1.0.0初始化源代码
This commit is contained in:
18
backend/app/services/__init__.py
Normal file
18
backend/app/services/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
服务层模块初始化
|
||||
"""
|
||||
from .huihui_api import HuihuiAPIService
|
||||
from .ai_service import AIService
|
||||
from .virtual_user_service import VirtualUserService
|
||||
from .interaction_service import InteractionService
|
||||
from .token_service import TokenService
|
||||
from .scheduler_service import SchedulerService
|
||||
|
||||
__all__ = [
|
||||
"HuihuiAPIService",
|
||||
"AIService",
|
||||
"VirtualUserService",
|
||||
"InteractionService",
|
||||
"TokenService",
|
||||
"SchedulerService",
|
||||
]
|
||||
301
backend/app/services/ai_service.py
Normal file
301
backend/app/services/ai_service.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""
|
||||
AI 大模型对接服务
|
||||
支持 OpenAI、智谱、百度文心、阿里通义等主流大模型
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIService:
|
||||
"""AI 服务类"""
|
||||
|
||||
def __init__(self):
|
||||
self._client_cache: Dict[str, Any] = {}
|
||||
|
||||
async def generate_comment(
|
||||
self,
|
||||
news_content: str,
|
||||
writing_style: str,
|
||||
persona_description: Optional[str] = None,
|
||||
model_config: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
AI 生成评论
|
||||
:param news_content: 新闻内容
|
||||
:param writing_style: 写作风格
|
||||
:param persona_description: 人格描述
|
||||
:param model_config: 模型配置
|
||||
:return: 生成结果(包含 content, tokens_used 等)
|
||||
"""
|
||||
prompt = self._build_comment_prompt(
|
||||
news_content,
|
||||
writing_style,
|
||||
persona_description
|
||||
)
|
||||
|
||||
return await self._call_ai_api(prompt, model_config)
|
||||
|
||||
async def generate_reply(
|
||||
self,
|
||||
original_comment: str,
|
||||
news_content: str,
|
||||
writing_style: str,
|
||||
persona_description: Optional[str] = None,
|
||||
model_config: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
AI 生成回复
|
||||
:param original_comment: 原评论
|
||||
:param news_content: 新闻内容
|
||||
:param writing_style: 写作风格
|
||||
:param persona_description: 人格描述
|
||||
:param model_config: 模型配置
|
||||
:return: 生成结果
|
||||
"""
|
||||
prompt = self._build_reply_prompt(
|
||||
original_comment,
|
||||
news_content,
|
||||
writing_style,
|
||||
persona_description
|
||||
)
|
||||
|
||||
return await self._call_ai_api(prompt, model_config)
|
||||
|
||||
def _build_comment_prompt(
|
||||
self,
|
||||
news_content: str,
|
||||
writing_style: str,
|
||||
persona_description: Optional[str] = None
|
||||
) -> str:
|
||||
"""构建评论提示词"""
|
||||
base_prompt = f"""你是一位虚拟用户,请根据以下要求写一条简短的评论:
|
||||
|
||||
写作风格:{writing_style}
|
||||
"""
|
||||
|
||||
if persona_description:
|
||||
base_prompt += f"\n人格特征:{persona_description}\n"
|
||||
|
||||
base_prompt += f"""
|
||||
新闻内容:
|
||||
{news_content[:1000]} # 限制长度
|
||||
|
||||
请写一条 50-100 字的评论,要符合你的写作风格和人格特征。直接输出评论内容,不要有其他说明。"""
|
||||
|
||||
return base_prompt
|
||||
|
||||
def _build_reply_prompt(
|
||||
self,
|
||||
original_comment: str,
|
||||
news_content: str,
|
||||
writing_style: str,
|
||||
persona_description: Optional[str] = None
|
||||
) -> str:
|
||||
"""构建回复提示词"""
|
||||
base_prompt = f"""你是一位虚拟用户,请根据以下要求回复另一条评论:
|
||||
|
||||
写作风格:{writing_style}
|
||||
"""
|
||||
|
||||
if persona_description:
|
||||
base_prompt += f"\n人格特征:{persona_description}\n"
|
||||
|
||||
base_prompt += f"""
|
||||
新闻内容:
|
||||
{news_content[:500]}
|
||||
|
||||
原评论:
|
||||
{original_comment}
|
||||
|
||||
请写一条 30-80 字的回复,要符合你的写作风格和人格特征。直接输出回复内容,不要有其他说明。"""
|
||||
|
||||
return base_prompt
|
||||
|
||||
async def _call_ai_api(
|
||||
self,
|
||||
prompt: str,
|
||||
model_config: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
调用 AI API(根据 model_config 中的 provider 选择对应模型)
|
||||
:param prompt: 提示词
|
||||
:param model_config: 模型配置
|
||||
:return: 生成结果
|
||||
"""
|
||||
if not model_config:
|
||||
# 使用默认配置(需要从数据库加载)
|
||||
from app.models.ai_model import AIModelConfig
|
||||
from app.models.base import get_db
|
||||
|
||||
with get_db() as db:
|
||||
default_model = db.query(AIModelConfig).filter(
|
||||
AIModelConfig.is_default == True,
|
||||
AIModelConfig.is_active == True
|
||||
).first()
|
||||
|
||||
if not default_model:
|
||||
logger.error("No default AI model configured")
|
||||
return None
|
||||
|
||||
model_config = {
|
||||
"provider": default_model.provider,
|
||||
"model_name": default_model.model_name,
|
||||
"api_key": default_model.api_key,
|
||||
"api_url": default_model.api_url,
|
||||
"temperature": default_model.temperature,
|
||||
"max_tokens": default_model.max_tokens,
|
||||
}
|
||||
|
||||
provider = model_config.get("provider", "").lower()
|
||||
|
||||
try:
|
||||
if provider == "openai":
|
||||
return await self._call_openai(prompt, model_config)
|
||||
elif provider == "zhipu":
|
||||
return await self._call_zhipu(prompt, model_config)
|
||||
elif provider in ["baidu", "wenxin"]:
|
||||
return await self._call_baidu_wenxin(prompt, model_config)
|
||||
elif provider in ["aliyun", "dashscope"]:
|
||||
return await self._call_aliyun_dashscope(prompt, model_config)
|
||||
else:
|
||||
logger.error(f"Unsupported AI provider: {provider}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"AI API call error: {e}")
|
||||
return None
|
||||
|
||||
async def _call_openai(
|
||||
self,
|
||||
prompt: str,
|
||||
config: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""调用 OpenAI API"""
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
client = AsyncOpenAI(
|
||||
api_key=config["api_key"],
|
||||
base_url=config.get("api_url")
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=config.get("model_name", "gpt-3.5-turbo"),
|
||||
messages=[
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=config.get("temperature", 0.7),
|
||||
max_tokens=config.get("max_tokens", 1000)
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
tokens_used = response.usage.total_tokens if response.usage else 0
|
||||
|
||||
logger.info(f"OpenAI generated content, tokens: {tokens_used}")
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"tokens_used": tokens_used,
|
||||
"provider": "openai",
|
||||
"model": config.get("model_name", "gpt-3.5-turbo")
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
return None
|
||||
|
||||
async def _call_zhipu(
|
||||
self,
|
||||
prompt: str,
|
||||
config: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""调用智谱 AI API"""
|
||||
try:
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
client = ZhipuAI(api_key=config["api_key"])
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=config.get("model_name", "glm-4"),
|
||||
messages=[
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=config.get("temperature", 0.7),
|
||||
max_tokens=config.get("max_tokens", 1000)
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
tokens_used = response.usage.total_tokens if response.usage else 0
|
||||
|
||||
logger.info(f"Zhipu AI generated content, tokens: {tokens_used}")
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"tokens_used": tokens_used,
|
||||
"provider": "zhipu",
|
||||
"model": config.get("model_name", "glm-4")
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Zhipu AI API error: {e}")
|
||||
return None
|
||||
|
||||
async def _call_baidu_wenxin(
|
||||
self,
|
||||
prompt: str,
|
||||
config: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""调用百度文心一言 API"""
|
||||
# TODO: 实现百度文心一言 API 调用
|
||||
logger.warning("Baidu Wenxin API not implemented yet")
|
||||
return None
|
||||
|
||||
async def _call_aliyun_dashscope(
|
||||
self,
|
||||
prompt: str,
|
||||
config: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""调用阿里云通义千问 API"""
|
||||
# TODO: 实现阿里云 DashScope API 调用
|
||||
logger.warning("Aliyun DashScope API not implemented yet")
|
||||
return None
|
||||
|
||||
async def test_model(
|
||||
self,
|
||||
model_config: Dict[str, Any],
|
||||
test_prompt: str = "测试评论"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
测试模型配置
|
||||
:param model_config: 模型配置
|
||||
:param test_prompt: 测试提示词
|
||||
:return: 测试结果
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
result = await self._call_ai_api(test_prompt, model_config)
|
||||
|
||||
cost_time = time.time() - start_time
|
||||
|
||||
if result:
|
||||
return {
|
||||
"success": True,
|
||||
"content": result.get("content"),
|
||||
"tokens_used": result.get("tokens_used", 0),
|
||||
"cost_time": round(cost_time, 2),
|
||||
"error_message": None
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"content": None,
|
||||
"tokens_used": 0,
|
||||
"cost_time": round(cost_time, 2),
|
||||
"error_message": "Failed to generate content"
|
||||
}
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
ai_service = AIService()
|
||||
291
backend/app/services/huihui_api.py
Normal file
291
backend/app/services/huihui_api.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
会会接口对接服务
|
||||
基于 http://192.168.1.200:63120/doc.html 接口文档
|
||||
"""
|
||||
import httpx
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HuihuiAPIService:
|
||||
"""会会 API 服务类"""
|
||||
|
||||
def __init__(self):
|
||||
self.base_url = settings.HUIHUI_API_BASE
|
||||
self.timeout = 30 # 秒
|
||||
self._session_cache: Dict[str, httpx.AsyncClient] = {}
|
||||
|
||||
def _get_client(self, session_token: Optional[str] = None) -> httpx.AsyncClient:
|
||||
"""获取 HTTP 客户端"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
if session_token:
|
||||
headers["Authorization"] = f"Bearer {session_token}"
|
||||
|
||||
return httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
async def login(self, username: str, password: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
用户登录
|
||||
:param username: 用户名
|
||||
:param password: 密码
|
||||
:return: 登录响应(包含 session token)
|
||||
"""
|
||||
try:
|
||||
async with self._get_client() as client:
|
||||
response = await client.post(
|
||||
"/api/login", # 实际接口路径需根据 doc.html 调整
|
||||
json={"username": username, "password": password}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
logger.info(f"Login success for user: {username}")
|
||||
return data
|
||||
else:
|
||||
logger.error(f"Login failed: {response.status_code} - {response.text}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Login error: {e}")
|
||||
return None
|
||||
|
||||
async def get_news_list(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
category: Optional[str] = None
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
获取新闻列表
|
||||
:param page: 页码
|
||||
:param page_size: 每页数量
|
||||
:param category: 分类(可选)
|
||||
:return: 新闻列表
|
||||
"""
|
||||
try:
|
||||
async with self._get_client() as client:
|
||||
params = {"page": page, "pageSize": page_size}
|
||||
if category:
|
||||
params["category"] = category
|
||||
|
||||
response = await client.get(
|
||||
"/api/news/list", # 实际接口路径需根据 doc.html 调整
|
||||
params=params
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return data.get("data", [])
|
||||
else:
|
||||
logger.error(f"Get news list failed: {response.status_code}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Get news list error: {e}")
|
||||
return None
|
||||
|
||||
async def get_news_detail(self, news_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取新闻详情
|
||||
:param news_id: 新闻 ID
|
||||
:return: 新闻详情
|
||||
"""
|
||||
try:
|
||||
async with self._get_client() as client:
|
||||
response = await client.get(f"/api/news/{news_id}")
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return data.get("data")
|
||||
else:
|
||||
logger.error(f"Get news detail failed: {response.status_code}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Get news detail error: {e}")
|
||||
return None
|
||||
|
||||
async def create_comment(
|
||||
self,
|
||||
news_id: str,
|
||||
content: str,
|
||||
session_token: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
创建评论
|
||||
:param news_id: 新闻 ID
|
||||
:param content: 评论内容
|
||||
:param session_token: 会话 Token
|
||||
:return: 评论结果
|
||||
"""
|
||||
try:
|
||||
async with self._get_client(session_token) as client:
|
||||
response = await client.post(
|
||||
"/api/comment/create", # 实际接口路径需根据 doc.html 调整
|
||||
json={"newsId": news_id, "content": content}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
logger.info(f"Comment created for news: {news_id}")
|
||||
return data
|
||||
else:
|
||||
logger.error(f"Create comment failed: {response.status_code} - {response.text}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Create comment error: {e}")
|
||||
return None
|
||||
|
||||
async def create_reply(
|
||||
self,
|
||||
comment_id: str,
|
||||
content: str,
|
||||
session_token: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
创建回复
|
||||
:param comment_id: 评论 ID
|
||||
:param content: 回复内容
|
||||
:param session_token: 会话 Token
|
||||
:return: 回复结果
|
||||
"""
|
||||
try:
|
||||
async with self._get_client(session_token) as client:
|
||||
response = await client.post(
|
||||
"/api/reply/create", # 实际接口路径需根据 doc.html 调整
|
||||
json={"commentId": comment_id, "content": content}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
logger.info(f"Reply created for comment: {comment_id}")
|
||||
return data
|
||||
else:
|
||||
logger.error(f"Create reply failed: {response.status_code}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Create reply error: {e}")
|
||||
return None
|
||||
|
||||
async def like_comment(
|
||||
self,
|
||||
comment_id: str,
|
||||
session_token: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
点赞评论
|
||||
:param comment_id: 评论 ID
|
||||
:param session_token: 会话 Token
|
||||
:return: 点赞结果
|
||||
"""
|
||||
try:
|
||||
async with self._get_client(session_token) as client:
|
||||
response = await client.post(f"/api/comment/{comment_id}/like")
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
logger.info(f"Comment liked: {comment_id}")
|
||||
return data
|
||||
else:
|
||||
logger.error(f"Like comment failed: {response.status_code}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Like comment error: {e}")
|
||||
return None
|
||||
|
||||
async def favorite_news(
|
||||
self,
|
||||
news_id: str,
|
||||
session_token: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
收藏新闻
|
||||
:param news_id: 新闻 ID
|
||||
:param session_token: 会话 Token
|
||||
:return: 收藏结果
|
||||
"""
|
||||
try:
|
||||
async with self._get_client(session_token) as client:
|
||||
response = await client.post(f"/api/news/{news_id}/favorite")
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
logger.info(f"News favorited: {news_id}")
|
||||
return data
|
||||
else:
|
||||
logger.error(f"Favorite news failed: {response.status_code}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Favorite news error: {e}")
|
||||
return None
|
||||
|
||||
async def share_news(
|
||||
self,
|
||||
news_id: str,
|
||||
session_token: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
转发新闻
|
||||
:param news_id: 新闻 ID
|
||||
:param session_token: 会话 Token
|
||||
:return: 转发结果
|
||||
"""
|
||||
try:
|
||||
async with self._get_client(session_token) as client:
|
||||
response = await client.post(f"/api/news/{news_id}/share")
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
logger.info(f"News shared: {news_id}")
|
||||
return data
|
||||
else:
|
||||
logger.error(f"Share news failed: {response.status_code}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Share news error: {e}")
|
||||
return None
|
||||
|
||||
async def get_comments(
|
||||
self,
|
||||
news_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
获取新闻评论列表
|
||||
:param news_id: 新闻 ID
|
||||
:param page: 页码
|
||||
:param page_size: 每页数量
|
||||
:return: 评论列表
|
||||
"""
|
||||
try:
|
||||
async with self._get_client() as client:
|
||||
params = {"page": page, "pageSize": page_size}
|
||||
response = await client.get(
|
||||
f"/api/news/{news_id}/comments",
|
||||
params=params
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return data.get("data", [])
|
||||
else:
|
||||
logger.error(f"Get comments failed: {response.status_code}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Get comments error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
huihui_api_service = HuihuiAPIService()
|
||||
408
backend/app/services/interaction_service.py
Normal file
408
backend/app/services/interaction_service.py
Normal file
@@ -0,0 +1,408 @@
|
||||
"""
|
||||
互动执行服务
|
||||
"""
|
||||
import logging
|
||||
import random
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, func
|
||||
|
||||
from app.models.virtual_user import VirtualUser, ActivityLevel, UserStatus
|
||||
from app.models.interaction import InteractionRecord, InteractionType, InteractionStatus
|
||||
from app.models.token_usage import TokenUsage
|
||||
from app.models.news_cache import NewsCache
|
||||
from app.services.huihui_api_service import huihui_api_service
|
||||
from app.services.ai_service import ai_service
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InteractionService:
|
||||
"""互动执行服务类"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
async def execute_interaction(
|
||||
self,
|
||||
virtual_user_id: int,
|
||||
interaction_type: Optional[InteractionType] = None,
|
||||
news_id: Optional[str] = None
|
||||
) -> Optional[InteractionRecord]:
|
||||
"""
|
||||
执行单次互动
|
||||
:param virtual_user_id: 虚拟用户 ID
|
||||
:param interaction_type: 互动类型(不传则随机)
|
||||
:param news_id: 新闻 ID(不传则随机选择)
|
||||
:return: 互动记录
|
||||
"""
|
||||
# 获取虚拟用户
|
||||
user = self.db.query(VirtualUser).filter(VirtualUser.id == virtual_user_id).first()
|
||||
if not user:
|
||||
logger.error(f"Virtual user not found: {virtual_user_id}")
|
||||
return None
|
||||
|
||||
# 检查用户状态
|
||||
if user.status != UserStatus.ACTIVE:
|
||||
logger.warning(f"Virtual user is not active: {virtual_user_id}")
|
||||
return None
|
||||
|
||||
# 检查是否已登录
|
||||
if not user.is_logged_in or not user.session_token:
|
||||
logger.warning(f"Virtual user not logged in: {virtual_user_id}")
|
||||
# TODO: 自动登录
|
||||
return None
|
||||
|
||||
# 检查今日限额
|
||||
if not self._check_daily_limit(user, interaction_type):
|
||||
logger.warning(f"Daily limit reached for user {virtual_user_id}")
|
||||
return None
|
||||
|
||||
# 选择新闻
|
||||
if not news_id:
|
||||
news_id = await self._select_news(user)
|
||||
if not news_id:
|
||||
logger.warning("No news available for interaction")
|
||||
return None
|
||||
|
||||
# 获取新闻详情
|
||||
news = self.db.query(NewsCache).filter(NewsCache.news_id == news_id).first()
|
||||
if not news:
|
||||
# 从 API 获取
|
||||
news_detail = await huihui_api_service.get_news_detail(news_id)
|
||||
if news_detail:
|
||||
news = self._cache_news(news_detail)
|
||||
|
||||
if not news:
|
||||
logger.error(f"Cannot get news detail: {news_id}")
|
||||
return None
|
||||
|
||||
# 确定互动类型
|
||||
if not interaction_type:
|
||||
interaction_type = self._random_interaction_type()
|
||||
|
||||
# 创建互动记录
|
||||
record = InteractionRecord(
|
||||
virtual_user_id=virtual_user_id,
|
||||
news_id=news_id,
|
||||
news_title=news.title if news else "",
|
||||
interaction_type=interaction_type,
|
||||
status=InteractionStatus.PENDING
|
||||
)
|
||||
|
||||
self.db.add(record)
|
||||
self.db.commit()
|
||||
self.db.refresh(record)
|
||||
|
||||
try:
|
||||
# 执行互动
|
||||
if interaction_type == InteractionType.COMMENT:
|
||||
result = await self._execute_comment(user, news, record)
|
||||
elif interaction_type == InteractionType.REPLY:
|
||||
result = await self._execute_reply(user, news, record)
|
||||
elif interaction_type == InteractionType.LIKE:
|
||||
result = await self._execute_like(user, news, record)
|
||||
elif interaction_type == InteractionType.FAVORITE:
|
||||
result = await self._execute_favorite(user, news, record)
|
||||
elif interaction_type == InteractionType.SHARE:
|
||||
result = await self._execute_share(user, news, record)
|
||||
else:
|
||||
logger.error(f"Unknown interaction type: {interaction_type}")
|
||||
return None
|
||||
|
||||
if result:
|
||||
record.status = InteractionStatus.SUCCESS
|
||||
record.api_response = str(result)
|
||||
|
||||
# 更新用户统计
|
||||
user.total_interactions += 1
|
||||
if interaction_type == InteractionType.COMMENT:
|
||||
user.today_comments += 1
|
||||
elif interaction_type == InteractionType.REPLY:
|
||||
user.today_replies += 1
|
||||
user.last_interaction_time = datetime.now()
|
||||
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"Interaction executed successfully: user={user.id}, type={interaction_type}")
|
||||
return record
|
||||
else:
|
||||
record.status = InteractionStatus.FAILED
|
||||
record.error_message = "API call failed"
|
||||
self.db.commit()
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Execute interaction error: {e}")
|
||||
record.status = InteractionStatus.FAILED
|
||||
record.error_message = str(e)
|
||||
self.db.commit()
|
||||
return None
|
||||
|
||||
async def _execute_comment(
|
||||
self,
|
||||
user: VirtualUser,
|
||||
news: NewsCache,
|
||||
record: InteractionRecord
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""执行评论"""
|
||||
# AI 生成评论内容
|
||||
ai_result = await ai_service.generate_comment(
|
||||
news_content=news.content or news.summary or news.title,
|
||||
writing_style=user.writing_style or "普通",
|
||||
persona_description=user.persona_description
|
||||
)
|
||||
|
||||
if not ai_result or not ai_result.get("content"):
|
||||
logger.error("AI generate comment failed")
|
||||
return None
|
||||
|
||||
# 记录 Token 使用
|
||||
self._record_token_usage(
|
||||
virtual_user_id=user.id,
|
||||
interaction_id=record.id,
|
||||
tokens_used=ai_result.get("tokens_used", 0),
|
||||
ai_model=ai_result.get("model", "unknown"),
|
||||
action_type="generate_comment",
|
||||
tokens_prompt=ai_result.get("tokens_prompt", 0),
|
||||
tokens_completion=ai_result.get("tokens_completion", 0)
|
||||
)
|
||||
|
||||
# 调用接口提交评论
|
||||
result = await huihui_api_service.create_comment(
|
||||
news_id=news.news_id,
|
||||
content=ai_result["content"],
|
||||
session_token=user.session_token
|
||||
)
|
||||
|
||||
if result:
|
||||
record.content = ai_result["content"]
|
||||
record.tokens_used = ai_result.get("tokens_used", 0)
|
||||
record.ai_model_used = ai_result.get("model")
|
||||
|
||||
return result
|
||||
|
||||
async def _execute_reply(
|
||||
self,
|
||||
user: VirtualUser,
|
||||
news: NewsCache,
|
||||
record: InteractionRecord
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""执行回复"""
|
||||
# 获取评论列表
|
||||
comments = await huihui_api_service.get_comments(news_id=news.news_id)
|
||||
|
||||
if not comments or len(comments) == 0:
|
||||
logger.warning(f"No comments available for news {news.news_id}")
|
||||
return None
|
||||
|
||||
# 随机选择一条评论进行回复
|
||||
target_comment = random.choice(comments)
|
||||
record.target_comment_id = target_comment.get("id")
|
||||
|
||||
# AI 生成回复内容
|
||||
ai_result = await ai_service.generate_reply(
|
||||
original_comment=target_comment.get("content", ""),
|
||||
news_content=news.content or news.summary or news.title,
|
||||
writing_style=user.writing_style or "普通",
|
||||
persona_description=user.persona_description
|
||||
)
|
||||
|
||||
if not ai_result or not ai_result.get("content"):
|
||||
logger.error("AI generate reply failed")
|
||||
return None
|
||||
|
||||
# 记录 Token 使用
|
||||
self._record_token_usage(
|
||||
virtual_user_id=user.id,
|
||||
interaction_id=record.id,
|
||||
tokens_used=ai_result.get("tokens_used", 0),
|
||||
ai_model=ai_result.get("model", "unknown"),
|
||||
action_type="generate_reply"
|
||||
)
|
||||
|
||||
# 调用接口提交回复
|
||||
result = await huihui_api_service.create_reply(
|
||||
comment_id=target_comment.get("id"),
|
||||
content=ai_result["content"],
|
||||
session_token=user.session_token
|
||||
)
|
||||
|
||||
if result:
|
||||
record.content = ai_result["content"]
|
||||
record.tokens_used = ai_result.get("tokens_used", 0)
|
||||
|
||||
return result
|
||||
|
||||
async def _execute_like(
|
||||
self,
|
||||
user: VirtualUser,
|
||||
news: NewsCache,
|
||||
record: InteractionRecord
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""执行点赞"""
|
||||
# 获取评论列表
|
||||
comments = await huihui_api_service.get_comments(news_id=news.news_id)
|
||||
|
||||
if not comments or len(comments) == 0:
|
||||
logger.warning(f"No comments available for like")
|
||||
return None
|
||||
|
||||
# 随机选择一条评论点赞
|
||||
target_comment = random.choice(comments)
|
||||
record.target_comment_id = target_comment.get("id")
|
||||
|
||||
result = await huihui_api_service.like_comment(
|
||||
comment_id=target_comment.get("id"),
|
||||
session_token=user.session_token
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _execute_favorite(
|
||||
self,
|
||||
user: VirtualUser,
|
||||
news: NewsCache,
|
||||
record: InteractionRecord
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""执行收藏"""
|
||||
result = await huihui_api_service.favorite_news(
|
||||
news_id=news.news_id,
|
||||
session_token=user.session_token
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _execute_share(
|
||||
self,
|
||||
user: VirtualUser,
|
||||
news: NewsCache,
|
||||
record: InteractionRecord
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""执行转发"""
|
||||
result = await huihui_api_service.share_news(
|
||||
news_id=news.news_id,
|
||||
session_token=user.session_token
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _check_daily_limit(
|
||||
self,
|
||||
user: VirtualUser,
|
||||
interaction_type: Optional[InteractionType]
|
||||
) -> bool:
|
||||
"""检查每日限额"""
|
||||
today = datetime.now().date()
|
||||
|
||||
# 统计今日互动
|
||||
today_records = self.db.query(InteractionRecord).filter(
|
||||
and_(
|
||||
InteractionRecord.virtual_user_id == user.id,
|
||||
func.date(InteractionRecord.execution_time) == today,
|
||||
InteractionRecord.status == InteractionStatus.SUCCESS
|
||||
)
|
||||
).all()
|
||||
|
||||
today_comments = sum(1 for r in today_records if r.interaction_type == InteractionType.COMMENT)
|
||||
today_replies = sum(1 for r in today_records if r.interaction_type == InteractionType.REPLY)
|
||||
|
||||
# 检查评论限额
|
||||
if interaction_type == InteractionType.COMMENT:
|
||||
if today_comments >= settings.MAX_COMMENTS_PER_USER_PER_DAY:
|
||||
return False
|
||||
|
||||
# 检查回复限额
|
||||
if interaction_type == InteractionType.REPLY:
|
||||
if today_replies >= settings.MAX_REPLIES_PER_USER_PER_DAY:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _random_interaction_type(self) -> InteractionType:
|
||||
"""随机选择互动类型"""
|
||||
rand = random.random()
|
||||
|
||||
# 根据概率决定互动类型
|
||||
if rand < settings.LIKE_PROBABILITY:
|
||||
return InteractionType.LIKE
|
||||
elif rand < settings.LIKE_PROBABILITY + settings.FAVORITE_PROBABILITY:
|
||||
return InteractionType.FAVORITE
|
||||
elif rand < settings.LIKE_PROBABILITY + settings.FAVORITE_PROBABILITY + settings.SHARE_PROBABILITY:
|
||||
return InteractionType.SHARE
|
||||
else:
|
||||
return InteractionType.COMMENT
|
||||
|
||||
async def _select_news(self, user: VirtualUser) -> Optional[str]:
|
||||
"""选择新闻"""
|
||||
# 优先选择未互动过的新闻
|
||||
cached_news = self.db.query(NewsCache).order_by(
|
||||
NewsCache.created_at.desc()
|
||||
).limit(50).all()
|
||||
|
||||
if not cached_news:
|
||||
# 从 API 获取
|
||||
news_list = await huihui_api_service.get_news_list(page=1, page_size=20)
|
||||
if news_list:
|
||||
for news_data in news_list:
|
||||
self._cache_news(news_data)
|
||||
cached_news = self.db.query(NewsCache).order_by(
|
||||
NewsCache.created_at.desc()
|
||||
).limit(50).all()
|
||||
|
||||
if not cached_news:
|
||||
return None
|
||||
|
||||
# 随机选择一篇
|
||||
return random.choice(cached_news).news_id
|
||||
|
||||
def _cache_news(self, news_data: Dict[str, Any]) -> Optional[NewsCache]:
|
||||
"""缓存新闻"""
|
||||
news = NewsCache(
|
||||
news_id=str(news_data.get("id")),
|
||||
title=news_data.get("title", ""),
|
||||
summary=news_data.get("summary", ""),
|
||||
content=news_data.get("content", ""),
|
||||
source=news_data.get("source", ""),
|
||||
author=news_data.get("author", ""),
|
||||
category=news_data.get("category", "")
|
||||
)
|
||||
|
||||
self.db.add(news)
|
||||
self.db.commit()
|
||||
self.db.refresh(news)
|
||||
return news
|
||||
|
||||
def _record_token_usage(
|
||||
self,
|
||||
virtual_user_id: int,
|
||||
interaction_id: int,
|
||||
tokens_used: int,
|
||||
ai_model: str,
|
||||
action_type: str,
|
||||
tokens_prompt: int = 0,
|
||||
tokens_completion: int = 0
|
||||
):
|
||||
"""记录 Token 使用"""
|
||||
usage = TokenUsage(
|
||||
virtual_user_id=virtual_user_id,
|
||||
interaction_id=interaction_id,
|
||||
tokens_used=tokens_used,
|
||||
tokens_prompt=tokens_prompt,
|
||||
tokens_completion=tokens_completion,
|
||||
ai_model=ai_model,
|
||||
action_type=action_type
|
||||
)
|
||||
|
||||
self.db.add(usage)
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"Token usage recorded: {tokens_used} tokens for user {virtual_user_id}")
|
||||
|
||||
|
||||
# 工厂函数
|
||||
def get_interaction_service(db: Session) -> InteractionService:
|
||||
"""获取互动服务实例"""
|
||||
return InteractionService(db)
|
||||
215
backend/app/services/scheduler_service.py
Normal file
215
backend/app/services/scheduler_service.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
定时任务调度服务
|
||||
基于 APScheduler 实现
|
||||
"""
|
||||
import logging
|
||||
import random
|
||||
import asyncio
|
||||
from typing import Optional, List
|
||||
from datetime import datetime, time
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.virtual_user import VirtualUser, ActivityLevel, UserStatus
|
||||
from app.models.base import get_db, SessionLocal
|
||||
from app.services.interaction_service import InteractionService
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SchedulerService:
|
||||
"""定时任务调度服务类"""
|
||||
|
||||
def __init__(self):
|
||||
self.scheduler = AsyncIOScheduler()
|
||||
self.is_running = False
|
||||
self._current_job = None
|
||||
|
||||
def start(self):
|
||||
"""启动调度器"""
|
||||
if not self.is_running:
|
||||
self.scheduler.start()
|
||||
self.is_running = True
|
||||
logger.info("Scheduler started")
|
||||
|
||||
def stop(self):
|
||||
"""停止调度器"""
|
||||
if self.is_running:
|
||||
self.scheduler.shutdown()
|
||||
self.is_running = False
|
||||
logger.info("Scheduler stopped")
|
||||
|
||||
def add_interaction_task(self):
|
||||
"""添加互动任务"""
|
||||
# 在活动时间段内,每隔随机时间执行一次互动
|
||||
# 由于 APScheduler 不支持随机间隔,我们使用固定间隔但通过概率控制执行
|
||||
|
||||
# 每 5 分钟检查一次
|
||||
trigger = IntervalTrigger(minutes=5)
|
||||
|
||||
self.scheduler.add_job(
|
||||
self._execute_random_interaction,
|
||||
trigger=trigger,
|
||||
id="random_interaction",
|
||||
name="Random Interaction Task",
|
||||
replace_existing=True
|
||||
)
|
||||
|
||||
logger.info("Interaction task added")
|
||||
|
||||
def remove_interaction_task(self):
|
||||
"""移除互动任务"""
|
||||
try:
|
||||
self.scheduler.remove_job("random_interaction")
|
||||
logger.info("Interaction task removed")
|
||||
except Exception as e:
|
||||
logger.warning(f"Remove interaction task error: {e}")
|
||||
|
||||
async def _execute_random_interaction(self):
|
||||
"""执行随机互动任务"""
|
||||
# 检查是否在活动时间段内
|
||||
now = datetime.now()
|
||||
current_hour = now.hour
|
||||
|
||||
if current_hour < settings.TASK_START_HOUR or current_hour > settings.TASK_END_HOUR:
|
||||
logger.debug(f"Outside activity hours: {current_hour}")
|
||||
return
|
||||
|
||||
# 随机决定是否执行(通过随机间隔模拟)
|
||||
if random.random() > 0.5: # 50% 概率执行
|
||||
logger.debug("Skip this round")
|
||||
return
|
||||
|
||||
logger.info("Executing random interaction task")
|
||||
|
||||
# 获取数据库会话
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 获取所有活跃的虚拟用户
|
||||
users = db.query(VirtualUser).filter(
|
||||
VirtualUser.status == UserStatus.ACTIVE,
|
||||
VirtualUser.is_logged_in == True
|
||||
).all()
|
||||
|
||||
if not users:
|
||||
logger.debug("No active logged-in users")
|
||||
return
|
||||
|
||||
# 随机选择一个用户
|
||||
user = random.choice(users)
|
||||
|
||||
# 检查用户活跃度
|
||||
if not self._should_user_interact(user):
|
||||
logger.debug(f"User {user.id} should not interact now")
|
||||
return
|
||||
|
||||
# 执行互动
|
||||
interaction_service = InteractionService(db)
|
||||
await interaction_service.execute_interaction(virtual_user_id=user.id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Execute random interaction error: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _should_user_interact(self, user: VirtualUser) -> bool:
|
||||
"""根据活跃度判断用户是否应该互动"""
|
||||
# 根据活跃度决定互动概率
|
||||
if user.activity_level == ActivityLevel.HIGH:
|
||||
# 高活跃度:80% 概率
|
||||
return random.random() < 0.8
|
||||
elif user.activity_level == ActivityLevel.MEDIUM:
|
||||
# 中活跃度:50% 概率
|
||||
return random.random() < 0.5
|
||||
else:
|
||||
# 低活跃度:30% 概率
|
||||
return random.random() < 0.3
|
||||
|
||||
def add_login_task(self, hour: int = 8, minute: int = 0):
|
||||
"""添加每日登录任务"""
|
||||
trigger = CronTrigger(hour=hour, minute=minute)
|
||||
|
||||
self.scheduler.add_job(
|
||||
self._auto_login_users,
|
||||
trigger=trigger,
|
||||
id="daily_login",
|
||||
name="Daily Auto Login",
|
||||
replace_existing=True
|
||||
)
|
||||
|
||||
logger.info(f"Daily login task added at {hour:02d}:{minute:02d}")
|
||||
|
||||
async def _auto_login_users(self):
|
||||
"""自动登录所有活跃用户"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from app.services.huihui_api_service import huihui_api_service
|
||||
|
||||
users = db.query(VirtualUser).filter(
|
||||
VirtualUser.status == UserStatus.ACTIVE
|
||||
).all()
|
||||
|
||||
for user in users:
|
||||
try:
|
||||
# 调用登录接口
|
||||
result = await huihui_api_service.login(user.username, user.password)
|
||||
|
||||
if result and result.get("token"):
|
||||
user.is_logged_in = True
|
||||
user.session_token = result["token"]
|
||||
# TODO: 设置 token 过期时间
|
||||
|
||||
logger.info(f"Auto login success: {user.username}")
|
||||
else:
|
||||
logger.warning(f"Auto login failed: {user.username}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Auto login error for {user.username}: {e}")
|
||||
|
||||
db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Auto login task error: {e}")
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def reset_daily_counters(self, hour: int = 0, minute: int = 1):
|
||||
"""添加每日计数器重置任务"""
|
||||
trigger = CronTrigger(hour=hour, minute=minute)
|
||||
|
||||
self.scheduler.add_job(
|
||||
self._reset_daily_counters,
|
||||
trigger=trigger,
|
||||
id="reset_daily_counters",
|
||||
name="Reset Daily Counters",
|
||||
replace_existing=True
|
||||
)
|
||||
|
||||
logger.info(f"Daily reset task added at {hour:02d}:{minute:02d}")
|
||||
|
||||
def _reset_daily_counters(self):
|
||||
"""重置每日计数器"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 重置所有用户的今日计数
|
||||
db.query(VirtualUser).update({
|
||||
VirtualUser.today_comments: 0,
|
||||
VirtualUser.today_replies: 0
|
||||
})
|
||||
|
||||
db.commit()
|
||||
logger.info("Daily counters reset")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Reset daily counters error: {e}")
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
scheduler_service = SchedulerService()
|
||||
173
backend/app/services/token_service.py
Normal file
173
backend/app/services/token_service.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
Token 使用统计服务
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, Any, List
|
||||
from datetime import datetime, timedelta, date
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func, and_, extract
|
||||
|
||||
from app.models.token_usage import TokenUsage
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenService:
|
||||
"""Token 统计服务类"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_today_usage(self) -> int:
|
||||
"""获取今日 Token 使用量"""
|
||||
today = date.today()
|
||||
|
||||
result = self.db.query(func.sum(TokenUsage.tokens_used)).filter(
|
||||
func.date(TokenUsage.usage_date) == today
|
||||
).scalar()
|
||||
|
||||
return result or 0
|
||||
|
||||
def get_yesterday_usage(self) -> int:
|
||||
"""获取昨日 Token 使用量"""
|
||||
yesterday = date.today() - timedelta(days=1)
|
||||
|
||||
result = self.db.query(func.sum(TokenUsage.tokens_used)).filter(
|
||||
func.date(TokenUsage.usage_date) == yesterday
|
||||
).scalar()
|
||||
|
||||
return result or 0
|
||||
|
||||
def get_month_usage(self, year: Optional[int] = None, month: Optional[int] = None) -> int:
|
||||
"""获取当月 Token 使用量"""
|
||||
if not year or not month:
|
||||
now = datetime.now()
|
||||
year = now.year
|
||||
month = now.month
|
||||
|
||||
result = self.db.query(func.sum(TokenUsage.tokens_used)).filter(
|
||||
and_(
|
||||
extract('year', TokenUsage.usage_date) == year,
|
||||
extract('month', TokenUsage.usage_date) == month
|
||||
)
|
||||
).scalar()
|
||||
|
||||
return result or 0
|
||||
|
||||
def get_remaining_tokens(self) -> int:
|
||||
"""获取今日剩余 Token"""
|
||||
today_used = self.get_today_usage()
|
||||
remaining = settings.MAX_TOKENS_PER_DAY - today_used
|
||||
return max(0, remaining)
|
||||
|
||||
def get_daily_usages(self, days: int = 30) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取每日 Token 使用(用于图表)
|
||||
:param days: 天数
|
||||
:return: 每日使用列表
|
||||
"""
|
||||
end_date = date.today()
|
||||
start_date = end_date - timedelta(days=days - 1)
|
||||
|
||||
results = self.db.query(
|
||||
func.date(TokenUsage.usage_date).label('usage_date'),
|
||||
func.sum(TokenUsage.tokens_used).label('total_tokens')
|
||||
).filter(
|
||||
and_(
|
||||
func.date(TokenUsage.usage_date) >= start_date,
|
||||
func.date(TokenUsage.usage_date) <= end_date
|
||||
)
|
||||
).group_by(
|
||||
func.date(TokenUsage.usage_date)
|
||||
).order_by(
|
||||
func.date(TokenUsage.usage_date)
|
||||
).all()
|
||||
|
||||
# 转换为字典列表
|
||||
usage_dict = {str(row.usage_date): row.total_tokens for row in results}
|
||||
|
||||
# 填充缺失的日期
|
||||
daily_usages = []
|
||||
current_date = start_date
|
||||
while current_date <= end_date:
|
||||
date_str = str(current_date)
|
||||
tokens = usage_dict.get(date_str, 0)
|
||||
daily_usages.append({
|
||||
"date": date_str,
|
||||
"tokens": tokens
|
||||
})
|
||||
current_date += timedelta(days=1)
|
||||
|
||||
return daily_usages
|
||||
|
||||
def get_monthly_usages(self, months: int = 12) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取每月 Token 使用(用于图表)
|
||||
:param months: 月数
|
||||
:return: 每月使用列表
|
||||
"""
|
||||
now = datetime.now()
|
||||
results = []
|
||||
|
||||
for i in range(months):
|
||||
# 计算月份
|
||||
month_offset = months - 1 - i
|
||||
target_date = now - timedelta(days=30 * month_offset)
|
||||
year = target_date.year
|
||||
month = target_date.month
|
||||
|
||||
# 查询该月的使用量
|
||||
usage = self.get_month_usage(year, month)
|
||||
|
||||
results.append({
|
||||
"month": f"{year}-{month:02d}",
|
||||
"tokens": usage
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
def get_user_token_usage(
|
||||
self,
|
||||
user_id: int,
|
||||
days: int = 30
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户的 Token 使用
|
||||
:param user_id: 用户 ID
|
||||
:param days: 天数
|
||||
:return: 每日使用列表
|
||||
"""
|
||||
end_date = date.today()
|
||||
start_date = end_date - timedelta(days=days - 1)
|
||||
|
||||
results = self.db.query(
|
||||
func.date(TokenUsage.usage_date).label('usage_date'),
|
||||
func.sum(TokenUsage.tokens_used).label('total_tokens')
|
||||
).filter(
|
||||
and_(
|
||||
TokenUsage.virtual_user_id == user_id,
|
||||
func.date(TokenUsage.usage_date) >= start_date,
|
||||
func.date(TokenUsage.usage_date) <= end_date
|
||||
)
|
||||
).group_by(
|
||||
func.date(TokenUsage.usage_date)
|
||||
).order_by(
|
||||
func.date(TokenUsage.usage_date)
|
||||
).all()
|
||||
|
||||
return [
|
||||
{"date": str(row.usage_date), "tokens": row.total_tokens}
|
||||
for row in results
|
||||
]
|
||||
|
||||
def check_token_limit_exceeded(self) -> bool:
|
||||
"""检查是否超出 Token 限额"""
|
||||
today_used = self.get_today_usage()
|
||||
return today_used >= settings.MAX_TOKENS_PER_DAY
|
||||
|
||||
|
||||
# 工厂函数
|
||||
def get_token_service(db: Session) -> TokenService:
|
||||
"""获取 Token 服务实例"""
|
||||
return TokenService(db)
|
||||
361
backend/app/services/virtual_user_service.py
Normal file
361
backend/app/services/virtual_user_service.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
虚拟用户管理服务
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, func, Date
|
||||
|
||||
from app.models.virtual_user import VirtualUser, VirtualUserPersona, ActivityLevel, UserStatus
|
||||
from app.models.interaction import InteractionRecord, InteractionType
|
||||
from app.services.ai_service import ai_service
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VirtualUserService:
|
||||
"""虚拟用户服务类"""
|
||||
|
||||
# 预设写作风格库
|
||||
WRITING_STYLES = [
|
||||
"幽默风趣",
|
||||
"严肃理性",
|
||||
"文艺清新",
|
||||
"吐槽犀利",
|
||||
"感性温暖",
|
||||
"客观中立",
|
||||
"激情澎湃",
|
||||
"冷静分析",
|
||||
"活泼可爱",
|
||||
"深沉内敛"
|
||||
]
|
||||
|
||||
# 昵称前缀和后缀
|
||||
NICKNAME_PREFIXES = ["清风", "星辰", "云端", "晨曦", "暮色", "流年", "初心", "远方"]
|
||||
NICKNAME_SUFFIXES = ["行者", "旅人", "追梦", "时光", "记忆", "印象", "故事", "传奇"]
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_user_by_id(self, user_id: int) -> Optional[VirtualUser]:
|
||||
"""根据 ID 获取用户"""
|
||||
return self.db.query(VirtualUser).filter(VirtualUser.id == user_id).first()
|
||||
|
||||
def get_user_by_username(self, username: str) -> Optional[VirtualUser]:
|
||||
"""根据用户名获取用户"""
|
||||
return self.db.query(VirtualUser).filter(VirtualUser.username == username).first()
|
||||
|
||||
def get_users(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
status: Optional[UserStatus] = None,
|
||||
search: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取用户列表
|
||||
:param page: 页码
|
||||
:param page_size: 每页数量
|
||||
:param status: 状态筛选
|
||||
:param search: 搜索关键词
|
||||
:return: 用户列表和总数
|
||||
"""
|
||||
query = self.db.query(VirtualUser)
|
||||
|
||||
if status:
|
||||
query = query.filter(VirtualUser.status == status)
|
||||
|
||||
if search:
|
||||
query = query.filter(
|
||||
or_(
|
||||
VirtualUser.nickname.like(f"%{search}%"),
|
||||
VirtualUser.username.like(f"%{search}%")
|
||||
)
|
||||
)
|
||||
|
||||
total = query.count()
|
||||
users = query.order_by(VirtualUser.created_at.desc()).offset(
|
||||
(page - 1) * page_size
|
||||
).limit(page_size).all()
|
||||
|
||||
return {"total": total, "items": users}
|
||||
|
||||
def create_user(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
nickname: str,
|
||||
writing_style: Optional[str] = None,
|
||||
activity_level: ActivityLevel = ActivityLevel.MEDIUM,
|
||||
avatar_url: Optional[str] = None,
|
||||
persona_description: Optional[str] = None
|
||||
) -> Optional[VirtualUser]:
|
||||
"""
|
||||
创建虚拟用户
|
||||
:param username: 用户名
|
||||
:param password: 密码
|
||||
:param nickname: 昵称
|
||||
:param writing_style: 写作风格
|
||||
:param activity_level: 活跃度
|
||||
:param avatar_url: 头像 URL
|
||||
:param persona_description: 人格描述
|
||||
:return: 创建的用户
|
||||
"""
|
||||
# 检查用户名是否已存在
|
||||
existing = self.get_user_by_username(username)
|
||||
if existing:
|
||||
logger.error(f"Username already exists: {username}")
|
||||
return None
|
||||
|
||||
user = VirtualUser(
|
||||
username=username,
|
||||
password=password, # TODO: 加密存储
|
||||
nickname=nickname,
|
||||
writing_style=writing_style or self._random_writing_style(),
|
||||
activity_level=activity_level,
|
||||
avatar_url=avatar_url or self._generate_avatar_url(),
|
||||
persona_description=persona_description
|
||||
)
|
||||
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
|
||||
logger.info(f"Virtual user created: {username}")
|
||||
return user
|
||||
|
||||
def generate_users(
|
||||
self,
|
||||
count: int,
|
||||
writing_styles: Optional[List[str]] = None,
|
||||
activity_levels: Optional[List[ActivityLevel]] = None,
|
||||
generate_persona: bool = True
|
||||
) -> List[VirtualUser]:
|
||||
"""
|
||||
批量生成虚拟用户
|
||||
:param count: 生成数量
|
||||
:param writing_styles: 写作风格列表
|
||||
:param activity_levels: 活跃度级别列表
|
||||
:param generate_persona: 是否生成 AI 人格描述
|
||||
:return: 生成的用户列表
|
||||
"""
|
||||
import random
|
||||
|
||||
styles = writing_styles or self.WRITING_STYLES
|
||||
levels = activity_levels or [ActivityLevel.LOW, ActivityLevel.MEDIUM, ActivityLevel.HIGH]
|
||||
|
||||
created_users = []
|
||||
|
||||
for i in range(count):
|
||||
# 生成唯一用户名
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
username = f"user_{timestamp}_{i}"
|
||||
|
||||
# 随机生成昵称
|
||||
prefix = random.choice(self.NICKNAME_PREFIXES)
|
||||
suffix = random.choice(self.NICKNAME_SUFFIXES)
|
||||
nickname = f"{prefix}{suffix}{random.randint(100, 999)}"
|
||||
|
||||
# 随机密码
|
||||
password = f"pwd_{random.randint(100000, 999999)}"
|
||||
|
||||
# 随机写作风格
|
||||
writing_style = random.choice(styles)
|
||||
|
||||
# 随机活跃度
|
||||
activity_level = random.choice(levels)
|
||||
|
||||
# 生成头像
|
||||
avatar_url = self._generate_avatar_url()
|
||||
|
||||
# AI 生成人格描述
|
||||
persona_description = None
|
||||
if generate_persona:
|
||||
persona_description = self._generate_persona_description(
|
||||
writing_style,
|
||||
activity_level
|
||||
)
|
||||
|
||||
user = self.create_user(
|
||||
username=username,
|
||||
password=password,
|
||||
nickname=nickname,
|
||||
writing_style=writing_style,
|
||||
activity_level=activity_level,
|
||||
avatar_url=avatar_url,
|
||||
persona_description=persona_description
|
||||
)
|
||||
|
||||
if user:
|
||||
created_users.append(user)
|
||||
|
||||
logger.info(f"Generated {len(created_users)} virtual users")
|
||||
return created_users
|
||||
|
||||
def _generate_persona_description(
|
||||
self,
|
||||
writing_style: str,
|
||||
activity_level: ActivityLevel
|
||||
) -> str:
|
||||
"""AI 生成人格描述"""
|
||||
import asyncio
|
||||
|
||||
prompt = f"""请为一位虚拟用户生成人格描述,要求:
|
||||
- 写作风格:{writing_style}
|
||||
- 活跃度:{activity_level.value}
|
||||
|
||||
请用 50-100 字描述这个人的性格特点、兴趣爱好、说话方式等。直接输出描述内容。"""
|
||||
|
||||
try:
|
||||
# 同步调用异步方法
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
result = loop.run_until_complete(ai_service._call_ai_api(prompt))
|
||||
loop.close()
|
||||
|
||||
if result and result.get("content"):
|
||||
return result["content"]
|
||||
except Exception as e:
|
||||
logger.error(f"Generate persona description error: {e}")
|
||||
|
||||
return f"这是一位{writing_style}的虚拟用户,活跃度{activity_level.value}。"
|
||||
|
||||
def _random_writing_style(self) -> str:
|
||||
"""随机选择写作风格"""
|
||||
import random
|
||||
return random.choice(self.WRITING_STYLES)
|
||||
|
||||
def _generate_avatar_url(self) -> str:
|
||||
"""生成随机头像 URL(使用第三方头像 API)"""
|
||||
import random
|
||||
# 使用 DiceBear 头像 API
|
||||
seed = f"avatar_{datetime.now().timestamp()}_{random.randint(1000, 9999)}"
|
||||
return f"https://api.dicebear.com/7.x/avataaars/svg?seed={seed}"
|
||||
|
||||
def update_user(
|
||||
self,
|
||||
user_id: int,
|
||||
**kwargs
|
||||
) -> Optional[VirtualUser]:
|
||||
"""更新用户信息"""
|
||||
user = self.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(user, key) and value is not None:
|
||||
setattr(user, key, value)
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
return user
|
||||
|
||||
def delete_user(self, user_id: int) -> bool:
|
||||
"""删除用户"""
|
||||
user = self.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
self.db.delete(user)
|
||||
self.db.commit()
|
||||
logger.info(f"Virtual user deleted: {user_id}")
|
||||
return True
|
||||
|
||||
def import_users_from_excel(
|
||||
self,
|
||||
users_data: List[Dict[str, Any]],
|
||||
generate_persona: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从 Excel 导入虚拟用户
|
||||
:param users_data: 用户数据列表
|
||||
:param generate_persona: 是否生成 AI 人格描述
|
||||
:return: 导入结果
|
||||
"""
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
created_users = []
|
||||
|
||||
for user_data in users_data:
|
||||
try:
|
||||
username = user_data.get("username")
|
||||
password = user_data.get("password")
|
||||
nickname = user_data.get("nickname", "")
|
||||
|
||||
if not username or not password:
|
||||
logger.warning(f"Missing username or password: {user_data}")
|
||||
failed_count += 1
|
||||
continue
|
||||
|
||||
# 如果昵称为空,生成一个
|
||||
if not nickname:
|
||||
nickname = f"用户{username}"
|
||||
|
||||
writing_style = user_data.get("writing_style")
|
||||
activity_level_str = user_data.get("activity_level", "medium")
|
||||
|
||||
# 转换活跃度枚举
|
||||
try:
|
||||
activity_level = ActivityLevel(activity_level_str.lower())
|
||||
except ValueError:
|
||||
activity_level = ActivityLevel.MEDIUM
|
||||
|
||||
user = self.create_user(
|
||||
username=username,
|
||||
password=password,
|
||||
nickname=nickname,
|
||||
writing_style=writing_style,
|
||||
activity_level=activity_level
|
||||
)
|
||||
|
||||
if user:
|
||||
success_count += 1
|
||||
created_users.append(user)
|
||||
else:
|
||||
failed_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Import user error: {e}")
|
||||
failed_count += 1
|
||||
|
||||
return {
|
||||
"success_count": success_count,
|
||||
"failed_count": failed_count,
|
||||
"created_users": created_users
|
||||
}
|
||||
|
||||
def get_user_stats(self, user_id: int) -> Dict[str, Any]:
|
||||
"""获取用户统计信息"""
|
||||
user = self.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return {}
|
||||
|
||||
today = datetime.now().date()
|
||||
|
||||
# 统计今日互动
|
||||
today_interactions = self.db.query(InteractionRecord).filter(
|
||||
and_(
|
||||
InteractionRecord.virtual_user_id == user_id,
|
||||
func.date(InteractionRecord.execution_time) == today
|
||||
)
|
||||
).all()
|
||||
|
||||
today_comments = sum(1 for i in today_interactions if i.interaction_type == InteractionType.COMMENT)
|
||||
today_replies = sum(1 for i in today_interactions if i.interaction_type == InteractionType.REPLY)
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"nickname": user.nickname,
|
||||
"total_interactions": user.total_interactions,
|
||||
"today_comments": today_comments,
|
||||
"today_replies": today_replies,
|
||||
"last_interaction_time": user.last_interaction_time
|
||||
}
|
||||
|
||||
|
||||
# 工厂函数
|
||||
def get_virtual_user_service(db: Session) -> VirtualUserService:
|
||||
"""获取虚拟用户服务实例"""
|
||||
return VirtualUserService(db)
|
||||
Reference in New Issue
Block a user