feat: AI虚拟用户新闻互动系统 v1.3.0 初始提交
- 虚拟用户管理(昵称/头像/性别/简介/邮箱同步到目标平台) - AI互动调度(点赞/收藏/评论/转发) - 日志时间改为北京时间 - 评论达上限后继续执行点赞收藏转发 - 一键登出全部功能 - 浅色主题UI
This commit is contained in:
21
backend/Dockerfile
Normal file
21
backend/Dockerfile
Normal file
@@ -0,0 +1,21 @@
|
||||
FROM python:3.10-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
default-libmysqlclient-dev \
|
||||
pkg-config \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN mkdir -p /app/logs /app/config
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2"]
|
||||
0
backend/app/__init__.py
Normal file
0
backend/app/__init__.py
Normal file
12
backend/app/api/__init__.py
Normal file
12
backend/app/api/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""API路由汇总"""
|
||||
from fastapi import APIRouter
|
||||
from app.api.endpoints import users, interactions, ai_models, dashboard, system, logs
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
router.include_router(users.router, prefix="/users", tags=["虚拟用户管理"])
|
||||
router.include_router(interactions.router, prefix="/interactions", tags=["互动记录"])
|
||||
router.include_router(ai_models.router, prefix="/ai-models", tags=["AI模型配置"])
|
||||
router.include_router(dashboard.router, prefix="/dashboard", tags=["数据看板"])
|
||||
router.include_router(system.router, prefix="/system", tags=["系统设置"])
|
||||
router.include_router(logs.router, prefix="/logs", tags=["日志管理"])
|
||||
0
backend/app/api/endpoints/__init__.py
Normal file
0
backend/app/api/endpoints/__init__.py
Normal file
87
backend/app/api/endpoints/ai_models.py
Normal file
87
backend/app/api/endpoints/ai_models.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""AI模型配置接口"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.schemas import ApiResponse, AIModelCreateRequest, AIModelUpdateRequest, AIModelTestRequest
|
||||
from app.models import AIModelConfig
|
||||
from app.utils.crypto import encrypt, decrypt
|
||||
from app.services.ai_service import ai_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_models(db=Depends(get_db)):
|
||||
result = await db.execute(select(AIModelConfig).order_by(AIModelConfig.created_at.desc()))
|
||||
models = result.scalars().all()
|
||||
items = [_format_model(m) for m in models]
|
||||
return ApiResponse(data=items)
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_model(req: AIModelCreateRequest, db=Depends(get_db)):
|
||||
if req.is_default:
|
||||
await db.execute(update(AIModelConfig).values(is_default=0))
|
||||
model = AIModelConfig(
|
||||
model_name=req.model_name,
|
||||
provider=req.provider,
|
||||
api_base_url=req.api_base_url,
|
||||
api_key_enc=encrypt(req.api_key) if req.api_key else None,
|
||||
model_version=req.model_version,
|
||||
temperature=req.temperature,
|
||||
max_tokens=req.max_tokens,
|
||||
timeout_seconds=req.timeout_seconds,
|
||||
is_default=req.is_default,
|
||||
is_enabled=1,
|
||||
)
|
||||
db.add(model)
|
||||
await db.commit()
|
||||
await db.refresh(model)
|
||||
return ApiResponse(data=_format_model(model), message="模型添加成功")
|
||||
|
||||
|
||||
@router.put("/{model_id}")
|
||||
async def update_model(model_id: int, req: AIModelUpdateRequest, db=Depends(get_db)):
|
||||
result = await db.execute(select(AIModelConfig).where(AIModelConfig.id == model_id))
|
||||
model = result.scalar_one_or_none()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="模型不存在")
|
||||
if req.is_default:
|
||||
await db.execute(update(AIModelConfig).where(AIModelConfig.id != model_id).values(is_default=0))
|
||||
for field, val in req.model_dump(exclude_none=True).items():
|
||||
if field == "api_key":
|
||||
model.api_key_enc = encrypt(val) if val else None
|
||||
else:
|
||||
setattr(model, field, val)
|
||||
await db.commit()
|
||||
await db.refresh(model)
|
||||
return ApiResponse(data=_format_model(model), message="更新成功")
|
||||
|
||||
|
||||
@router.delete("/{model_id}")
|
||||
async def delete_model(model_id: int, db=Depends(get_db)):
|
||||
result = await db.execute(select(AIModelConfig).where(AIModelConfig.id == model_id))
|
||||
model = result.scalar_one_or_none()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="模型不存在")
|
||||
await db.delete(model)
|
||||
await db.commit()
|
||||
return ApiResponse(message="删除成功")
|
||||
|
||||
|
||||
@router.post("/test")
|
||||
async def test_model(req: AIModelTestRequest, db=Depends(get_db)):
|
||||
result = await ai_service.test_model(db, req.model_id, req.test_prompt)
|
||||
return ApiResponse(data=result)
|
||||
|
||||
|
||||
def _format_model(m: AIModelConfig) -> dict:
|
||||
return {
|
||||
"id": m.id, "model_name": m.model_name, "provider": m.provider,
|
||||
"api_base_url": m.api_base_url, "has_api_key": bool(m.api_key_enc),
|
||||
"model_version": m.model_version, "temperature": m.temperature,
|
||||
"max_tokens": m.max_tokens, "timeout_seconds": m.timeout_seconds,
|
||||
"is_default": m.is_default, "is_enabled": m.is_enabled,
|
||||
"created_at": m.created_at.isoformat(),
|
||||
}
|
||||
25
backend/app/api/endpoints/dashboard.py
Normal file
25
backend/app/api/endpoints/dashboard.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""数据看板接口"""
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from app.core.database import get_db
|
||||
from app.schemas import ApiResponse
|
||||
from app.services.stats_service import stats_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_dashboard(db=Depends(get_db)):
|
||||
data = await stats_service.get_dashboard(db)
|
||||
return ApiResponse(data=data)
|
||||
|
||||
|
||||
@router.get("/token-trend")
|
||||
async def get_token_trend(days: int = Query(default=30, ge=7, le=90), db=Depends(get_db)):
|
||||
trend = await stats_service.get_token_trend(db, days)
|
||||
return ApiResponse(data=trend)
|
||||
|
||||
|
||||
@router.get("/monthly-token-trend")
|
||||
async def get_monthly_token_trend(db=Depends(get_db)):
|
||||
trend = await stats_service.get_monthly_token_trend(db)
|
||||
return ApiResponse(data=trend)
|
||||
170
backend/app/api/endpoints/interactions.py
Normal file
170
backend/app/api/endpoints/interactions.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""互动记录接口"""
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
import io, pandas as pd
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.schemas import ApiResponse
|
||||
from app.services.stats_service import stats_service
|
||||
from app.models import InteractionRecord
|
||||
from sqlalchemy import select, update
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_interactions(
|
||||
page: int = Query(default=1, ge=1),
|
||||
page_size: int = Query(default=20, ge=1, le=100),
|
||||
user_id: Optional[int] = None,
|
||||
interact_type: Optional[str] = None,
|
||||
status: Optional[int] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
keyword: Optional[str] = None,
|
||||
db=Depends(get_db)
|
||||
):
|
||||
result = await stats_service.get_interaction_records(
|
||||
db, page, page_size, user_id, interact_type, status, start_date, end_date, keyword
|
||||
)
|
||||
return ApiResponse(data=result)
|
||||
|
||||
|
||||
@router.post("/{record_id}/retry")
|
||||
async def retry_interaction(record_id: int, db=Depends(get_db)):
|
||||
"""手动重试失败任务"""
|
||||
result = await db.execute(select(InteractionRecord).where(InteractionRecord.id == record_id))
|
||||
record = result.scalar_one_or_none()
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="记录不存在")
|
||||
if record.status != 2:
|
||||
raise HTTPException(status_code=400, detail="只能重试失败的任务")
|
||||
if record.retry_count >= 3:
|
||||
raise HTTPException(status_code=400, detail="已超过最大重试次数(3次)")
|
||||
|
||||
from app.services.news_service import news_service
|
||||
from app.services.ai_service import ai_service
|
||||
from app.models import VirtualUser, UserPersonality
|
||||
|
||||
user_result = await db.execute(select(VirtualUser).where(VirtualUser.id == record.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
if not user or user.status != 2:
|
||||
raise HTTPException(status_code=400, detail="用户未登录,无法重试")
|
||||
|
||||
success, err = False, "未知类型"
|
||||
if record.interact_type == "comment" and record.content:
|
||||
success, err = await news_service.post_comment(db, user, record.article_id, record.article_title or "", record.content)
|
||||
elif record.interact_type == "like":
|
||||
success, err = await news_service.like_news(db, user, record.article_id, org_id="", title=record.article_title or "")
|
||||
elif record.interact_type == "collect":
|
||||
success, err = await news_service.collect_news(db, user, record.article_id, title=record.article_title or "")
|
||||
elif record.interact_type == "forward":
|
||||
success, err = await news_service.forward_news(db, user, record.article_id)
|
||||
|
||||
await db.execute(
|
||||
update(InteractionRecord).where(InteractionRecord.id == record_id).values(
|
||||
status=1 if success else 2,
|
||||
error_msg=None if success else err,
|
||||
retry_count=record.retry_count + 1,
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return ApiResponse(message="重试成功" if success else f"重试失败: {err}")
|
||||
|
||||
|
||||
@router.get("/export")
|
||||
async def export_interactions(
|
||||
user_id: Optional[int] = None,
|
||||
interact_type: Optional[str] = None,
|
||||
status: Optional[int] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""导出互动记录"""
|
||||
data = await stats_service.get_interaction_records(
|
||||
db, 1, 10000, user_id, interact_type, status, start_date, end_date
|
||||
)
|
||||
rows = [{
|
||||
"ID": r["id"], "用户昵称": r["user_nickname"], "用户账号": r["user_account"],
|
||||
"文章标题": r["article_title"], "互动类型": r["interact_type_label"],
|
||||
"内容": r["content"] or "", "Token消耗": r["token_consumed"],
|
||||
"状态": r["status_label"], "失败原因": r["error_msg"] or "",
|
||||
"重试次数": r["retry_count"], "执行时间": r["executed_at"],
|
||||
} for r in data["items"]]
|
||||
df = pd.DataFrame(rows)
|
||||
buf = io.BytesIO()
|
||||
df.to_excel(buf, index=False, sheet_name="互动记录")
|
||||
buf.seek(0)
|
||||
return StreamingResponse(
|
||||
buf,
|
||||
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
headers={"Content-Disposition": "attachment; filename=interactions_export.xlsx"}
|
||||
)
|
||||
|
||||
@router.post("/{record_id}/cancel")
|
||||
async def cancel_interaction(record_id: int, db=Depends(get_db)):
|
||||
"""取消互动(取消点赞/收藏/删除评论),转发不支持取消"""
|
||||
from sqlalchemy import select, update
|
||||
from app.models import InteractionRecord, VirtualUser
|
||||
from app.services.news_service import news_service
|
||||
|
||||
# 查找互动记录
|
||||
r = await db.execute(select(InteractionRecord).where(InteractionRecord.id == record_id))
|
||||
record = r.scalar_one_or_none()
|
||||
if not record:
|
||||
return ApiResponse(code=404, message="记录不存在")
|
||||
|
||||
if record.status != 1:
|
||||
return ApiResponse(code=400, message="只能取消成功的互动")
|
||||
|
||||
if record.interact_type == "forward":
|
||||
return ApiResponse(code=400, message="转发互动无法取消")
|
||||
|
||||
if record.interact_type == "read":
|
||||
return ApiResponse(code=400, message="阅读记录无法取消")
|
||||
|
||||
# 查找对应用户
|
||||
ur = await db.execute(select(VirtualUser).where(VirtualUser.id == record.user_id))
|
||||
user = ur.scalar_one_or_none()
|
||||
if not user:
|
||||
return ApiResponse(code=404, message="用户不存在")
|
||||
|
||||
# 执行取消
|
||||
ok = False
|
||||
err = ""
|
||||
if record.interact_type in ("like",):
|
||||
ok, err = await news_service.cancel_like(
|
||||
db, user,
|
||||
news_id=record.article_id or "",
|
||||
org_id=record.session_id or "", # session_id 字段暂存 org_id
|
||||
title=record.article_title or "",
|
||||
)
|
||||
elif record.interact_type == "collect":
|
||||
ok, err = await news_service.cancel_collect(
|
||||
db, user,
|
||||
news_id=record.article_id or "",
|
||||
title=record.article_title or "",
|
||||
)
|
||||
elif record.interact_type in ("comment", "reply"):
|
||||
comment_id = record.platform_record_id or ""
|
||||
if not comment_id:
|
||||
return ApiResponse(code=400, message="缺少评论ID,无法删除")
|
||||
ok, err = await news_service.cancel_comment(
|
||||
db, user,
|
||||
news_id=record.article_id or "",
|
||||
comment_id=comment_id,
|
||||
)
|
||||
|
||||
if ok:
|
||||
# 更新状态为手动取消(status=3)
|
||||
await db.execute(
|
||||
update(InteractionRecord).where(InteractionRecord.id == record_id).values(
|
||||
status=3, error_msg="手动取消"
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return ApiResponse(message="取消成功")
|
||||
else:
|
||||
return ApiResponse(code=500, message=f"取消失败: {err}")
|
||||
83
backend/app/api/endpoints/logs.py
Normal file
83
backend/app/api/endpoints/logs.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""日志管理接口"""
|
||||
import os
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy import select, func, and_
|
||||
from app.core.database import get_db
|
||||
from app.schemas import ApiResponse
|
||||
from app.models import LoginLog
|
||||
from app.core.config import settings
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/login")
|
||||
async def get_login_logs(
|
||||
page: int = Query(default=1, ge=1),
|
||||
page_size: int = Query(default=50, ge=1, le=200),
|
||||
user_id: Optional[int] = None,
|
||||
action: Optional[str] = None,
|
||||
db=Depends(get_db)
|
||||
):
|
||||
query = select(LoginLog)
|
||||
conditions = []
|
||||
if user_id:
|
||||
conditions.append(LoginLog.user_id == user_id)
|
||||
if action:
|
||||
conditions.append(LoginLog.action == action)
|
||||
if conditions:
|
||||
query = query.where(and_(*conditions))
|
||||
|
||||
total = (await db.execute(select(func.count()).select_from(query.subquery()))).scalar()
|
||||
query = query.order_by(LoginLog.created_at.desc()).offset((page - 1) * page_size).limit(page_size)
|
||||
result = await db.execute(query)
|
||||
logs = result.scalars().all()
|
||||
|
||||
items = [{
|
||||
"id": l.id, "user_id": l.user_id, "user_account": l.user_account,
|
||||
"action": l.action, "session_id": l.session_id,
|
||||
"error_msg": l.error_msg, "created_at": l.created_at.isoformat()
|
||||
} for l in logs]
|
||||
return ApiResponse(data={"total": total, "page": page, "page_size": page_size, "items": items})
|
||||
|
||||
|
||||
@router.get("/files")
|
||||
async def list_log_files():
|
||||
"""列出日志文件"""
|
||||
log_dir = settings.LOG_DIR
|
||||
files = []
|
||||
if os.path.exists(log_dir):
|
||||
for fname in sorted(os.listdir(log_dir), reverse=True):
|
||||
if fname.endswith(".log"):
|
||||
fpath = os.path.join(log_dir, fname)
|
||||
size = os.path.getsize(fpath)
|
||||
files.append({"name": fname, "size": size,
|
||||
"size_kb": round(size / 1024, 1)})
|
||||
return ApiResponse(data=files)
|
||||
|
||||
|
||||
@router.get("/files/{filename}/tail")
|
||||
async def tail_log_file(filename: str, lines: int = Query(default=100, ge=10, le=1000)):
|
||||
"""读取日志文件末尾"""
|
||||
# 安全校验
|
||||
if ".." in filename or "/" in filename:
|
||||
raise HTTPException(status_code=400, detail="非法文件名")
|
||||
fpath = os.path.join(settings.LOG_DIR, filename)
|
||||
if not os.path.exists(fpath):
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
with open(fpath, "r", encoding="utf-8", errors="replace") as f:
|
||||
all_lines = f.readlines()
|
||||
tail = all_lines[-lines:]
|
||||
return ApiResponse(data={"filename": filename, "lines": tail, "total_lines": len(all_lines)})
|
||||
|
||||
|
||||
@router.get("/files/{filename}/download")
|
||||
async def download_log_file(filename: str):
|
||||
"""下载日志文件"""
|
||||
if ".." in filename or "/" in filename:
|
||||
raise HTTPException(status_code=400, detail="非法文件名")
|
||||
fpath = os.path.join(settings.LOG_DIR, filename)
|
||||
if not os.path.exists(fpath):
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
return FileResponse(fpath, filename=filename, media_type="text/plain")
|
||||
115
backend/app/api/endpoints/system.py
Normal file
115
backend/app/api/endpoints/system.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""系统设置接口"""
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import select, update as sql_update
|
||||
from app.core.database import get_db
|
||||
from app.schemas import ApiResponse
|
||||
from app.models import SystemConfig
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/configs")
|
||||
async def get_configs(db=Depends(get_db)):
|
||||
result = await db.execute(select(SystemConfig).order_by(SystemConfig.config_key))
|
||||
configs = result.scalars().all()
|
||||
data = {c.config_key: {"value": c.config_value, "type": c.config_type, "desc": c.description}
|
||||
for c in configs}
|
||||
return ApiResponse(data=data)
|
||||
|
||||
|
||||
@router.put("/configs")
|
||||
async def update_configs(body: dict, db=Depends(get_db)):
|
||||
"""批量更新配置"""
|
||||
for key, value in body.items():
|
||||
result = await db.execute(select(SystemConfig).where(SystemConfig.config_key == key))
|
||||
cfg = result.scalar_one_or_none()
|
||||
if cfg:
|
||||
cfg.config_value = str(value)
|
||||
else:
|
||||
db.add(SystemConfig(config_key=key, config_value=str(value)))
|
||||
await db.commit()
|
||||
return ApiResponse(message="配置已保存")
|
||||
|
||||
|
||||
@router.post("/scheduler/toggle")
|
||||
async def toggle_scheduler(body: dict, db=Depends(get_db)):
|
||||
enabled = body.get("enabled", True)
|
||||
result = await db.execute(select(SystemConfig).where(SystemConfig.config_key == "scheduler_enabled"))
|
||||
cfg = result.scalar_one_or_none()
|
||||
if cfg:
|
||||
cfg.config_value = "true" if enabled else "false"
|
||||
await db.commit()
|
||||
return ApiResponse(message=f"调度器已{'启用' if enabled else '暂停'}")
|
||||
|
||||
|
||||
@router.post("/sessions/reset-all")
|
||||
async def reset_all_sessions(db=Depends(get_db)):
|
||||
"""重置所有用户会话"""
|
||||
from app.models import VirtualUser
|
||||
await db.execute(
|
||||
sql_update(VirtualUser).values(status=0, session_token=None, session_expires_at=None)
|
||||
)
|
||||
await db.commit()
|
||||
return ApiResponse(message="所有会话已重置")
|
||||
|
||||
|
||||
@router.post("/login/diagnose")
|
||||
async def diagnose_login(body: dict, db=Depends(get_db)):
|
||||
"""
|
||||
诊断登录接口原始响应 - 临时调试用
|
||||
传入: {"username": "xxx", "password": "xxx"}
|
||||
"""
|
||||
import httpx, hashlib, uuid
|
||||
from datetime import datetime
|
||||
from app.services.news_service import news_service
|
||||
|
||||
cfg = await news_service._client(db)
|
||||
auth = await news_service._auth_url(db)
|
||||
|
||||
username = body.get("username", "")
|
||||
password = body.get("password", "")
|
||||
|
||||
# 构建 formData(与真实登录完全一致)
|
||||
extra = {
|
||||
"username": username,
|
||||
"password": password,
|
||||
"loginType": "password",
|
||||
"grantType": "password",
|
||||
"isRegister": "false",
|
||||
}
|
||||
if cfg.get("clientCode"):
|
||||
extra["clientCode"] = cfg["clientCode"]
|
||||
|
||||
form = news_service._build_form(extra, cfg)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15) as c:
|
||||
resp = await c.post(f"{auth}/open/login/token", data=form)
|
||||
|
||||
# 返回完整诊断信息
|
||||
try:
|
||||
resp_json = resp.json()
|
||||
except Exception:
|
||||
resp_json = None
|
||||
|
||||
return ApiResponse(data={
|
||||
"status_code": resp.status_code,
|
||||
"response_text": resp.text[:2000],
|
||||
"response_json": resp_json,
|
||||
"request_url": f"{auth}/open/login/token",
|
||||
"request_form": {k: v if k not in ("password","accessSecret") else "***" for k, v in form.items()},
|
||||
"content_type": resp.headers.get("content-type", ""),
|
||||
})
|
||||
except Exception as e:
|
||||
return ApiResponse(code=500, message=str(e), data={"error": str(e)})
|
||||
|
||||
|
||||
@router.post("/interaction/run-now")
|
||||
async def run_interaction_now(db=Depends(get_db)):
|
||||
"""立即触发一次互动任务(不受时间段限制)"""
|
||||
from app.services.scheduler import scheduler_service
|
||||
try:
|
||||
result = await scheduler_service.run_once_now(db)
|
||||
return ApiResponse(data=result, message="互动任务已触发")
|
||||
except Exception as e:
|
||||
return ApiResponse(code=500, message=f"触发失败: {e}")
|
||||
370
backend/app/api/endpoints/users.py
Normal file
370
backend/app/api/endpoints/users.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""虚拟用户管理接口"""
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, Query, UploadFile, File, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
import io
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.schemas import ApiResponse, UserCreateRequest, UserUpdateRequest, UserBatchRequest, PersonalityUpdateRequest
|
||||
from app.services.user_service import user_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_users(
|
||||
page: int = Query(default=1, ge=1),
|
||||
page_size: int = Query(default=20, ge=1, le=100),
|
||||
keyword: Optional[str] = Query(default=None),
|
||||
status: Optional[int] = Query(default=None),
|
||||
is_enabled: Optional[int] = Query(default=None),
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""获取虚拟用户列表"""
|
||||
total, items = await user_service.get_users(db, page, page_size, keyword, status, is_enabled)
|
||||
return ApiResponse(data={"total": total, "page": page, "page_size": page_size, "items": items})
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_user(req: UserCreateRequest, db=Depends(get_db)):
|
||||
"""创建虚拟用户"""
|
||||
user = await user_service.create_user(db, req)
|
||||
return ApiResponse(data=user, message="用户创建成功")
|
||||
|
||||
|
||||
@router.get("/{user_id}")
|
||||
async def get_user(user_id: int, db=Depends(get_db)):
|
||||
"""获取单个用户详情"""
|
||||
total, items = await user_service.get_users(db, 1, 1)
|
||||
from sqlalchemy import select
|
||||
from app.models import VirtualUser, UserPersonality
|
||||
from app.services.user_service import user_service as svc
|
||||
result = await db.execute(select(VirtualUser).where(VirtualUser.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
p_result = await db.execute(select(UserPersonality).where(UserPersonality.user_id == user_id))
|
||||
personality = p_result.scalar_one_or_none()
|
||||
return ApiResponse(data=svc._format_user(user, personality))
|
||||
|
||||
|
||||
@router.put("/{user_id}")
|
||||
async def update_user(user_id: int, req: UserUpdateRequest, db=Depends(get_db)):
|
||||
"""更新用户信息(sync_to_platform=true 时同步到目标平台)"""
|
||||
result = await user_service.update_user(db, user_id, req)
|
||||
|
||||
if req.sync_to_platform:
|
||||
from sqlalchemy import select
|
||||
from app.models import VirtualUser as _VU
|
||||
from app.services.news_service import news_service
|
||||
ur = await db.execute(select(_VU).where(_VU.id == user_id))
|
||||
user = ur.scalar_one_or_none()
|
||||
if user and user.status == 2:
|
||||
ok, err = await news_service.update_user_profile(
|
||||
db, user,
|
||||
nick_name=req.nickname,
|
||||
real_name=req.real_name,
|
||||
sex=req.sex,
|
||||
description=req.description,
|
||||
email=req.email,
|
||||
)
|
||||
if not ok:
|
||||
return ApiResponse(data=result,
|
||||
message=f"本地已保存,同步到平台失败: {err}", code=206)
|
||||
|
||||
return ApiResponse(data=result, message="更新成功")
|
||||
|
||||
|
||||
@router.delete("/{user_id}")
|
||||
async def delete_user(user_id: int, db=Depends(get_db)):
|
||||
"""删除用户"""
|
||||
await user_service.delete_user(db, user_id)
|
||||
return ApiResponse(message="删除成功")
|
||||
|
||||
|
||||
@router.post("/batch/action")
|
||||
async def batch_action(req: UserBatchRequest, db=Depends(get_db)):
|
||||
"""批量操作用户"""
|
||||
result = await user_service.batch_action(db, req.user_ids, req.action)
|
||||
return ApiResponse(data=result, message="批量操作成功")
|
||||
|
||||
|
||||
@router.post("/{user_id}/login")
|
||||
async def manual_login(user_id: int, db=Depends(get_db)):
|
||||
"""手动触发用户登录"""
|
||||
from app.services.news_service import news_service
|
||||
from sqlalchemy import select
|
||||
from app.models import VirtualUser
|
||||
result = await db.execute(select(VirtualUser).where(VirtualUser.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
success = await news_service.login(db, user)
|
||||
if success:
|
||||
return ApiResponse(message="登录成功")
|
||||
raise HTTPException(status_code=400, detail="登录失败,请检查账号密码")
|
||||
|
||||
|
||||
@router.post("/{user_id}/logout")
|
||||
async def manual_logout(user_id: int, db=Depends(get_db)):
|
||||
"""手动登出"""
|
||||
from app.services.news_service import news_service
|
||||
await news_service.logout(db, user_id)
|
||||
return ApiResponse(message="已登出")
|
||||
|
||||
|
||||
@router.post("/{user_id}/personality/generate")
|
||||
async def generate_personality(user_id: int, db=Depends(get_db)):
|
||||
"""重新生成AI人格"""
|
||||
personality = await user_service.generate_personality(db, user_id)
|
||||
return ApiResponse(data=personality, message="人格生成成功")
|
||||
|
||||
|
||||
@router.put("/{user_id}/personality")
|
||||
async def update_personality(user_id: int, req: PersonalityUpdateRequest, db=Depends(get_db)):
|
||||
"""手动编辑人格属性"""
|
||||
personality = await user_service.update_personality(db, user_id, req)
|
||||
return ApiResponse(data=personality, message="人格更新成功")
|
||||
|
||||
|
||||
@router.get("/excel/template")
|
||||
async def download_template():
|
||||
"""下载Excel导入模板"""
|
||||
content = await user_service.get_excel_template()
|
||||
return StreamingResponse(
|
||||
io.BytesIO(content),
|
||||
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
headers={"Content-Disposition": "attachment; filename=virtual_users_template.xlsx"}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/excel/import")
|
||||
async def import_excel(file: UploadFile = File(...), db=Depends(get_db)):
|
||||
"""Excel批量导入"""
|
||||
if not file.filename.endswith((".xlsx", ".xls")):
|
||||
raise HTTPException(status_code=400, detail="仅支持Excel文件(.xlsx/.xls)")
|
||||
content = await file.read()
|
||||
result = await user_service.import_from_excel(db, content)
|
||||
return ApiResponse(data=result, message=f"导入完成:成功{result['success']}条,失败{result['failed']}条")
|
||||
|
||||
|
||||
@router.get("/excel/export")
|
||||
async def export_excel(db=Depends(get_db)):
|
||||
"""导出用户数据Excel"""
|
||||
content = await user_service.export_to_excel(db)
|
||||
return StreamingResponse(
|
||||
io.BytesIO(content),
|
||||
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
headers={"Content-Disposition": "attachment; filename=virtual_users_export.xlsx"}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/deduplicate")
|
||||
async def deduplicate_users(db=Depends(get_db)):
|
||||
"""删除重复用户(保留最早创建的一条)"""
|
||||
from sqlalchemy import text
|
||||
# 找出重复的账号,保留 id 最小的,删除其他的
|
||||
result = await db.execute(
|
||||
text("""
|
||||
DELETE FROM virtual_users
|
||||
WHERE id NOT IN (
|
||||
SELECT MIN(id) FROM virtual_users GROUP BY account
|
||||
)
|
||||
""")
|
||||
)
|
||||
await db.commit()
|
||||
deleted = result.rowcount
|
||||
return ApiResponse(data={"deleted": deleted}, message=f"已清理 {deleted} 条重复数据")
|
||||
|
||||
|
||||
@router.post("/clear-all")
|
||||
async def clear_all_users(db=Depends(get_db)):
|
||||
"""清空所有用户(慎用)"""
|
||||
from sqlalchemy import text
|
||||
from app.core.redis_client import get_redis
|
||||
await db.execute(text("DELETE FROM user_personalities"))
|
||||
await db.execute(text("DELETE FROM virtual_users"))
|
||||
await db.commit()
|
||||
return ApiResponse(message="已清空所有用户数据")
|
||||
|
||||
|
||||
@router.post("/login-all")
|
||||
async def batch_login_all(db=Depends(get_db)):
|
||||
"""一键登录所有未登录/登录失效的用户"""
|
||||
from sqlalchemy import select
|
||||
from app.services.news_service import news_service
|
||||
from app.models import VirtualUser as _VU
|
||||
from app.core.database import AsyncSessionLocal
|
||||
import asyncio
|
||||
|
||||
# 先用当前 session 查出所有待登录用户 ID
|
||||
result_r = await db.execute(
|
||||
select(_VU.id, _VU.account).where(
|
||||
_VU.is_enabled == 1,
|
||||
_VU.status.in_([0, 3])
|
||||
)
|
||||
)
|
||||
rows = result_r.all()
|
||||
if not rows:
|
||||
return ApiResponse(message="没有需要登录的用户", data={"count": 0})
|
||||
|
||||
user_ids = [r[0] for r in rows]
|
||||
total = len(user_ids)
|
||||
success = failed = 0
|
||||
|
||||
# 每个用户独立 session,避免事务污染
|
||||
async def login_one(uid: int):
|
||||
async with AsyncSessionLocal() as s:
|
||||
try:
|
||||
ur = await s.execute(select(_VU).where(_VU.id == uid))
|
||||
u = ur.scalar_one_or_none()
|
||||
if u:
|
||||
return await news_service.login(s, u)
|
||||
except Exception as e:
|
||||
logger.warning(f"login_one {uid} 异常: {e}")
|
||||
return False
|
||||
return False
|
||||
|
||||
batch_size = 5
|
||||
for i in range(0, total, batch_size):
|
||||
batch_ids = user_ids[i:i+batch_size]
|
||||
results = await asyncio.gather(*[login_one(uid) for uid in batch_ids], return_exceptions=True)
|
||||
for r in results:
|
||||
if r is True: success += 1
|
||||
else: failed += 1
|
||||
if i + batch_size < total:
|
||||
await asyncio.sleep(1) # 批次间隔避免过于集中
|
||||
|
||||
return ApiResponse(
|
||||
message=f"登录完成:成功 {success} 个,失败 {failed} 个",
|
||||
data={"success": success, "failed": failed, "total": total}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sync-all-profiles")
|
||||
async def sync_all_profiles(db=Depends(get_db)):
|
||||
"""
|
||||
同步所有已登录用户的平台信息(昵称/真实姓名/性别/头像)到本系统
|
||||
从登录 session 中的 token 调用目标平台接口获取最新用户信息
|
||||
"""
|
||||
from sqlalchemy import select, update
|
||||
from app.models import VirtualUser as _VU
|
||||
from app.core.database import AsyncSessionLocal
|
||||
from app.core.redis_client import get_session
|
||||
import httpx, asyncio
|
||||
|
||||
# 查出所有已登录用户
|
||||
result_r = await db.execute(select(_VU).where(_VU.status == 2, _VU.is_enabled == 1))
|
||||
users = result_r.scalars().all()
|
||||
if not users:
|
||||
return ApiResponse(message="没有已登录的用户", data={"synced": 0})
|
||||
|
||||
synced = failed = 0
|
||||
|
||||
async def sync_one(uid: int):
|
||||
"""从登录 session 中提取已缓存的用户信息,直接写入数据库,无需调用外部接口"""
|
||||
async with AsyncSessionLocal() as s:
|
||||
try:
|
||||
sess = await get_session(uid)
|
||||
if not sess:
|
||||
return False
|
||||
platform_uid = sess.get("platform_uid", "")
|
||||
# 登录成功时 session 里已存有用户信息
|
||||
vals = {}
|
||||
if platform_uid: vals["platform_uid"] = platform_uid
|
||||
# session 里的字段(登录时写入)
|
||||
if sess.get("nickname"): vals["nickname"] = sess["nickname"]
|
||||
if sess.get("real_name"): vals["real_name"] = sess["real_name"]
|
||||
if sess.get("sex"): vals["sex"] = int(sess["sex"])
|
||||
if sess.get("avatar"): vals["avatar_url"] = sess["avatar"]
|
||||
if vals:
|
||||
await s.execute(update(_VU).where(_VU.id == uid).values(**vals))
|
||||
await s.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"sync_one {uid} 失败: {e}")
|
||||
return False
|
||||
|
||||
results = await asyncio.gather(*[sync_one(u.id) for u in users], return_exceptions=True)
|
||||
for r in results:
|
||||
if r is True: synced += 1
|
||||
else: failed += 1
|
||||
|
||||
return ApiResponse(
|
||||
message=f"同步完成:成功 {synced} 个,失败/跳过 {failed} 个",
|
||||
data={"synced": synced, "failed": failed, "total": len(users)}
|
||||
)
|
||||
|
||||
@router.post("/{user_id}/upload-avatar")
|
||||
async def upload_avatar(
|
||||
user_id: int,
|
||||
file: UploadFile = File(...),
|
||||
sync_to_platform: bool = Query(default=True),
|
||||
db=Depends(get_db)
|
||||
):
|
||||
"""上传头像并可选同步到目标平台"""
|
||||
from sqlalchemy import select, update
|
||||
from app.models import VirtualUser as _VU
|
||||
from app.services.news_service import news_service
|
||||
|
||||
ur = await db.execute(select(_VU).where(_VU.id == user_id))
|
||||
user = ur.scalar_one_or_none()
|
||||
if not user:
|
||||
return ApiResponse(code=404, message="用户不存在")
|
||||
|
||||
# 读取文件内容
|
||||
file_bytes = await file.read()
|
||||
if len(file_bytes) > 5 * 1024 * 1024:
|
||||
return ApiResponse(code=400, message="头像文件不能超过5MB")
|
||||
|
||||
avatar_url = None
|
||||
|
||||
if sync_to_platform and user.status == 2:
|
||||
# 上传到目标平台
|
||||
ok, result = await news_service.upload_avatar(db, user, file_bytes, file.filename)
|
||||
if ok:
|
||||
avatar_url = result
|
||||
else:
|
||||
return ApiResponse(code=500, message=f"头像上传到平台失败: {result}")
|
||||
else:
|
||||
# 仅本地存储(转 base64 或存储到本地)
|
||||
import base64
|
||||
avatar_url = f"data:{file.content_type};base64,{base64.b64encode(file_bytes).decode()}"
|
||||
|
||||
# 更新数据库
|
||||
await db.execute(update(_VU).where(_VU.id == user_id).values(avatar_url=avatar_url))
|
||||
await db.commit()
|
||||
|
||||
# 如果已同步到平台,再调用 update_user_profile 更新头像字段
|
||||
if sync_to_platform and user.status == 2 and avatar_url:
|
||||
await news_service.update_user_profile(db, user, avatar=avatar_url)
|
||||
|
||||
return ApiResponse(data={"avatar_url": avatar_url}, message="头像更新成功")
|
||||
@router.post("/logout-all")
|
||||
async def batch_logout_all(db=Depends(get_db)):
|
||||
"""一键登出所有已登录用户"""
|
||||
from sqlalchemy import select, update
|
||||
from app.models import VirtualUser as _VU
|
||||
from app.core.redis_client import delete_session
|
||||
|
||||
result_r = await db.execute(
|
||||
select(_VU.id).where(_VU.status == 2, _VU.is_enabled == 1)
|
||||
)
|
||||
rows = result_r.all()
|
||||
if not rows:
|
||||
return ApiResponse(message="没有已登录的用户", data={"count": 0})
|
||||
|
||||
count = 0
|
||||
for row in rows:
|
||||
try:
|
||||
await delete_session(row[0])
|
||||
count += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 更新所有用户状态为未登录
|
||||
await db.execute(
|
||||
update(_VU).where(_VU.status == 2).values(status=0)
|
||||
)
|
||||
await db.commit()
|
||||
return ApiResponse(message=f"已登出 {count} 个用户", data={"count": count})
|
||||
1
backend/app/core/__init__.py
Normal file
1
backend/app/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# app.core package
|
||||
46
backend/app/core/config.py
Normal file
46
backend/app/core/config.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""系统配置"""
|
||||
import os
|
||||
from urllib.parse import quote_plus
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# 数据库
|
||||
DB_HOST: str = os.getenv("DB_HOST", "localhost")
|
||||
DB_PORT: int = int(os.getenv("DB_PORT", "3306"))
|
||||
DB_USER: str = os.getenv("DB_USER", "aivirtual")
|
||||
DB_PASSWORD: str = os.getenv("DB_PASSWORD", "AiVirtual2024")
|
||||
DB_NAME: str = os.getenv("DB_NAME", "ai_virtual_news")
|
||||
|
||||
# Redis
|
||||
REDIS_HOST: str = os.getenv("REDIS_HOST", "localhost")
|
||||
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
|
||||
|
||||
# 安全
|
||||
SECRET_KEY: str = os.getenv("SECRET_KEY", "dev-secret-key-change-in-prod")
|
||||
AES_KEY: str = os.getenv("AES_KEY", "your-aes-key-32-chars-change-now!")
|
||||
|
||||
# 新闻平台
|
||||
NEWS_PLATFORM_BASE_URL: str = os.getenv(
|
||||
"NEWS_PLATFORM_BASE_URL", "http://192.168.1.200:63120"
|
||||
)
|
||||
|
||||
# 日志目录
|
||||
LOG_DIR: str = "/app/logs"
|
||||
|
||||
@property
|
||||
def DATABASE_URL(self) -> str:
|
||||
# 对密码做 URL 编码,防止 @ # ! 等特殊字符破坏连接字符串
|
||||
pwd = quote_plus(self.DB_PASSWORD)
|
||||
return f"mysql+aiomysql://{self.DB_USER}:{pwd}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}?charset=utf8mb4"
|
||||
|
||||
@property
|
||||
def SYNC_DATABASE_URL(self) -> str:
|
||||
pwd = quote_plus(self.DB_PASSWORD)
|
||||
return f"mysql+pymysql://{self.DB_USER}:{pwd}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}?charset=utf8mb4"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
72
backend/app/core/database.py
Normal file
72
backend/app/core/database.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""数据库连接管理"""
|
||||
import asyncio
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from app.core.config import settings
|
||||
from app.core.logger import logger
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
pool_size=10,
|
||||
max_overflow=20,
|
||||
)
|
||||
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
|
||||
async def get_db():
|
||||
"""获取数据库会话"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def wait_for_db(max_retries: int = 30, interval: int = 2):
|
||||
"""等待 MySQL 就绪,最多重试 max_retries 次"""
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(__import__("sqlalchemy").text("SELECT 1"))
|
||||
logger.info(f"✅ 数据库连接成功(第 {attempt} 次尝试)")
|
||||
return
|
||||
except Exception as e:
|
||||
if attempt == max_retries:
|
||||
logger.error(f"数据库连接失败,已重试 {max_retries} 次: {e}")
|
||||
raise
|
||||
logger.warning(f"数据库未就绪,{interval}s 后重试({attempt}/{max_retries}): {e}")
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
|
||||
async def init_db():
|
||||
"""初始化数据库 - 等待 MySQL 就绪并注册所有模型"""
|
||||
try:
|
||||
# 等待 MySQL 容器真正就绪
|
||||
await wait_for_db(max_retries=30, interval=2)
|
||||
|
||||
# 导入所有模型类,确保 SQLAlchemy ORM 元数据注册
|
||||
from app.models import (
|
||||
VirtualUser, UserPersonality, InteractionRecord,
|
||||
TokenStat, AIModelConfig, SystemConfig, LoginLog
|
||||
)
|
||||
logger.info("✅ 数据库模型注册成功")
|
||||
logger.info("✅ 数据库初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"数据库初始化失败: {e}")
|
||||
raise
|
||||
48
backend/app/core/logger.py
Normal file
48
backend/app/core/logger.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""日志配置"""
|
||||
import sys
|
||||
import os
|
||||
from loguru import logger
|
||||
|
||||
LOG_DIR = os.getenv("LOG_DIR", "./logs")
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
|
||||
# 移除默认处理器
|
||||
logger.remove()
|
||||
|
||||
# 控制台输出
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
level="INFO",
|
||||
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
|
||||
)
|
||||
|
||||
# 通用日志文件
|
||||
logger.add(
|
||||
f"{LOG_DIR}/app_{{time:YYYY-MM-DD}}.log",
|
||||
rotation="00:00",
|
||||
retention="30 days",
|
||||
level="INFO",
|
||||
encoding="utf-8",
|
||||
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}",
|
||||
)
|
||||
|
||||
# 错误日志文件
|
||||
logger.add(
|
||||
f"{LOG_DIR}/error_{{time:YYYY-MM-DD}}.log",
|
||||
rotation="00:00",
|
||||
retention="30 days",
|
||||
level="ERROR",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# AI调用日志
|
||||
logger.add(
|
||||
f"{LOG_DIR}/ai_{{time:YYYY-MM-DD}}.log",
|
||||
rotation="00:00",
|
||||
retention="30 days",
|
||||
level="INFO",
|
||||
encoding="utf-8",
|
||||
filter=lambda record: "ai_call" in record["extra"],
|
||||
)
|
||||
|
||||
__all__ = ["logger"]
|
||||
81
backend/app/core/redis_client.py
Normal file
81
backend/app/core/redis_client.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Redis缓存客户端"""
|
||||
import json
|
||||
import redis.asyncio as aioredis
|
||||
from app.core.config import settings
|
||||
from app.core.logger import logger
|
||||
|
||||
_redis_client = None
|
||||
|
||||
|
||||
async def get_redis() -> aioredis.Redis:
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
_redis_client = aioredis.from_url(
|
||||
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}",
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
return _redis_client
|
||||
|
||||
|
||||
# Session键前缀
|
||||
SESSION_PREFIX = "session:"
|
||||
LOCK_PREFIX = "lock:"
|
||||
RATE_PREFIX = "rate:"
|
||||
|
||||
|
||||
async def set_session(user_id: int, session_data: dict, expire: int = 86400):
|
||||
"""存储用户会话"""
|
||||
r = await get_redis()
|
||||
key = f"{SESSION_PREFIX}{user_id}"
|
||||
await r.setex(key, expire, json.dumps(session_data, ensure_ascii=False))
|
||||
|
||||
|
||||
async def get_session(user_id: int) -> dict | None:
|
||||
"""获取用户会话"""
|
||||
r = await get_redis()
|
||||
key = f"{SESSION_PREFIX}{user_id}"
|
||||
data = await r.get(key)
|
||||
if data:
|
||||
return json.loads(data)
|
||||
return None
|
||||
|
||||
|
||||
async def delete_session(user_id: int):
|
||||
"""删除用户会话"""
|
||||
r = await get_redis()
|
||||
key = f"{SESSION_PREFIX}{user_id}"
|
||||
await r.delete(key)
|
||||
|
||||
|
||||
async def acquire_lock(name: str, expire: int = 60) -> bool:
|
||||
"""获取分布式锁"""
|
||||
r = await get_redis()
|
||||
key = f"{LOCK_PREFIX}{name}"
|
||||
result = await r.set(key, "1", nx=True, ex=expire)
|
||||
return result is True
|
||||
|
||||
|
||||
async def release_lock(name: str):
|
||||
"""释放分布式锁"""
|
||||
r = await get_redis()
|
||||
key = f"{LOCK_PREFIX}{name}"
|
||||
await r.delete(key)
|
||||
|
||||
|
||||
async def incr_rate(key: str, expire: int = 86400) -> int:
|
||||
"""限流计数"""
|
||||
r = await get_redis()
|
||||
rate_key = f"{RATE_PREFIX}{key}"
|
||||
count = await r.incr(rate_key)
|
||||
if count == 1:
|
||||
await r.expire(rate_key, expire)
|
||||
return count
|
||||
|
||||
|
||||
async def get_counter(key: str) -> int:
|
||||
"""获取计数"""
|
||||
r = await get_redis()
|
||||
rate_key = f"{RATE_PREFIX}{key}"
|
||||
val = await r.get(rate_key)
|
||||
return int(val) if val else 0
|
||||
65
backend/app/main.py
Normal file
65
backend/app/main.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
AI虚拟用户新闻互动系统 - 后端主入口
|
||||
"""
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import init_db
|
||||
from app.core.logger import logger
|
||||
from app.api import router
|
||||
from app.services.scheduler import scheduler_service
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
logger.info("🚀 AI虚拟用户新闻互动系统启动中...")
|
||||
# 初始化数据库
|
||||
await init_db()
|
||||
# 启动调度器
|
||||
await scheduler_service.start()
|
||||
logger.info("✅ 系统启动完成")
|
||||
yield
|
||||
# 关闭调度器
|
||||
await scheduler_service.stop()
|
||||
logger.info("🛑 系统已关闭")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="AI虚拟用户新闻互动系统",
|
||||
description="基于AI驱动的虚拟用户新闻互动自动化平台",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
docs_url="/api/docs",
|
||||
redoc_url="/api/redoc",
|
||||
)
|
||||
|
||||
# CORS配置
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
app.include_router(router, prefix="/api")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "ok", "service": "ai-virtual-news-backend"}
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request, exc):
|
||||
logger.error(f"全局异常: {exc}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"code": 500, "message": f"服务器内部错误: {str(exc)}"},
|
||||
)
|
||||
132
backend/app/models/__init__.py
Normal file
132
backend/app/models/__init__.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""SQLAlchemy ORM 模型"""
|
||||
from datetime import datetime
|
||||
from sqlalchemy import (
|
||||
BigInteger, Integer, SmallInteger, String, Text, DateTime,
|
||||
Boolean, Float, Date, JSON, func
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class VirtualUser(Base):
|
||||
__tablename__ = "virtual_users"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
nickname: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
account: Mapped[str] = mapped_column(String(128), nullable=False, unique=True)
|
||||
password_enc: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
avatar_url: Mapped[str | None] = mapped_column(String(512))
|
||||
status: Mapped[int] = mapped_column(SmallInteger, default=0)
|
||||
activity_level: Mapped[int] = mapped_column(SmallInteger, default=1)
|
||||
daily_comment_limit: Mapped[int] = mapped_column(Integer, default=10)
|
||||
daily_like_limit: Mapped[int] = mapped_column(Integer, default=30)
|
||||
today_comment_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
today_like_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
total_interactions: Mapped[int] = mapped_column(Integer, default=0)
|
||||
session_token: Mapped[str | None] = mapped_column(Text)
|
||||
session_expires_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
last_login_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
last_interact_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
real_name: Mapped[str | None] = mapped_column(String(64)) # 真实姓名(从平台同步)
|
||||
sex: Mapped[int] = mapped_column(SmallInteger, default=0) # 性别 0未知 1男 2女
|
||||
platform_uid: Mapped[str | None] = mapped_column(String(64)) # 平台用户ID
|
||||
remark: Mapped[str | None] = mapped_column(String(256))
|
||||
is_enabled: Mapped[int] = mapped_column(SmallInteger, default=1)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now())
|
||||
|
||||
|
||||
class UserPersonality(Base):
|
||||
__tablename__ = "user_personalities"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False, unique=True)
|
||||
character_type: Mapped[str | None] = mapped_column(String(32))
|
||||
language_style: Mapped[str | None] = mapped_column(String(32))
|
||||
interest_tags: Mapped[dict | None] = mapped_column(JSON)
|
||||
interact_tendency: Mapped[str | None] = mapped_column(String(32))
|
||||
word_count_min: Mapped[int] = mapped_column(Integer, default=20)
|
||||
word_count_max: Mapped[int] = mapped_column(Integer, default=100)
|
||||
personality_desc: Mapped[str | None] = mapped_column(Text)
|
||||
comment_style_prompt: Mapped[str | None] = mapped_column(Text)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now())
|
||||
|
||||
|
||||
class InteractionRecord(Base):
|
||||
__tablename__ = "interaction_records"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False, index=True)
|
||||
user_nickname: Mapped[str | None] = mapped_column(String(64))
|
||||
user_account: Mapped[str | None] = mapped_column(String(128))
|
||||
article_id: Mapped[str | None] = mapped_column(String(64))
|
||||
article_title: Mapped[str | None] = mapped_column(String(256))
|
||||
interact_type: Mapped[str] = mapped_column(String(16), nullable=False, index=True)
|
||||
content: Mapped[str | None] = mapped_column(Text)
|
||||
platform_record_id: Mapped[str | None] = mapped_column(String(64)) # 平台返回的记录ID(用于取消互动)
|
||||
parent_comment_id: Mapped[str | None] = mapped_column(String(64))
|
||||
session_id: Mapped[str | None] = mapped_column(String(128))
|
||||
token_consumed: Mapped[int] = mapped_column(Integer, default=0)
|
||||
status: Mapped[int] = mapped_column(SmallInteger, default=0)
|
||||
error_msg: Mapped[str | None] = mapped_column(String(512))
|
||||
retry_count: Mapped[int] = mapped_column(SmallInteger, default=0)
|
||||
executed_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), index=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
|
||||
|
||||
class TokenStat(Base):
|
||||
__tablename__ = "token_stats"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
stat_date: Mapped[datetime] = mapped_column(Date, nullable=False, unique=True)
|
||||
model_name: Mapped[str | None] = mapped_column(String(64))
|
||||
total_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
prompt_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
completion_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
call_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now())
|
||||
|
||||
|
||||
class AIModelConfig(Base):
|
||||
__tablename__ = "ai_model_configs"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
model_name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
api_base_url: Mapped[str | None] = mapped_column(String(256))
|
||||
api_key_enc: Mapped[str | None] = mapped_column(String(512))
|
||||
model_version: Mapped[str | None] = mapped_column(String(64))
|
||||
temperature: Mapped[float] = mapped_column(Float, default=0.7)
|
||||
max_tokens: Mapped[int] = mapped_column(Integer, default=1000)
|
||||
timeout_seconds: Mapped[int] = mapped_column(Integer, default=30)
|
||||
is_default: Mapped[int] = mapped_column(SmallInteger, default=0)
|
||||
is_enabled: Mapped[int] = mapped_column(SmallInteger, default=1)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now())
|
||||
|
||||
|
||||
class SystemConfig(Base):
|
||||
__tablename__ = "system_configs"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
config_key: Mapped[str] = mapped_column(String(64), nullable=False, unique=True)
|
||||
config_value: Mapped[str | None] = mapped_column(Text)
|
||||
config_type: Mapped[str] = mapped_column(String(16), default="string")
|
||||
description: Mapped[str | None] = mapped_column(String(256))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now())
|
||||
|
||||
|
||||
class LoginLog(Base):
|
||||
__tablename__ = "login_logs"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False, index=True)
|
||||
user_account: Mapped[str | None] = mapped_column(String(128))
|
||||
action: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
session_id: Mapped[str | None] = mapped_column(String(128))
|
||||
ip_address: Mapped[str | None] = mapped_column(String(64))
|
||||
error_msg: Mapped[str | None] = mapped_column(String(512))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), index=True)
|
||||
19
backend/app/models/all_models.py
Normal file
19
backend/app/models/all_models.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# models package - re-export all models
|
||||
from app.models import (
|
||||
VirtualUser, UserPersonality, InteractionRecord,
|
||||
TokenStat, AIModelConfig, SystemConfig, LoginLog
|
||||
)
|
||||
|
||||
# Aliases for import compatibility
|
||||
virtual_user = VirtualUser
|
||||
personality = UserPersonality
|
||||
interaction = InteractionRecord
|
||||
token_stat = TokenStat
|
||||
ai_model = AIModelConfig
|
||||
system_config = SystemConfig
|
||||
login_log = LoginLog
|
||||
|
||||
__all__ = [
|
||||
"VirtualUser", "UserPersonality", "InteractionRecord",
|
||||
"TokenStat", "AIModelConfig", "SystemConfig", "LoginLog",
|
||||
]
|
||||
220
backend/app/schemas/__init__.py
Executable file
220
backend/app/schemas/__init__.py
Executable file
@@ -0,0 +1,220 @@
|
||||
"""Pydantic数据模型 - 请求/响应模式"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ===== 通用响应 =====
|
||||
class ApiResponse(BaseModel):
|
||||
code: int = 200
|
||||
message: str = "success"
|
||||
data: Any = None
|
||||
|
||||
|
||||
class PageResult(BaseModel):
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
items: List[Any]
|
||||
|
||||
|
||||
# ===== 虚拟用户 =====
|
||||
class UserCreateRequest(BaseModel):
|
||||
# 必填
|
||||
account: str = Field(..., min_length=1, max_length=128, description="新闻平台账号(必填)")
|
||||
password: str = Field(..., min_length=6, max_length=64, description="登录密码(必填)")
|
||||
# 选填
|
||||
nickname: Optional[str] = Field(None, max_length=64, description="昵称(选填,为空自动生成)")
|
||||
avatar_url: Optional[str] = None
|
||||
activity_level: int = Field(default=1, ge=0, le=2)
|
||||
daily_comment_limit: int = Field(default=10, ge=1, le=100)
|
||||
daily_like_limit: int = Field(default=30, ge=1, le=200)
|
||||
remark: Optional[str] = None
|
||||
|
||||
|
||||
class UserUpdateRequest(BaseModel):
|
||||
nickname: Optional[str] = Field(None, min_length=1, max_length=64)
|
||||
password: Optional[str] = Field(None, min_length=6, max_length=64)
|
||||
avatar_url: Optional[str] = None
|
||||
real_name: Optional[str] = None
|
||||
sex: Optional[int] = None
|
||||
description: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
activity_level: Optional[int] = Field(None, ge=0, le=2)
|
||||
daily_comment_limit: Optional[int] = Field(None, ge=1, le=100)
|
||||
daily_like_limit: Optional[int] = Field(None, ge=1, le=200)
|
||||
remark: Optional[str] = None
|
||||
is_enabled: Optional[int] = None
|
||||
sync_to_platform: bool = False
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: int
|
||||
nickname: str
|
||||
account: str
|
||||
avatar_url: Optional[str]
|
||||
real_name: Optional[str] = None
|
||||
sex: int = 0
|
||||
platform_uid: Optional[str] = None
|
||||
status: int
|
||||
status_label: str
|
||||
activity_level: int
|
||||
activity_label: str
|
||||
daily_comment_limit: int
|
||||
daily_like_limit: int
|
||||
today_comment_count: int
|
||||
today_like_count: int
|
||||
total_interactions: int
|
||||
last_login_at: Optional[datetime]
|
||||
last_interact_at: Optional[datetime]
|
||||
remark: Optional[str]
|
||||
is_enabled: int
|
||||
created_at: datetime
|
||||
personality: Optional[dict] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class UserBatchRequest(BaseModel):
|
||||
user_ids: List[int]
|
||||
action: str # enable/disable/logout/delete
|
||||
|
||||
|
||||
# ===== 人格 =====
|
||||
class PersonalityUpdateRequest(BaseModel):
|
||||
character_type: Optional[str] = None
|
||||
language_style: Optional[str] = None
|
||||
interest_tags: Optional[List[str]] = None
|
||||
interact_tendency: Optional[str] = None
|
||||
word_count_min: Optional[int] = Field(None, ge=10, le=500)
|
||||
word_count_max: Optional[int] = Field(None, ge=10, le=1000)
|
||||
personality_desc: Optional[str] = None
|
||||
|
||||
|
||||
class PersonalityResponse(BaseModel):
|
||||
id: int
|
||||
user_id: int
|
||||
character_type: Optional[str]
|
||||
language_style: Optional[str]
|
||||
interest_tags: Optional[List[str]]
|
||||
interact_tendency: Optional[str]
|
||||
word_count_min: int
|
||||
word_count_max: int
|
||||
personality_desc: Optional[str]
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ===== 互动记录 =====
|
||||
class InteractionQueryParams(BaseModel):
|
||||
page: int = Field(default=1, ge=1)
|
||||
page_size: int = Field(default=20, ge=1, le=100)
|
||||
user_id: Optional[int] = None
|
||||
interact_type: Optional[str] = None
|
||||
status: Optional[int] = None
|
||||
start_date: Optional[str] = None
|
||||
end_date: Optional[str] = None
|
||||
keyword: Optional[str] = None
|
||||
|
||||
|
||||
class InteractionResponse(BaseModel):
|
||||
id: int
|
||||
user_id: int
|
||||
user_nickname: Optional[str]
|
||||
user_account: Optional[str]
|
||||
article_id: Optional[str]
|
||||
article_title: Optional[str]
|
||||
interact_type: str
|
||||
interact_type_label: str
|
||||
content: Optional[str]
|
||||
token_consumed: int
|
||||
status: int
|
||||
status_label: str
|
||||
error_msg: Optional[str]
|
||||
retry_count: int
|
||||
executed_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ===== AI模型配置 =====
|
||||
class AIModelCreateRequest(BaseModel):
|
||||
model_name: str = Field(..., min_length=1, max_length=64)
|
||||
provider: str = Field(..., pattern="^(openai|zhipu|wenxin|qianwen|local)$")
|
||||
api_base_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
model_version: Optional[str] = None
|
||||
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
|
||||
max_tokens: int = Field(default=1000, ge=1, le=32000)
|
||||
timeout_seconds: int = Field(default=30, ge=5, le=300)
|
||||
is_default: int = Field(default=0, ge=0, le=1)
|
||||
|
||||
|
||||
class AIModelUpdateRequest(BaseModel):
|
||||
model_name: Optional[str] = None
|
||||
api_base_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
model_version: Optional[str] = None
|
||||
temperature: Optional[float] = Field(None, ge=0.0, le=2.0)
|
||||
max_tokens: Optional[int] = Field(None, ge=1, le=32000)
|
||||
timeout_seconds: Optional[int] = Field(None, ge=5, le=300)
|
||||
is_default: Optional[int] = None
|
||||
is_enabled: Optional[int] = None
|
||||
|
||||
|
||||
class AIModelResponse(BaseModel):
|
||||
id: int
|
||||
model_name: str
|
||||
provider: str
|
||||
api_base_url: Optional[str]
|
||||
has_api_key: bool
|
||||
model_version: Optional[str]
|
||||
temperature: float
|
||||
max_tokens: int
|
||||
timeout_seconds: int
|
||||
is_default: int
|
||||
is_enabled: int
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class AIModelTestRequest(BaseModel):
|
||||
model_id: int
|
||||
test_prompt: str = "你好,请简单介绍一下自己。"
|
||||
|
||||
|
||||
# ===== 系统配置 =====
|
||||
class SystemConfigUpdateRequest(BaseModel):
|
||||
configs: dict
|
||||
|
||||
|
||||
# ===== 数据统计 =====
|
||||
class DashboardResponse(BaseModel):
|
||||
user_stats: dict
|
||||
today_interactions: dict
|
||||
monthly_stats: dict
|
||||
token_stats: dict
|
||||
system_status: dict
|
||||
online_users: int
|
||||
|
||||
|
||||
# ===== 调度配置 =====
|
||||
class SchedulerConfigRequest(BaseModel):
|
||||
interact_time_start: Optional[str] = None
|
||||
interact_time_end: Optional[str] = None
|
||||
interact_interval_min: Optional[int] = None
|
||||
interact_interval_max: Optional[int] = None
|
||||
max_concurrent_users: Optional[int] = None
|
||||
daily_token_limit: Optional[int] = None
|
||||
comment_probability: Optional[float] = None
|
||||
reply_probability: Optional[float] = None
|
||||
like_probability: Optional[float] = None
|
||||
collect_probability: Optional[float] = None
|
||||
forward_probability: Optional[float] = None
|
||||
scheduler_enabled: Optional[bool] = None
|
||||
215
backend/app/schemas/__init__.pybuckup
Normal file
215
backend/app/schemas/__init__.pybuckup
Normal file
@@ -0,0 +1,215 @@
|
||||
"""Pydantic数据模型 - 请求/响应模式"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ===== 通用响应 =====
|
||||
class ApiResponse(BaseModel):
|
||||
code: int = 200
|
||||
message: str = "success"
|
||||
data: Any = None
|
||||
|
||||
|
||||
class PageResult(BaseModel):
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
items: List[Any]
|
||||
|
||||
|
||||
# ===== 虚拟用户 =====
|
||||
class UserCreateRequest(BaseModel):
|
||||
# 必填
|
||||
account: str = Field(..., min_length=1, max_length=128, description="新闻平台账号(必填)")
|
||||
password: str = Field(..., min_length=6, max_length=64, description="登录密码(必填)")
|
||||
# 选填
|
||||
nickname: Optional[str] = Field(None, max_length=64, description="昵称(选填,为空自动生成)")
|
||||
avatar_url: Optional[str] = None
|
||||
activity_level: int = Field(default=1, ge=0, le=2)
|
||||
daily_comment_limit: int = Field(default=10, ge=1, le=100)
|
||||
daily_like_limit: int = Field(default=30, ge=1, le=200)
|
||||
remark: Optional[str] = None
|
||||
|
||||
|
||||
class UserUpdateRequest(BaseModel):
|
||||
nickname: Optional[str] = Field(None, min_length=1, max_length=64)
|
||||
password: Optional[str] = Field(None, min_length=6, max_length=64)
|
||||
avatar_url: Optional[str] = None
|
||||
activity_level: Optional[int] = Field(None, ge=0, le=2)
|
||||
daily_comment_limit: Optional[int] = Field(None, ge=1, le=100)
|
||||
daily_like_limit: Optional[int] = Field(None, ge=1, le=200)
|
||||
remark: Optional[str] = None
|
||||
is_enabled: Optional[int] = None
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: int
|
||||
nickname: str
|
||||
account: str
|
||||
avatar_url: Optional[str]
|
||||
real_name: Optional[str] = None
|
||||
sex: int = 0
|
||||
platform_uid: Optional[str] = None
|
||||
status: int
|
||||
status_label: str
|
||||
activity_level: int
|
||||
activity_label: str
|
||||
daily_comment_limit: int
|
||||
daily_like_limit: int
|
||||
today_comment_count: int
|
||||
today_like_count: int
|
||||
total_interactions: int
|
||||
last_login_at: Optional[datetime]
|
||||
last_interact_at: Optional[datetime]
|
||||
remark: Optional[str]
|
||||
is_enabled: int
|
||||
created_at: datetime
|
||||
personality: Optional[dict] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class UserBatchRequest(BaseModel):
|
||||
user_ids: List[int]
|
||||
action: str # enable/disable/logout/delete
|
||||
|
||||
|
||||
# ===== 人格 =====
|
||||
class PersonalityUpdateRequest(BaseModel):
|
||||
character_type: Optional[str] = None
|
||||
language_style: Optional[str] = None
|
||||
interest_tags: Optional[List[str]] = None
|
||||
interact_tendency: Optional[str] = None
|
||||
word_count_min: Optional[int] = Field(None, ge=10, le=500)
|
||||
word_count_max: Optional[int] = Field(None, ge=10, le=1000)
|
||||
personality_desc: Optional[str] = None
|
||||
|
||||
|
||||
class PersonalityResponse(BaseModel):
|
||||
id: int
|
||||
user_id: int
|
||||
character_type: Optional[str]
|
||||
language_style: Optional[str]
|
||||
interest_tags: Optional[List[str]]
|
||||
interact_tendency: Optional[str]
|
||||
word_count_min: int
|
||||
word_count_max: int
|
||||
personality_desc: Optional[str]
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ===== 互动记录 =====
|
||||
class InteractionQueryParams(BaseModel):
|
||||
page: int = Field(default=1, ge=1)
|
||||
page_size: int = Field(default=20, ge=1, le=100)
|
||||
user_id: Optional[int] = None
|
||||
interact_type: Optional[str] = None
|
||||
status: Optional[int] = None
|
||||
start_date: Optional[str] = None
|
||||
end_date: Optional[str] = None
|
||||
keyword: Optional[str] = None
|
||||
|
||||
|
||||
class InteractionResponse(BaseModel):
|
||||
id: int
|
||||
user_id: int
|
||||
user_nickname: Optional[str]
|
||||
user_account: Optional[str]
|
||||
article_id: Optional[str]
|
||||
article_title: Optional[str]
|
||||
interact_type: str
|
||||
interact_type_label: str
|
||||
content: Optional[str]
|
||||
token_consumed: int
|
||||
status: int
|
||||
status_label: str
|
||||
error_msg: Optional[str]
|
||||
retry_count: int
|
||||
executed_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ===== AI模型配置 =====
|
||||
class AIModelCreateRequest(BaseModel):
|
||||
model_name: str = Field(..., min_length=1, max_length=64)
|
||||
provider: str = Field(..., pattern="^(openai|zhipu|wenxin|qianwen|local)$")
|
||||
api_base_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
model_version: Optional[str] = None
|
||||
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
|
||||
max_tokens: int = Field(default=1000, ge=1, le=32000)
|
||||
timeout_seconds: int = Field(default=30, ge=5, le=300)
|
||||
is_default: int = Field(default=0, ge=0, le=1)
|
||||
|
||||
|
||||
class AIModelUpdateRequest(BaseModel):
|
||||
model_name: Optional[str] = None
|
||||
api_base_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
model_version: Optional[str] = None
|
||||
temperature: Optional[float] = Field(None, ge=0.0, le=2.0)
|
||||
max_tokens: Optional[int] = Field(None, ge=1, le=32000)
|
||||
timeout_seconds: Optional[int] = Field(None, ge=5, le=300)
|
||||
is_default: Optional[int] = None
|
||||
is_enabled: Optional[int] = None
|
||||
|
||||
|
||||
class AIModelResponse(BaseModel):
|
||||
id: int
|
||||
model_name: str
|
||||
provider: str
|
||||
api_base_url: Optional[str]
|
||||
has_api_key: bool
|
||||
model_version: Optional[str]
|
||||
temperature: float
|
||||
max_tokens: int
|
||||
timeout_seconds: int
|
||||
is_default: int
|
||||
is_enabled: int
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class AIModelTestRequest(BaseModel):
|
||||
model_id: int
|
||||
test_prompt: str = "你好,请简单介绍一下自己。"
|
||||
|
||||
|
||||
# ===== 系统配置 =====
|
||||
class SystemConfigUpdateRequest(BaseModel):
|
||||
configs: dict
|
||||
|
||||
|
||||
# ===== 数据统计 =====
|
||||
class DashboardResponse(BaseModel):
|
||||
user_stats: dict
|
||||
today_interactions: dict
|
||||
monthly_stats: dict
|
||||
token_stats: dict
|
||||
system_status: dict
|
||||
online_users: int
|
||||
|
||||
|
||||
# ===== 调度配置 =====
|
||||
class SchedulerConfigRequest(BaseModel):
|
||||
interact_time_start: Optional[str] = None
|
||||
interact_time_end: Optional[str] = None
|
||||
interact_interval_min: Optional[int] = None
|
||||
interact_interval_max: Optional[int] = None
|
||||
max_concurrent_users: Optional[int] = None
|
||||
daily_token_limit: Optional[int] = None
|
||||
comment_probability: Optional[float] = None
|
||||
reply_probability: Optional[float] = None
|
||||
like_probability: Optional[float] = None
|
||||
collect_probability: Optional[float] = None
|
||||
forward_probability: Optional[float] = None
|
||||
scheduler_enabled: Optional[bool] = None
|
||||
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
258
backend/app/services/ai_service.py
Normal file
258
backend/app/services/ai_service.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""AI服务 - 人格生成、内容创作"""
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
from typing import Optional
|
||||
import httpx
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from app.models import AIModelConfig, TokenStat
|
||||
from app.utils.crypto import decrypt
|
||||
from app.core.logger import logger
|
||||
from datetime import date
|
||||
|
||||
|
||||
class AIService:
|
||||
"""AI大模型服务"""
|
||||
|
||||
# 人格候选池
|
||||
CHARACTER_TYPES = ["开朗", "内敛", "毒舌", "温和", "理性", "感性", "幽默", "严谨"]
|
||||
LANGUAGE_STYLES = ["严肃", "幽默", "文艺", "吐槽", "口语化", "学术", "简洁", "丰富"]
|
||||
INTEREST_TAGS_POOL = [
|
||||
"科技", "财经", "娱乐", "体育", "政治", "文化", "教育", "医疗",
|
||||
"汽车", "房产", "旅游", "美食", "军事", "国际", "环保", "农业"
|
||||
]
|
||||
INTERACT_TENDENCIES = ["爱评论", "爱点赞", "爱收藏", "潜水", "爱转发", "爱回复"]
|
||||
|
||||
async def _get_default_model(self, db: AsyncSession) -> Optional[AIModelConfig]:
|
||||
result = await db.execute(
|
||||
select(AIModelConfig).where(
|
||||
AIModelConfig.is_default == 1, AIModelConfig.is_enabled == 1
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def _call_api(
|
||||
self, db: AsyncSession, prompt: str, system_prompt: str = None,
|
||||
max_tokens: int = None
|
||||
) -> tuple[str, int]:
|
||||
"""调用AI接口,返回(内容, token数)"""
|
||||
model = await self._get_default_model(db)
|
||||
if not model:
|
||||
# 无模型配置时返回随机预设
|
||||
return "", 0
|
||||
|
||||
api_key = decrypt(model.api_key_enc) if model.api_key_enc else ""
|
||||
base_url = model.api_base_url or "https://api.openai.com/v1"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
payload = {
|
||||
"model": model.model_version or "gpt-3.5-turbo",
|
||||
"messages": messages,
|
||||
"temperature": model.temperature,
|
||||
"max_tokens": max_tokens or model.max_tokens,
|
||||
}
|
||||
|
||||
import asyncio as _asyncio
|
||||
last_err = None
|
||||
for attempt in range(3): # 最多重试3次
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=model.timeout_seconds) as client:
|
||||
resp = await client.post(
|
||||
f"{base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
# 429 限流:等待后重试
|
||||
if resp.status_code == 429:
|
||||
wait = 30 * (attempt + 1) # 30s, 60s, 90s
|
||||
logger.warning(f"AI接口限流(429),{wait}s后重试({attempt+1}/3)")
|
||||
await _asyncio.sleep(wait)
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
text = data["choices"][0]["message"]["content"].strip()
|
||||
tokens = data.get("usage", {}).get("total_tokens", 0)
|
||||
await self._record_token_usage(db, tokens, data.get("usage", {}), model.model_name)
|
||||
logger.bind(ai_call=True).info(
|
||||
f"AI调用成功 model={model.model_name} tokens={tokens}"
|
||||
)
|
||||
return text, tokens
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
if attempt < 2:
|
||||
await _asyncio.sleep(5 * (attempt + 1))
|
||||
logger.error(f"AI调用失败(已重试3次): {last_err}")
|
||||
return "", 0
|
||||
|
||||
async def generate_personality(self, nickname: str, account: str) -> dict:
|
||||
"""生成用户人格(含fallback随机生成)"""
|
||||
# 如无AI配置,使用随机生成
|
||||
from app.core.database import AsyncSessionLocal
|
||||
try:
|
||||
async with AsyncSessionLocal() as db:
|
||||
model = await self._get_default_model(db)
|
||||
if not model:
|
||||
return self._random_personality()
|
||||
|
||||
prompt = f"""请为以下虚拟新闻读者生成一个独特的人格档案,要求真实自然、贴合中国用户特征。
|
||||
用户昵称:{nickname}
|
||||
|
||||
请严格以JSON格式返回,不要有其他内容:
|
||||
{{
|
||||
"character_type": "从[开朗/内敛/毒舌/温和/理性/感性/幽默/严谨]中选一个",
|
||||
"language_style": "从[严肃/幽默/文艺/吐槽/口语化/学术/简洁/丰富]中选一个",
|
||||
"interest_tags": ["兴趣1", "兴趣2", "兴趣3"],
|
||||
"interact_tendency": "从[爱评论/爱点赞/爱收藏/潜水/爱转发/爱回复]中选一个",
|
||||
"word_count_min": 最少字数(10-50整数),
|
||||
"word_count_max": 最多字数(50-200整数),
|
||||
"personality_desc": "一句话描述此人的性格特点(30字以内)"
|
||||
}}"""
|
||||
content, _ = await self._call_api(db, prompt, max_tokens=300)
|
||||
if content:
|
||||
try:
|
||||
# 提取JSON
|
||||
json_match = re.search(r'\{.*\}', content, re.DOTALL)
|
||||
if json_match:
|
||||
return json.loads(json_match.group())
|
||||
except Exception:
|
||||
pass
|
||||
return self._random_personality()
|
||||
except Exception as e:
|
||||
logger.error(f"人格生成失败: {e}")
|
||||
return self._random_personality()
|
||||
|
||||
def _random_personality(self) -> dict:
|
||||
"""随机生成人格(无AI时的备用方案)"""
|
||||
interests = random.sample(self.INTEREST_TAGS_POOL, random.randint(2, 4))
|
||||
char = random.choice(self.CHARACTER_TYPES)
|
||||
style = random.choice(self.LANGUAGE_STYLES)
|
||||
tendency = random.choice(self.INTERACT_TENDENCIES)
|
||||
w_min = random.randint(15, 40)
|
||||
w_max = random.randint(60, 150)
|
||||
return {
|
||||
"character_type": char,
|
||||
"language_style": style,
|
||||
"interest_tags": interests,
|
||||
"interact_tendency": tendency,
|
||||
"word_count_min": w_min,
|
||||
"word_count_max": w_max,
|
||||
"personality_desc": f"一个{char}性格、{tendency}的新闻读者",
|
||||
}
|
||||
|
||||
async def generate_comment(
|
||||
self, db: AsyncSession, article_title: str, article_content: str,
|
||||
personality_prompt: str, word_min: int = 20, word_max: int = 80
|
||||
) -> tuple[str, int]:
|
||||
"""生成文章评论"""
|
||||
system_prompt = f"""你是一名真实的社区用户,正在阅读新闻后发表评论。{personality_prompt}
|
||||
|
||||
重要规则:
|
||||
- 评论必须积极正面、文明友善,绝对不包含任何政治敏感、色情、暴力、侮辱、歧视内容
|
||||
- 不要提及具体政治人物、党派、政策批评、社会矛盾等敏感话题
|
||||
- 内容围绕文章本身展开,表达个人感受、分享观点、提出建设性问题
|
||||
- 语言朴实自然,像普通网友留言,不夸张不煽情"""
|
||||
|
||||
prompt = f"""请根据以下新闻文章写一条评论。
|
||||
|
||||
文章标题:{article_title}
|
||||
文章摘要:{article_content[:200] if article_content else '(无摘要)'}
|
||||
|
||||
要求:
|
||||
1. 评论字数 {word_min}~{word_max} 字
|
||||
2. 内容积极正面,贴近文章主题
|
||||
3. 语气自然真实,符合普通读者口吻
|
||||
4. 必须是完整的句子,不能被截断,以句号/感叹号/问号结尾
|
||||
5. 只输出评论正文,不要加任何前缀或解释
|
||||
|
||||
评论:"""
|
||||
return await self._call_api(db, prompt, system_prompt, max_tokens=300)
|
||||
|
||||
async def generate_reply(
|
||||
self, db: AsyncSession, article_title: str, parent_comment: str,
|
||||
personality_prompt: str, word_min: int = 15, word_max: int = 60
|
||||
) -> tuple[str, int]:
|
||||
"""生成回复"""
|
||||
system_prompt = f"""你是一名真实的社区用户。{personality_prompt}
|
||||
|
||||
重要规则:回复必须积极正面、文明友善,不含任何敏感违规内容。"""
|
||||
prompt = f"""文章:{article_title}
|
||||
原评论:{parent_comment}
|
||||
|
||||
请对上面的评论写一条友善自然的回复,{word_min}~{word_max}字,直接输出回复内容。"""
|
||||
return await self._call_api(db, prompt, system_prompt, max_tokens=150)
|
||||
|
||||
async def test_model(self, db: AsyncSession, model_id: int, test_prompt: str) -> dict:
|
||||
"""测试模型可用性"""
|
||||
result = await db.execute(select(AIModelConfig).where(AIModelConfig.id == model_id))
|
||||
model = result.scalar_one_or_none()
|
||||
if not model:
|
||||
return {"success": False, "error": "模型不存在"}
|
||||
|
||||
api_key = decrypt(model.api_key_enc) if model.api_key_enc else ""
|
||||
base_url = model.api_base_url or "https://api.openai.com/v1"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
payload = {
|
||||
"model": model.model_version or "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": test_prompt}],
|
||||
"max_tokens": 200,
|
||||
}
|
||||
try:
|
||||
import time
|
||||
start = time.time()
|
||||
async with httpx.AsyncClient(timeout=model.timeout_seconds) as client:
|
||||
resp = await client.post(f"{base_url}/chat/completions", headers=headers, json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
elapsed = round(time.time() - start, 2)
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
tokens = data.get("usage", {}).get("total_tokens", 0)
|
||||
return {
|
||||
"success": True, "content": content,
|
||||
"tokens": tokens, "elapsed_seconds": elapsed,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def _record_token_usage(
|
||||
self, db: AsyncSession, total: int, usage: dict, model_name: str
|
||||
):
|
||||
"""记录Token消耗"""
|
||||
today = date.today()
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
try:
|
||||
existing = await db.execute(
|
||||
select(TokenStat).where(TokenStat.stat_date == today)
|
||||
)
|
||||
stat = existing.scalar_one_or_none()
|
||||
if stat:
|
||||
stat.total_tokens += total
|
||||
stat.prompt_tokens += usage.get("prompt_tokens", 0)
|
||||
stat.completion_tokens += usage.get("completion_tokens", 0)
|
||||
stat.call_count += 1
|
||||
else:
|
||||
stat = TokenStat(
|
||||
stat_date=today,
|
||||
model_name=model_name,
|
||||
total_tokens=total,
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
call_count=1,
|
||||
)
|
||||
db.add(stat)
|
||||
except Exception as e:
|
||||
logger.error(f"记录Token消耗失败: {e}")
|
||||
|
||||
|
||||
ai_service = AIService()
|
||||
729
backend/app/services/news_service.py
Executable file
729
backend/app/services/news_service.py
Executable file
@@ -0,0 +1,729 @@
|
||||
"""
|
||||
新闻平台对接服务
|
||||
登录: POST {auth}/open/login/token (formData)
|
||||
签名: 完全对应 sign.js 的实现
|
||||
"""
|
||||
import uuid
|
||||
import hashlib
|
||||
import hmac
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
import httpx
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from app.models import VirtualUser, SystemConfig, LoginLog
|
||||
from app.utils.crypto import decrypt
|
||||
from app.core.redis_client import set_session, get_session, delete_session
|
||||
from app.core.logger import logger
|
||||
|
||||
|
||||
class NewsPlatformService:
|
||||
|
||||
# ─── 配置读取 ──────────────────────────────────────────────
|
||||
async def _cfg(self, db: AsyncSession, key: str, default: str = "") -> str:
|
||||
result = await db.execute(select(SystemConfig).where(SystemConfig.config_key == key))
|
||||
row = result.scalar_one_or_none()
|
||||
return row.config_value if row else default
|
||||
|
||||
async def _biz_url(self, db: AsyncSession) -> str:
|
||||
return await self._cfg(db, "news_platform_base_url", "http://192.168.1.200:63120")
|
||||
|
||||
async def _auth_url(self, db: AsyncSession) -> str:
|
||||
return await self._cfg(db, "auth_base_url", "http://192.168.1.200:60040")
|
||||
|
||||
async def _client(self, db: AsyncSession) -> dict:
|
||||
return {
|
||||
"appId": await self._cfg(db, "platform_app_id", ""),
|
||||
"accessId": await self._cfg(db, "platform_access_id", ""),
|
||||
"accessSecret": await self._cfg(db, "platform_access_secret", ""),
|
||||
"clientCode": await self._cfg(db, "platform_client_code", ""),
|
||||
"orgId": await self._cfg(db, "platform_org_id", ""),
|
||||
}
|
||||
|
||||
# ─── 签名(完全对应 sign.js 逻辑) ─────────────────────────
|
||||
@staticmethod
|
||||
def _make_sign(params: dict, secret_key: str, sign_type: str = "MD5") -> str:
|
||||
"""
|
||||
完全对应 sign.js:
|
||||
1. 所有参数 key 排序
|
||||
2. 过滤掉 signature / accessSecret / 空值 / 空数组
|
||||
3. 拼接 key=value& ... accessSecret=secretKey
|
||||
4. MD5/SHA256 大写
|
||||
"""
|
||||
SIGN_KEY = "signature"
|
||||
SECRET_KEY = "accessSecret"
|
||||
|
||||
keys = sorted(params.keys())
|
||||
str_parts = []
|
||||
for k in keys:
|
||||
if k in (SIGN_KEY, SECRET_KEY):
|
||||
continue
|
||||
v = params.get(k)
|
||||
if v is None or v == "" or v == []:
|
||||
continue
|
||||
if isinstance(v, list):
|
||||
continue
|
||||
str_parts.append(f"{k}={v}")
|
||||
|
||||
sign_str = "&".join(str_parts) + f"&{SECRET_KEY}={secret_key}"
|
||||
|
||||
if sign_type.upper() == "SHA256":
|
||||
return hashlib.sha256(sign_str.encode("utf-8")).hexdigest().upper()
|
||||
else:
|
||||
return hashlib.md5(sign_str.encode("utf-8")).hexdigest().upper()
|
||||
|
||||
@staticmethod
|
||||
def _get_nonce() -> str:
|
||||
import random, math
|
||||
return str(random.random())[2:][: random.randint(8, 12)]
|
||||
|
||||
@staticmethod
|
||||
def _get_timestamp() -> str:
|
||||
"""yyyyMMddhhmmss — 注意 sign.js 用小写 hh(12小时制)"""
|
||||
now = datetime.now()
|
||||
# sign.js 用 yyyyMMddhhmmss (12小时制小写hh)
|
||||
return now.strftime("%Y%m%d%I%M%S")
|
||||
|
||||
def _build_form(self, extra: dict, cfg: dict) -> dict:
|
||||
"""构建带签名的 formData"""
|
||||
sign_type = "MD5"
|
||||
sign_version = "1.0"
|
||||
secret_key = cfg.get("accessSecret", "")
|
||||
|
||||
base = {
|
||||
"appId": cfg.get("appId", ""),
|
||||
"accessId": cfg.get("accessId", ""),
|
||||
"timestamp": self._get_timestamp(),
|
||||
"signType": sign_type,
|
||||
"signVersion": sign_version,
|
||||
"accessSecret": secret_key,
|
||||
"nonce": self._get_nonce(),
|
||||
}
|
||||
base.update(extra)
|
||||
|
||||
# 计算签名
|
||||
signature = self._make_sign(base, secret_key, sign_type) if secret_key else ""
|
||||
base["signature"] = signature
|
||||
|
||||
# 移除 accessSecret(不发送到服务器)
|
||||
base.pop("accessSecret", None)
|
||||
return base
|
||||
|
||||
# ─── 登录 ──────────────────────────────────────────────────
|
||||
async def login(self, db: AsyncSession, user: VirtualUser) -> bool:
|
||||
password = decrypt(user.password_enc)
|
||||
if not password:
|
||||
logger.error(f"[登录] {user.account} 密码解密失败")
|
||||
return False
|
||||
|
||||
auth = await self._auth_url(db)
|
||||
cfg = await self._client(db)
|
||||
|
||||
await db.execute(update(VirtualUser).where(VirtualUser.id == user.id).values(status=1))
|
||||
await db.commit()
|
||||
|
||||
extra = {
|
||||
"username": user.account,
|
||||
"password": password,
|
||||
"loginType": "password",
|
||||
"grantType": "password",
|
||||
"isRegister": "false",
|
||||
}
|
||||
if cfg.get("clientCode"):
|
||||
extra["clientCode"] = cfg["clientCode"]
|
||||
|
||||
form = self._build_form(extra, cfg)
|
||||
|
||||
exc = None
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=30,
|
||||
follow_redirects=True, # 自动跟随 HTTP 重定向
|
||||
) as c:
|
||||
# 登录接口路径:需要加 /usercenter 前缀(通过网关路由)
|
||||
# auth_base_url 配置为完整前缀,如 https://fat-open.99hui.com/api/usercenter
|
||||
login_url = f"{auth}/open/login/token"
|
||||
resp = await c.post(login_url, data=form)
|
||||
|
||||
# 详细记录原始响应,便于排查
|
||||
logger.info(
|
||||
f"[登录] {user.account} 原始响应: "
|
||||
f"status={resp.status_code} "
|
||||
f"content-type={resp.headers.get('content-type','')} "
|
||||
f"body={resp.text[:500]}"
|
||||
)
|
||||
|
||||
# 防止空响应体崩溃
|
||||
if not resp.text.strip():
|
||||
logger.warning(f"[登录] {user.account} 服务器返回空响应体,接口可能不存在或被重定向")
|
||||
raise ValueError(f"服务器返回空响应 HTTP={resp.status_code}")
|
||||
|
||||
# 尝试解析 JSON
|
||||
try:
|
||||
body = resp.json()
|
||||
except Exception as je:
|
||||
logger.warning(f"[登录] {user.account} 响应非JSON: {resp.text[:200]}")
|
||||
raise ValueError(f"响应非JSON: {resp.text[:100]}")
|
||||
|
||||
if resp.status_code == 200 and body.get("code") in [0, 200]:
|
||||
raw = body.get("data")
|
||||
access_token, platform_uid = self._extract_token(raw)
|
||||
|
||||
if access_token:
|
||||
sid = str(uuid.uuid4())
|
||||
# 登录成功后尝试获取用户组织信息
|
||||
org_id = await self._fetch_org_id(db, access_token, platform_uid, cfg)
|
||||
if org_id and not cfg.get("orgId"):
|
||||
# 如果系统没有配置 orgId,则自动保存用户所属组织
|
||||
await self._save_org_id(db, org_id)
|
||||
|
||||
# 从登录响应中提取用户信息
|
||||
user_info = raw.get("userInfo", {}) if isinstance(raw, dict) else {}
|
||||
sync_nickname = user_info.get("nickName") or user_info.get("username") or ""
|
||||
sync_real_name = user_info.get("realName") or ""
|
||||
sync_sex = int(user_info.get("sex") or 0)
|
||||
sync_avatar = user_info.get("avatar") or ""
|
||||
|
||||
await set_session(user.id, {
|
||||
"token": access_token,
|
||||
"session_id": sid,
|
||||
"platform_uid": platform_uid,
|
||||
"org_id": org_id or cfg.get("orgId", ""),
|
||||
"login_time": datetime.now().isoformat(),
|
||||
# 缓存用户信息供 sync 使用
|
||||
"nickname": sync_nickname,
|
||||
"real_name": sync_real_name,
|
||||
"sex": sync_sex,
|
||||
"avatar": sync_avatar,
|
||||
}, expire=86400)
|
||||
|
||||
# 更新本地数据库,同步平台用户信息
|
||||
update_vals = dict(
|
||||
status=2, session_token=access_token,
|
||||
session_expires_at=datetime.now() + timedelta(hours=24),
|
||||
last_login_at=datetime.now(),
|
||||
platform_uid=platform_uid,
|
||||
)
|
||||
if sync_nickname: update_vals["nickname"] = sync_nickname
|
||||
if sync_real_name: update_vals["real_name"] = sync_real_name
|
||||
if sync_sex: update_vals["sex"] = sync_sex
|
||||
if sync_avatar: update_vals["avatar_url"] = sync_avatar
|
||||
await db.execute(update(VirtualUser).where(VirtualUser.id == user.id).values(**update_vals))
|
||||
await db.commit()
|
||||
await self._write_login_log(db, user, "login", sid)
|
||||
logger.info(f"✅ [登录] {user.account} 成功 orgId={org_id}")
|
||||
return True
|
||||
|
||||
logger.warning(f"[登录] {user.account} 无token: {body}")
|
||||
else:
|
||||
logger.warning(f"[登录] {user.account} 失败: HTTP={resp.status_code} {body}")
|
||||
|
||||
except Exception as e:
|
||||
exc = e
|
||||
logger.error(f"[登录] {user.account} 异常: {e}")
|
||||
|
||||
await db.execute(update(VirtualUser).where(VirtualUser.id == user.id).values(status=3))
|
||||
await db.commit()
|
||||
await self._write_login_log(db, user, "fail", error_msg=str(exc or "登录失败"))
|
||||
return False
|
||||
|
||||
async def _fetch_org_id(
|
||||
self, db: AsyncSession, token: str, platform_uid: str, cfg: dict
|
||||
) -> str:
|
||||
"""登录成功后,调用接口获取用户所属组织 orgId"""
|
||||
biz = await self._biz_url(db)
|
||||
headers = self._bearer(token)
|
||||
# 尝试常见的用户信息接口
|
||||
endpoints = [
|
||||
f"/app/user/info",
|
||||
f"/open/user/info",
|
||||
f"/user/info",
|
||||
]
|
||||
for ep in endpoints:
|
||||
try:
|
||||
form = self._build_form({}, cfg)
|
||||
async with httpx.AsyncClient(timeout=8) as c:
|
||||
r = await c.get(
|
||||
f"{biz}{ep}",
|
||||
headers=headers,
|
||||
params={k: v for k, v in form.items() if k not in ("username","password")}
|
||||
)
|
||||
if r.status_code == 200:
|
||||
d = r.json()
|
||||
data = d.get("data") or {}
|
||||
# 从各种可能的字段中提取 orgId
|
||||
org_id = (
|
||||
data.get("orgId") or data.get("defaultOrgId") or
|
||||
data.get("org", {}).get("id") if isinstance(data.get("org"), dict) else None
|
||||
)
|
||||
if org_id:
|
||||
return str(org_id)
|
||||
except Exception as e:
|
||||
logger.debug(f"获取orgId失败({ep}): {e}")
|
||||
return ""
|
||||
|
||||
async def _save_org_id(self, db: AsyncSession, org_id: str):
|
||||
"""自动保存获取到的 orgId 到系统配置"""
|
||||
result = await db.execute(
|
||||
select(SystemConfig).where(SystemConfig.config_key == "platform_org_id")
|
||||
)
|
||||
cfg = result.scalar_one_or_none()
|
||||
if cfg:
|
||||
cfg.config_value = org_id
|
||||
else:
|
||||
db.add(SystemConfig(
|
||||
config_key="platform_org_id",
|
||||
config_value=org_id,
|
||||
description="平台组织Id(自动获取)"
|
||||
))
|
||||
await db.commit()
|
||||
logger.info(f"✅ 自动获取并保存 orgId={org_id}")
|
||||
|
||||
@staticmethod
|
||||
def _extract_token(raw) -> tuple[str, str]:
|
||||
if isinstance(raw, str) and raw:
|
||||
return raw, ""
|
||||
if isinstance(raw, dict):
|
||||
token = (raw.get("access_token") or raw.get("accessToken") or raw.get("token") or "")
|
||||
# openid 是平台用户ID(登录响应里 data.openid = data.userInfo.id)
|
||||
uid = str(
|
||||
raw.get("openid") or
|
||||
raw.get("userId") or raw.get("user_id") or raw.get("id") or
|
||||
(raw.get("userInfo") or {}).get("id") or ""
|
||||
)
|
||||
return token, uid
|
||||
return "", ""
|
||||
|
||||
async def logout(self, db: AsyncSession, user_id: int):
|
||||
user_r = await db.execute(select(VirtualUser).where(VirtualUser.id == user_id))
|
||||
user = user_r.scalar_one_or_none()
|
||||
if user:
|
||||
sess = await get_session(user_id)
|
||||
await self._write_login_log(db, user, "logout",
|
||||
sess.get("session_id") if sess else None)
|
||||
await delete_session(user_id)
|
||||
await db.execute(update(VirtualUser).where(VirtualUser.id == user_id).values(
|
||||
status=0, session_token=None, session_expires_at=None))
|
||||
await db.commit()
|
||||
|
||||
async def check_session(self, db: AsyncSession, user: VirtualUser) -> bool:
|
||||
sess = await get_session(user.id)
|
||||
if not sess:
|
||||
return False
|
||||
biz = await self._biz_url(db)
|
||||
cfg = await self._client(db)
|
||||
try:
|
||||
params = self._build_form({}, cfg)
|
||||
params.update({"orgId": cfg["orgId"] or "1", "pageNum": 1, "pageSize": 1, "status": "approved"})
|
||||
async with httpx.AsyncClient(timeout=10) as c:
|
||||
r = await c.get(f"{biz}/news/list", headers=self._bearer(sess["token"]), params=params)
|
||||
if r.status_code == 200 and r.json().get("code") in [0, 200]:
|
||||
return True
|
||||
await self.logout(db, user.id)
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# ─── 新闻列表 ──────────────────────────────────────────────
|
||||
async def get_news_list(self, db, user, count=5, interest_tags=None) -> list:
|
||||
"""
|
||||
GET /business/member/square/list 广场数据分页查询
|
||||
type=1 表示新闻,orgId 选填(不填则查全平台新闻,无需配置 orgId)
|
||||
返回字段:id(广场ID), recordId(新闻实际ID), title, orgId, orgName
|
||||
"""
|
||||
sess = await get_session(user.id)
|
||||
if not sess:
|
||||
return []
|
||||
biz = await self._biz_url(db)
|
||||
cfg = await self._client(db)
|
||||
org_id = sess.get("org_id") or cfg.get("orgId") or ""
|
||||
|
||||
# 先查总数,再随机翻页,避免每次都取第1页相同内容
|
||||
import math
|
||||
# 第一次查询获取总页数
|
||||
first_params = self._build_form({
|
||||
"pageNum": 1,
|
||||
"pageSize": 50,
|
||||
"type": "1",
|
||||
"isPlatformShow": "true",
|
||||
"isAdmin": "false",
|
||||
}, cfg)
|
||||
if org_id:
|
||||
first_params["orgId"] = org_id
|
||||
|
||||
total_pages = 1
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as _c:
|
||||
_r = await _c.get(
|
||||
f"{biz}/business/member/square/list",
|
||||
headers=self._bearer(sess["token"]),
|
||||
params=first_params
|
||||
)
|
||||
_d = _r.json()
|
||||
if _d.get("code") in [0, 200]:
|
||||
total_size = _d.get("data", {}).get("totalSize", 0)
|
||||
total_pages = max(1, math.ceil(total_size / 50))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 随机选择一页
|
||||
import random as _random
|
||||
rand_page = _random.randint(1, min(total_pages, 10)) # 最多取前10页随机
|
||||
|
||||
params = self._build_form({
|
||||
"pageNum": rand_page,
|
||||
"pageSize": 50,
|
||||
"type": "1",
|
||||
"isPlatformShow": "true",
|
||||
"isAdmin": "false",
|
||||
}, cfg)
|
||||
if org_id:
|
||||
params["orgId"] = org_id # 选填,有则按组织过滤
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15) as c:
|
||||
r = await c.get(
|
||||
f"{biz}/business/member/square/list",
|
||||
headers=self._bearer(sess["token"]),
|
||||
params=params
|
||||
)
|
||||
if r.status_code == 200:
|
||||
d = r.json()
|
||||
if d.get("code") in [0, 200]:
|
||||
nd = d.get("data", {})
|
||||
items = nd.get("data") or nd.get("list") or nd.get("records") or []
|
||||
# 过滤本人发布的文章
|
||||
platform_uid = sess.get("platform_uid", "")
|
||||
if platform_uid:
|
||||
items = [x for x in items if x.get("createUser") != platform_uid]
|
||||
# 过滤已知无效新闻(详情为空或不存在)
|
||||
INVALID_IDS = {
|
||||
"1965670408480907266","2029092495693975554","1960652956793597953",
|
||||
"1960651987045347330","1960596408620838914","1960596083193180161",
|
||||
"1960595664341594113",
|
||||
}
|
||||
items = [x for x in items
|
||||
if (x.get("recordId") or x.get("id")) not in INVALID_IDS]
|
||||
logger.info(f"[广场新闻] {user.account} 获取到 {len(items)} 条(已过滤本人+无效文章)")
|
||||
import random as _rand
|
||||
return _rand.sample(items, min(count, len(items))) if items else []
|
||||
logger.warning(f"[广场新闻] {user.account} code={d.get('code')} msg={d.get('message')}")
|
||||
except Exception as e:
|
||||
logger.error(f"[广场新闻] {user.account}: {e}")
|
||||
return []
|
||||
|
||||
async def read_news(self, db, user, news_id: str) -> bool:
|
||||
sess = await get_session(user.id)
|
||||
if not sess:
|
||||
return False
|
||||
biz = await self._biz_url(db)
|
||||
cfg = await self._client(db)
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as c:
|
||||
r = await c.patch(
|
||||
f"{biz}/news/read/{news_id}",
|
||||
headers=self._bearer(sess["token"]),
|
||||
data=self._build_form({}, cfg),
|
||||
)
|
||||
return r.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def post_comment(self, db, user, news_id, news_title, content, news_author_id="", org_id="") -> tuple[bool, str]:
|
||||
sess = await get_session(user.id)
|
||||
if not sess:
|
||||
return False, "未登录"
|
||||
biz = await self._biz_url(db)
|
||||
cfg = await self._client(db)
|
||||
uid = sess.get("platform_uid", "")
|
||||
# org_id 优先取文章自带的(从广场数据获取),否则取 session/配置
|
||||
final_org_id = org_id or sess.get("org_id") or cfg.get("orgId") or ""
|
||||
body = {
|
||||
"module": "news", "topicId": news_id, "title": news_title,
|
||||
"content": content, "orgId": final_org_id,
|
||||
"toUserId": news_author_id or uid, "userId": uid,
|
||||
"userName": user.nickname, "avatar": user.avatar_url or "",
|
||||
}
|
||||
return await self._json_post(f"{biz}/message/comment", self._bearer(sess["token"]), body)
|
||||
|
||||
async def post_reply(self, db, user, news_id, comment_id, content) -> tuple[bool, str]:
|
||||
sess = await get_session(user.id)
|
||||
if not sess:
|
||||
return False, "未登录"
|
||||
biz = await self._biz_url(db)
|
||||
uid = sess.get("platform_uid", "")
|
||||
body = {
|
||||
"module": "news", "topicId": news_id, "commentId": comment_id,
|
||||
"commentUserId": uid, "content": content,
|
||||
"fromUserName": user.nickname, "avatar": user.avatar_url or "",
|
||||
}
|
||||
return await self._json_post(f"{biz}/message/comment/reply", self._bearer(sess["token"]), body)
|
||||
|
||||
async def get_comments(self, db, user, news_id) -> list:
|
||||
sess = await get_session(user.id)
|
||||
if not sess:
|
||||
return []
|
||||
biz = await self._biz_url(db)
|
||||
cfg = await self._client(db)
|
||||
try:
|
||||
params = self._build_form({"module": "news", "topicId": news_id, "pageNum": 1, "pageSize": 20}, cfg)
|
||||
async with httpx.AsyncClient(timeout=10) as c:
|
||||
r = await c.get(f"{biz}/message/comment", headers=self._bearer(sess["token"]), params=params)
|
||||
if r.status_code == 200:
|
||||
return r.json().get("data", {}).get("data") or []
|
||||
except Exception:
|
||||
pass
|
||||
return []
|
||||
|
||||
async def like_news(self, db, user, news_id, org_id="", to_user_id="", title="") -> tuple[bool, str]:
|
||||
sess = await get_session(user.id)
|
||||
if not sess:
|
||||
return False, "未登录"
|
||||
biz = await self._biz_url(db)
|
||||
uid = sess.get("platform_uid", "")
|
||||
body = {
|
||||
"module": "news",
|
||||
"topicId": news_id,
|
||||
"userId": uid,
|
||||
"toUserId": to_user_id or uid,
|
||||
"orgId": org_id or sess.get("org_id", "") or "",
|
||||
"title": title,
|
||||
}
|
||||
return await self._json_post(f"{biz}/message/praise", self._bearer(sess["token"]), body)
|
||||
|
||||
async def collect_news(self, db, user, news_id, org_id="", to_user_id="", title="") -> tuple[bool, str]:
|
||||
"""收藏新闻:复用点赞接口(平台收藏=点赞同一接口)"""
|
||||
return await self.like_news(db, user, news_id, org_id=org_id, to_user_id=to_user_id, title=title)
|
||||
|
||||
async def forward_news(self, db, user, news_id) -> tuple[bool, str]:
|
||||
sess = await get_session(user.id)
|
||||
if not sess:
|
||||
return False, "未登录"
|
||||
biz = await self._biz_url(db)
|
||||
cfg = await self._client(db)
|
||||
org_id = sess.get("org_id") or cfg.get("orgId") or "1"
|
||||
headers = self._bearer(sess["token"])
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=8) as c:
|
||||
await c.get(f"{biz}/news/share/wechat/{news_id}", headers=headers,
|
||||
params=self._build_form({}, cfg))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15) as c:
|
||||
r = await c.post(
|
||||
f"{biz}/points/forward/news/{org_id}",
|
||||
headers=headers,
|
||||
data=self._build_form({}, cfg),
|
||||
)
|
||||
return self._ok(r)
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
@staticmethod
|
||||
def _bearer(token: str) -> dict:
|
||||
return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||||
|
||||
@staticmethod
|
||||
def _ok(resp: httpx.Response) -> tuple[bool, str]:
|
||||
if resp.status_code in [200, 201]:
|
||||
try:
|
||||
d = resp.json()
|
||||
if d.get("code") in [0, 200]:
|
||||
return True, ""
|
||||
return False, d.get("message") or "业务失败"
|
||||
except Exception:
|
||||
return True, ""
|
||||
return False, f"HTTP {resp.status_code}"
|
||||
|
||||
async def _json_post(self, url, headers, body) -> tuple[bool, str]:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15) as c:
|
||||
r = await c.post(url, json=body, headers=headers)
|
||||
return self._ok(r)
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
async def _write_login_log(self, db, user, action, session_id=None, error_msg=None):
|
||||
try:
|
||||
log = LoginLog(
|
||||
user_id=user.id, user_account=user.account,
|
||||
action=action, session_id=session_id, error_msg=error_msg,
|
||||
)
|
||||
db.add(log)
|
||||
await db.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ─── 取消互动 ────────────────────────────────────────────────────
|
||||
async def cancel_like(self, db, user, news_id: str, org_id: str = "", to_user_id: str = "", title: str = "") -> tuple[bool, str]:
|
||||
"""DELETE /message/praise/cancel 取消点赞"""
|
||||
sess = await get_session(user.id)
|
||||
if not sess:
|
||||
return False, "未登录"
|
||||
biz = await self._biz_url(db)
|
||||
uid = sess.get("platform_uid", "")
|
||||
body = {
|
||||
"module": "news",
|
||||
"topicId": news_id,
|
||||
"userId": uid,
|
||||
"toUserId": to_user_id or uid,
|
||||
"orgId": org_id or sess.get("org_id", "") or "",
|
||||
"title": title,
|
||||
}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as c:
|
||||
r = await c.delete(
|
||||
f"{biz}/message/praise/cancel",
|
||||
json=body,
|
||||
headers=self._bearer(sess["token"])
|
||||
)
|
||||
d = r.json()
|
||||
if d.get("code") in [0, 200]:
|
||||
return True, ""
|
||||
return False, d.get("message", "取消点赞失败")
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
async def cancel_comment(self, db, user, news_id: str, comment_id: str) -> tuple[bool, str]:
|
||||
"""DELETE /message/comment/{topicId}/{id} 删除评论"""
|
||||
sess = await get_session(user.id)
|
||||
if not sess:
|
||||
return False, "未登录"
|
||||
biz = await self._biz_url(db)
|
||||
cfg = await self._client(db)
|
||||
# 签名参数放 formData,路径里是 topicId 和 comment_id
|
||||
params = self._build_form({}, cfg)
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as c:
|
||||
r = await c.delete(
|
||||
f"{biz}/message/comment/{news_id}/{comment_id}",
|
||||
headers=self._bearer(sess["token"]),
|
||||
params=params
|
||||
)
|
||||
d = r.json()
|
||||
if d.get("code") in [0, 200]:
|
||||
return True, ""
|
||||
return False, d.get("message", "删除评论失败")
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
async def cancel_collect(self, db, user, news_id: str, org_id: str = "", to_user_id: str = "", title: str = "") -> tuple[bool, str]:
|
||||
"""取消收藏(复用取消点赞接口)"""
|
||||
return await self.cancel_like(db, user, news_id, org_id=org_id, to_user_id=to_user_id, title=title)
|
||||
|
||||
# ─── 修改目标系统用户信息 ─────────────────────────────────────
|
||||
async def update_user_profile(
|
||||
self, db: AsyncSession, user: VirtualUser,
|
||||
nick_name: str = None, real_name: str = None,
|
||||
sex: int = None, avatar: str = None,
|
||||
description: str = None, email: str = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
调用 POST /usercenter/user/change/userInfo 修改用户信息
|
||||
支持:昵称、真实姓名、性别、头像、简介、邮箱
|
||||
同时同步更新本地数据库
|
||||
"""
|
||||
sess = await get_session(user.id)
|
||||
if not sess:
|
||||
return False, "用户未登录,请先登录"
|
||||
|
||||
auth = await self._auth_url(db)
|
||||
cfg = await self._client(db)
|
||||
token = sess.get("token", "")
|
||||
platform_uid = sess.get("platform_uid", "")
|
||||
|
||||
if not platform_uid:
|
||||
return False, "缺少平台用户ID,请重新登录"
|
||||
|
||||
# 构建请求体(只传有值的字段)
|
||||
# 构建请求体,确保至少有 nickName 字段(平台 SQL 要求 SET 子句不为空)
|
||||
body = {
|
||||
"id": platform_uid,
|
||||
"nickName": nick_name if nick_name is not None else (user.nickname or ""),
|
||||
"realName": real_name if real_name is not None else (user.real_name or ""),
|
||||
"sex": sex if sex is not None else (user.sex or 0),
|
||||
}
|
||||
if avatar is not None: body["avatar"] = avatar
|
||||
if description is not None: body["description"] = description
|
||||
if email is not None: body["email"] = email
|
||||
|
||||
# 使用 PATCH /v2/users/current 接口(支持修改昵称)
|
||||
headers = dict(self._bearer(token))
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15) as c:
|
||||
r = await c.patch(
|
||||
f"{auth}/v2/users/current",
|
||||
json=body,
|
||||
headers=headers,
|
||||
)
|
||||
d = r.json()
|
||||
if d.get("code") in [0, 200]:
|
||||
# 同步到本地数据库
|
||||
local_vals = {}
|
||||
if nick_name is not None: local_vals["nickname"] = nick_name
|
||||
if real_name is not None: local_vals["real_name"] = real_name
|
||||
if sex is not None: local_vals["sex"] = sex
|
||||
if avatar is not None: local_vals["avatar_url"] = avatar
|
||||
if local_vals:
|
||||
from sqlalchemy import update
|
||||
await db.execute(update(VirtualUser).where(
|
||||
VirtualUser.id == user.id).values(**local_vals))
|
||||
await db.commit()
|
||||
logger.info(f"✅ 用户 {user.account} 信息已同步到目标系统")
|
||||
return True, ""
|
||||
err = d.get("message") or f"code={d.get('code')}"
|
||||
logger.warning(f"[修改用户信息] {user.account} 失败: {err} body={r.text[:200]}")
|
||||
return False, err
|
||||
except Exception as e:
|
||||
logger.warning(f"[修改用户信息] {user.account} 异常: {e}")
|
||||
return False, str(e)
|
||||
|
||||
async def upload_avatar(
|
||||
self, db: AsyncSession, user: VirtualUser, file_bytes: bytes, filename: str
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
上传头像图片到平台 filecenter,返回图片 URL
|
||||
POST /filecenter/fileUpload (multipart/form-data)
|
||||
"""
|
||||
sess = await get_session(user.id)
|
||||
if not sess:
|
||||
return False, "用户未登录"
|
||||
|
||||
cfg = await self._client(db)
|
||||
token = sess.get("token", "")
|
||||
|
||||
# filecenter 服务地址
|
||||
biz_base = await self._biz_url(db)
|
||||
# filecenter 与 huihuibusiness 同网关,替换服务名
|
||||
filecenter_url = biz_base.replace("/huihuibusiness", "/filecenter")
|
||||
|
||||
sign_params = self._build_form({"module": "userInfo", "service": "kccloud"}, cfg)
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
try:
|
||||
import mimetypes
|
||||
mime = mimetypes.guess_type(filename)[0] or "image/jpeg"
|
||||
files = {"file": (filename, file_bytes, mime)}
|
||||
async with httpx.AsyncClient(timeout=30) as c:
|
||||
r = await c.post(
|
||||
f"{filecenter_url}/fileUpload",
|
||||
files=files,
|
||||
data=sign_params,
|
||||
headers=headers,
|
||||
)
|
||||
d = r.json()
|
||||
if d.get("code") in [0, 200]:
|
||||
url = d.get("data") or d.get("url") or ""
|
||||
if isinstance(url, dict):
|
||||
url = url.get("url") or url.get("path") or ""
|
||||
logger.info(f"✅ 头像上传成功: {url}")
|
||||
return True, url
|
||||
return False, d.get("message") or "上传失败"
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
news_service = NewsPlatformService()
|
||||
402
backend/app/services/scheduler.py
Normal file
402
backend/app/services/scheduler.py
Normal file
@@ -0,0 +1,402 @@
|
||||
"""调度服务 - 定时自动互动、会话校验"""
|
||||
import random
|
||||
import asyncio
|
||||
from datetime import datetime, date
|
||||
from typing import Optional
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from sqlalchemy import select, update, func
|
||||
|
||||
from app.core.database import AsyncSessionLocal
|
||||
from app.core.logger import logger
|
||||
from app.models import VirtualUser, UserPersonality, InteractionRecord, SystemConfig
|
||||
|
||||
|
||||
class SchedulerService:
|
||||
def __init__(self):
|
||||
self.scheduler = AsyncIOScheduler(timezone="Asia/Shanghai")
|
||||
self._running = False
|
||||
|
||||
async def run_once_now(self, db=None):
|
||||
"""立即执行一次互动,不受时间段限制"""
|
||||
from sqlalchemy import select
|
||||
from app.core.database import AsyncSessionLocal
|
||||
logger.info("⚡ 立即触发互动任务")
|
||||
async with AsyncSessionLocal() as session:
|
||||
result_r = await session.execute(
|
||||
select(VirtualUser).where(
|
||||
VirtualUser.status == 2,
|
||||
VirtualUser.is_enabled == 1,
|
||||
)
|
||||
)
|
||||
users = result_r.scalars().all()
|
||||
if not users:
|
||||
return {"message": "没有已登录的用户", "triggered": 0}
|
||||
import random
|
||||
selected = random.sample(users, min(5, len(users)))
|
||||
import asyncio
|
||||
tasks = [self._execute_user_interaction(u.id) for u in selected]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
return {"triggered": len(selected), "users": [u.account for u in selected]}
|
||||
|
||||
async def start(self):
|
||||
if self._running:
|
||||
return
|
||||
# 会话校验:每10分钟
|
||||
self.scheduler.add_job(
|
||||
self._check_sessions, IntervalTrigger(minutes=10),
|
||||
id="check_sessions", replace_existing=True
|
||||
)
|
||||
# 互动任务:每5分钟检查一次(内部判断是否在活跃时间段)
|
||||
self.scheduler.add_job(
|
||||
self._run_interactions, IntervalTrigger(minutes=5),
|
||||
id="run_interactions", replace_existing=True
|
||||
)
|
||||
# 每日零点重置计数
|
||||
self.scheduler.add_job(
|
||||
self._daily_reset, "cron", hour=16, minute=0, # 北京时间 00:00 = UTC 16:00
|
||||
id="daily_reset", replace_existing=True
|
||||
)
|
||||
self.scheduler.start()
|
||||
self._running = True
|
||||
logger.info("调度器已启动")
|
||||
# 记录启动时间
|
||||
async with AsyncSessionLocal() as db:
|
||||
await self._set_config(db, "system_start_time", datetime.now().isoformat())
|
||||
|
||||
async def stop(self):
|
||||
if self.scheduler.running:
|
||||
self.scheduler.shutdown(wait=False)
|
||||
self._running = False
|
||||
|
||||
async def _get_config(self, db, key: str, default=None):
|
||||
result = await db.execute(select(SystemConfig).where(SystemConfig.config_key == key))
|
||||
cfg = result.scalar_one_or_none()
|
||||
return cfg.config_value if cfg else default
|
||||
|
||||
async def _set_config(self, db, key: str, value: str):
|
||||
result = await db.execute(select(SystemConfig).where(SystemConfig.config_key == key))
|
||||
cfg = result.scalar_one_or_none()
|
||||
if cfg:
|
||||
cfg.config_value = value
|
||||
else:
|
||||
db.add(SystemConfig(config_key=key, config_value=value))
|
||||
await db.commit()
|
||||
|
||||
async def _check_sessions(self):
|
||||
"""定时校验登录状态"""
|
||||
from app.services.news_service import news_service
|
||||
async with AsyncSessionLocal() as db:
|
||||
result = await db.execute(
|
||||
select(VirtualUser).where(VirtualUser.status == 2, VirtualUser.is_enabled == 1)
|
||||
)
|
||||
users = result.scalars().all()
|
||||
for user in users:
|
||||
try:
|
||||
valid = await news_service.check_session(db, user)
|
||||
if not valid:
|
||||
logger.warning(f"用户 {user.account} 会话失效,尝试重登")
|
||||
await news_service.login(db, user)
|
||||
except Exception as e:
|
||||
logger.error(f"会话校验异常 {user.account}: {e}")
|
||||
|
||||
async def _run_interactions(self):
|
||||
"""执行互动任务"""
|
||||
async with AsyncSessionLocal() as db:
|
||||
# 检查调度器开关
|
||||
enabled = await self._get_config(db, "scheduler_enabled", "true")
|
||||
if enabled != "true":
|
||||
return
|
||||
|
||||
# 检查Token限额
|
||||
token_limited = await self._get_config(db, "token_limit_reached", "false")
|
||||
if token_limited == "true":
|
||||
return
|
||||
|
||||
# 检查互动时间段(北京时间 UTC+8)
|
||||
from datetime import timezone, timedelta
|
||||
tz_beijing = timezone(timedelta(hours=8))
|
||||
now_bj = datetime.now(tz_beijing)
|
||||
now_time = now_bj.strftime("%H:%M")
|
||||
start_str = await self._get_config(db, "interact_time_start", "08:00")
|
||||
end_str = await self._get_config(db, "interact_time_end", "22:00")
|
||||
if not (start_str <= now_time <= end_str):
|
||||
logger.debug(f"[调度] 当前北京时间 {now_time} 不在互动时段 {start_str}-{end_str}")
|
||||
return
|
||||
|
||||
# 获取最小互动间隔(秒)
|
||||
min_interval = int(await self._get_config(db, "interact_min_interval", "300"))
|
||||
|
||||
# 获取最大并发
|
||||
max_concurrent = int(await self._get_config(db, "max_concurrent_users", "5"))
|
||||
|
||||
# 获取已登录、启用的用户
|
||||
query = select(VirtualUser).where(
|
||||
VirtualUser.status == 2,
|
||||
VirtualUser.is_enabled == 1,
|
||||
)
|
||||
if max_concurrent > 0:
|
||||
query = query.limit(max_concurrent)
|
||||
result = await db.execute(query)
|
||||
all_users = result.scalars().all()
|
||||
|
||||
# 没有已登录用户时,尝试登录未登录用户
|
||||
if not all_users:
|
||||
await self._try_login_users(db)
|
||||
return
|
||||
|
||||
# 检查互动间隔:过滤掉最近 min_interval 秒内已互动的用户
|
||||
now_utc = datetime.utcnow()
|
||||
eligible = []
|
||||
for u in all_users:
|
||||
if u.last_interact_at is None:
|
||||
eligible.append(u)
|
||||
else:
|
||||
elapsed = (now_utc - u.last_interact_at).total_seconds()
|
||||
if elapsed >= min_interval:
|
||||
eligible.append(u)
|
||||
|
||||
if not eligible:
|
||||
logger.debug(f"[调度] 所有用户在 {min_interval}s 内已互动,跳过本次")
|
||||
return
|
||||
|
||||
logger.info(f"[调度] {len(eligible)}/{len(all_users)} 个用户满足间隔要求,开始互动")
|
||||
|
||||
# 随机选取用户执行互动
|
||||
for user in eligible:
|
||||
if random.random() < 0.6:
|
||||
asyncio.create_task(self._execute_user_interaction(user.id))
|
||||
|
||||
async def _try_login_users(self, db):
|
||||
"""尝试登录未登录的用户"""
|
||||
from app.services.news_service import news_service
|
||||
result = await db.execute(
|
||||
select(VirtualUser).where(
|
||||
VirtualUser.status.in_([0, 3]),
|
||||
VirtualUser.is_enabled == 1
|
||||
).limit(3)
|
||||
)
|
||||
users = result.scalars().all()
|
||||
for user in users:
|
||||
try:
|
||||
await news_service.login(db, user)
|
||||
await asyncio.sleep(2)
|
||||
except Exception as e:
|
||||
logger.error(f"自动登录失败 {user.account}: {e}")
|
||||
|
||||
async def _execute_user_interaction(self, user_id: int):
|
||||
"""执行单用户互动 - 基于真实接口"""
|
||||
from app.services.news_service import news_service
|
||||
from app.services.ai_service import ai_service
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
try:
|
||||
user_result = await db.execute(select(VirtualUser).where(VirtualUser.id == user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
if not user or user.status != 2:
|
||||
return
|
||||
|
||||
# 检查今日评论限额
|
||||
can_comment = True
|
||||
if user.today_comment_count >= user.daily_comment_limit:
|
||||
can_comment = False
|
||||
logger.info(f'用户 ' + user.account + ' 今日评论已达上限,仍执行点赞/收藏/转发')
|
||||
|
||||
# 获取人格
|
||||
from app.models import UserPersonality
|
||||
p_result = await db.execute(
|
||||
select(UserPersonality).where(UserPersonality.user_id == user_id)
|
||||
)
|
||||
personality = p_result.scalar_one_or_none()
|
||||
interest_tags = personality.interest_tags if personality else []
|
||||
|
||||
# 获取新闻列表(基于接口 GET /news/list)
|
||||
articles = await news_service.get_news_list(
|
||||
db, user, count=5, interest_tags=interest_tags
|
||||
)
|
||||
if not articles:
|
||||
# 尝试从 session 获取 org_id 再试一次
|
||||
from app.core.redis_client import get_session as _get_sess
|
||||
sess = await _get_sess(user.id)
|
||||
org_from_sess = sess.get("org_id", "") if sess else ""
|
||||
if org_from_sess:
|
||||
articles = await news_service.get_news_list(
|
||||
db, user, count=5, interest_tags=interest_tags
|
||||
)
|
||||
if not articles:
|
||||
logger.warning(
|
||||
f"用户 {user.account} 获取新闻列表为空 "
|
||||
f"(orgId={await news_service._cfg(db, 'platform_org_id', '')})"
|
||||
)
|
||||
return
|
||||
|
||||
article = random.choice(articles)
|
||||
# 接口返回字段: id/newsTitle/content/digest/createUser
|
||||
# 广场接口字段:recordId=新闻实际ID, id=广场记录ID, title=标题
|
||||
news_id = str(article.get("recordId") or article.get("id", ""))
|
||||
news_title = article.get("title") or article.get("newsTitle") or "未知文章"
|
||||
news_content = article.get("content") or article.get("digest") or news_title
|
||||
news_author = str(article.get("createUser") or "")
|
||||
# 从广场数据中顺带获取 orgId
|
||||
article_org_id = str(article.get("orgId") or "")
|
||||
|
||||
if not news_id:
|
||||
return
|
||||
|
||||
# 读取互动概率
|
||||
comment_prob = float(await self._get_config_from_db(db, "comment_probability", "0.4"))
|
||||
reply_prob = float(await self._get_config_from_db(db, "reply_probability", "0.2"))
|
||||
like_prob = float(await self._get_config_from_db(db, "like_probability", "0.6"))
|
||||
collect_prob = float(await self._get_config_from_db(db, "collect_probability", "0.3"))
|
||||
forward_prob = float(await self._get_config_from_db(db, "forward_probability", "0.15"))
|
||||
|
||||
interactions_done = []
|
||||
|
||||
# ① 先记录阅读(每次必做,模拟真实用户打开文章)
|
||||
await news_service.read_news(db, user, news_id)
|
||||
|
||||
# ② 点赞
|
||||
if random.random() < like_prob:
|
||||
success, err = await news_service.like_news(db, user, news_id, org_id=article_org_id, to_user_id=news_author, title=news_title)
|
||||
await self._save_record(db, user, news_id, news_title, "like", None, 0, success, err)
|
||||
if success:
|
||||
interactions_done.append("like")
|
||||
await self._incr_total(db, user_id)
|
||||
|
||||
# ③ 收藏(阅读+点赞组合模拟)
|
||||
if random.random() < collect_prob:
|
||||
success, err = await news_service.collect_news(db, user, news_id, org_id=article_org_id, to_user_id=news_author, title=news_title)
|
||||
await self._save_record(db, user, news_id, news_title, "collect", None, 0, success, err)
|
||||
if success:
|
||||
interactions_done.append("collect")
|
||||
|
||||
# ④ 转发(调用 /points/forward/news/{orgId})
|
||||
if random.random() < forward_prob:
|
||||
success, err = await news_service.forward_news(db, user, news_id)
|
||||
await self._save_record(db, user, news_id, news_title, "forward", None, 0, success, err)
|
||||
if success:
|
||||
interactions_done.append("forward")
|
||||
await self._incr_total(db, user_id)
|
||||
|
||||
# ⑤ 评论(AI生成内容,调用 POST /message/comment)
|
||||
if can_comment and random.random() < comment_prob and personality:
|
||||
style_prompt = personality.comment_style_prompt or ""
|
||||
# 字数上限最多80字,避免超出 max_tokens 被截断
|
||||
safe_word_max = min(personality.word_count_max, 80)
|
||||
comment_text, tokens = await ai_service.generate_comment(
|
||||
db, news_title, news_content,
|
||||
style_prompt, personality.word_count_min, safe_word_max
|
||||
)
|
||||
if comment_text:
|
||||
success, err = await news_service.post_comment(
|
||||
db, user, news_id, news_title, comment_text,
|
||||
news_author_id=news_author, org_id=article_org_id
|
||||
)
|
||||
await self._save_record(
|
||||
db, user, news_id, news_title, "comment",
|
||||
comment_text, tokens, success, err
|
||||
)
|
||||
if success:
|
||||
interactions_done.append("comment")
|
||||
await db.execute(
|
||||
update(VirtualUser).where(VirtualUser.id == user_id).values(
|
||||
today_comment_count=VirtualUser.today_comment_count + 1,
|
||||
total_interactions=VirtualUser.total_interactions + 1,
|
||||
last_interact_at=datetime.utcnow()
|
||||
)
|
||||
)
|
||||
|
||||
# ⑥ 回复评论(评论成功后,随机回复别人的评论)
|
||||
if random.random() < reply_prob:
|
||||
existing = await news_service.get_comments(db, user, news_id)
|
||||
if existing:
|
||||
target = random.choice(existing)
|
||||
cid = str(target.get("id") or target.get("commentId") or "")
|
||||
parent_content = target.get("content") or ""
|
||||
if cid:
|
||||
reply_text, r_tokens = await ai_service.generate_reply(
|
||||
db, news_title, parent_content,
|
||||
style_prompt,
|
||||
personality.word_count_min,
|
||||
personality.word_count_max
|
||||
)
|
||||
if reply_text:
|
||||
r_ok, r_err = await news_service.post_reply(
|
||||
db, user, news_id, cid, reply_text
|
||||
)
|
||||
await self._save_record(
|
||||
db, user, news_id, news_title, "reply",
|
||||
reply_text, r_tokens, r_ok, r_err,
|
||||
parent_comment_id=cid
|
||||
)
|
||||
if r_ok:
|
||||
interactions_done.append("reply")
|
||||
|
||||
await db.commit()
|
||||
logger.info(f"👤 {user.account} 互动完成: {interactions_done} [新闻: {news_title[:20]}]")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"用户 {user_id} 互动异常: {e}")
|
||||
|
||||
async def _incr_total(self, db, user_id: int):
|
||||
await db.execute(
|
||||
update(VirtualUser).where(VirtualUser.id == user_id).values(
|
||||
total_interactions=VirtualUser.total_interactions + 1,
|
||||
last_interact_at=datetime.utcnow()
|
||||
)
|
||||
)
|
||||
|
||||
async def _save_record(
|
||||
self, db, user: VirtualUser, article_id: str, article_title: str,
|
||||
interact_type: str, content: Optional[str], tokens: int,
|
||||
success: bool, error_msg: str, parent_comment_id: str = None,
|
||||
platform_record_id: str = None
|
||||
):
|
||||
from app.core.redis_client import get_session
|
||||
session = await get_session(user.id)
|
||||
session_id = session.get("session_id") if session else None
|
||||
|
||||
record = InteractionRecord(
|
||||
user_id=user.id,
|
||||
user_nickname=user.nickname,
|
||||
user_account=user.account,
|
||||
article_id=article_id,
|
||||
article_title=article_title,
|
||||
interact_type=interact_type,
|
||||
content=content,
|
||||
parent_comment_id=parent_comment_id,
|
||||
platform_record_id=platform_record_id,
|
||||
session_id=session_id,
|
||||
token_consumed=tokens,
|
||||
status=1 if success else 2,
|
||||
error_msg=error_msg or None,
|
||||
executed_at=datetime.now(),
|
||||
)
|
||||
db.add(record)
|
||||
|
||||
async def _get_config_from_db(self, db, key: str, default: str = "") -> str:
|
||||
result = await db.execute(select(SystemConfig).where(SystemConfig.config_key == key))
|
||||
cfg = result.scalar_one_or_none()
|
||||
return cfg.config_value if cfg else default
|
||||
|
||||
async def _daily_reset(self):
|
||||
"""每日零点重置计数"""
|
||||
async with AsyncSessionLocal() as db:
|
||||
await db.execute(
|
||||
update(VirtualUser).values(
|
||||
today_comment_count=0,
|
||||
today_like_count=0
|
||||
)
|
||||
)
|
||||
# 重置Token限额标志
|
||||
result = await db.execute(
|
||||
select(SystemConfig).where(SystemConfig.config_key == "token_limit_reached")
|
||||
)
|
||||
cfg = result.scalar_one_or_none()
|
||||
if cfg:
|
||||
cfg.config_value = "false"
|
||||
await db.commit()
|
||||
logger.info("每日计数重置完成")
|
||||
|
||||
|
||||
scheduler_service = SchedulerService()
|
||||
251
backend/app/services/stats_service.py
Normal file
251
backend/app/services/stats_service.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""数据统计服务"""
|
||||
from datetime import datetime, date, timedelta, timezone
|
||||
|
||||
def _fmt_dt(dt):
|
||||
"""统一输出 UTC 时间,带时区标识,让前端正确解析为 +8"""
|
||||
if dt is None:
|
||||
return None
|
||||
if dt.tzinfo is None:
|
||||
# 数据库存的是 UTC,补上时区信息
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt.isoformat()
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, and_
|
||||
|
||||
from app.models import VirtualUser, InteractionRecord, TokenStat, SystemConfig
|
||||
from app.core.logger import logger
|
||||
|
||||
|
||||
class StatsService:
|
||||
|
||||
async def get_dashboard(self, db: AsyncSession) -> dict:
|
||||
"""获取控制台数据"""
|
||||
today = date.today()
|
||||
now = datetime.now()
|
||||
month_start = today.replace(day=1)
|
||||
|
||||
# 用户统计
|
||||
user_stats = await self._get_user_stats(db)
|
||||
|
||||
# 今日互动统计
|
||||
today_stats = await self._get_today_stats(db, today)
|
||||
|
||||
# 本月互动统计
|
||||
monthly_stats = await self._get_monthly_stats(db, month_start, today)
|
||||
|
||||
# Token统计
|
||||
token_stats = await self._get_token_stats(db, today)
|
||||
|
||||
# 系统状态
|
||||
system_status = await self._get_system_status(db, now)
|
||||
|
||||
# 在线用户数
|
||||
online_count_result = await db.execute(
|
||||
select(func.count()).where(VirtualUser.status == 2)
|
||||
)
|
||||
online_count = online_count_result.scalar() or 0
|
||||
|
||||
return {
|
||||
"user_stats": user_stats,
|
||||
"today_interactions": today_stats,
|
||||
"monthly_stats": monthly_stats,
|
||||
"token_stats": token_stats,
|
||||
"system_status": system_status,
|
||||
"online_users": online_count,
|
||||
}
|
||||
|
||||
async def _get_user_stats(self, db: AsyncSession) -> dict:
|
||||
total = await db.execute(select(func.count()).select_from(VirtualUser))
|
||||
normal = await db.execute(select(func.count()).where(VirtualUser.is_enabled == 1))
|
||||
banned = await db.execute(select(func.count()).where(VirtualUser.status == 4))
|
||||
abnormal = await db.execute(select(func.count()).where(VirtualUser.status == 3))
|
||||
return {
|
||||
"total": total.scalar() or 0,
|
||||
"normal": normal.scalar() or 0,
|
||||
"banned": banned.scalar() or 0,
|
||||
"abnormal": abnormal.scalar() or 0,
|
||||
}
|
||||
|
||||
async def _get_today_stats(self, db: AsyncSession, today: date) -> dict:
|
||||
result = await db.execute(
|
||||
select(
|
||||
InteractionRecord.interact_type,
|
||||
func.count().label("cnt"),
|
||||
).where(
|
||||
func.date(InteractionRecord.executed_at) == today,
|
||||
InteractionRecord.status == 1,
|
||||
).group_by(InteractionRecord.interact_type)
|
||||
)
|
||||
rows = result.all()
|
||||
stats = {"comment": 0, "reply": 0, "like": 0, "collect": 0, "forward": 0, "total": 0}
|
||||
for row in rows:
|
||||
if row.interact_type in stats:
|
||||
stats[row.interact_type] = row.cnt
|
||||
stats["total"] += row.cnt
|
||||
return stats
|
||||
|
||||
async def _get_monthly_stats(self, db: AsyncSession, month_start: date, today: date) -> dict:
|
||||
result = await db.execute(
|
||||
select(func.count()).where(
|
||||
InteractionRecord.executed_at >= month_start,
|
||||
InteractionRecord.status == 1,
|
||||
)
|
||||
)
|
||||
return {"total": result.scalar() or 0, "month_start": month_start.isoformat()}
|
||||
|
||||
async def _get_token_stats(self, db: AsyncSession, today: date) -> dict:
|
||||
# 今日
|
||||
today_stat = await db.execute(select(TokenStat).where(TokenStat.stat_date == today))
|
||||
today_row = today_stat.scalar_one_or_none()
|
||||
|
||||
# 每日限额
|
||||
limit_cfg = await db.execute(
|
||||
select(SystemConfig).where(SystemConfig.config_key == "daily_token_limit")
|
||||
)
|
||||
limit_row = limit_cfg.scalar_one_or_none()
|
||||
daily_limit = int(limit_row.config_value) if limit_row else 100000
|
||||
|
||||
today_used = today_row.total_tokens if today_row else 0
|
||||
return {
|
||||
"today_used": today_used,
|
||||
"daily_limit": daily_limit,
|
||||
"remaining": max(0, daily_limit - today_used),
|
||||
"today_calls": today_row.call_count if today_row else 0,
|
||||
}
|
||||
|
||||
async def _get_system_status(self, db: AsyncSession, now: datetime) -> dict:
|
||||
start_cfg = await db.execute(
|
||||
select(SystemConfig).where(SystemConfig.config_key == "system_start_time")
|
||||
)
|
||||
start_row = start_cfg.scalar_one_or_none()
|
||||
uptime = ""
|
||||
if start_row and start_row.config_value:
|
||||
try:
|
||||
start_time = datetime.fromisoformat(start_row.config_value)
|
||||
delta = now - start_time
|
||||
hours, rem = divmod(int(delta.total_seconds()), 3600)
|
||||
mins = rem // 60
|
||||
uptime = f"{hours}小时{mins}分钟"
|
||||
except Exception:
|
||||
uptime = "未知"
|
||||
|
||||
scheduler_cfg = await db.execute(
|
||||
select(SystemConfig).where(SystemConfig.config_key == "scheduler_enabled")
|
||||
)
|
||||
scheduler_row = scheduler_cfg.scalar_one_or_none()
|
||||
|
||||
return {
|
||||
"uptime": uptime,
|
||||
"scheduler_enabled": (scheduler_row.config_value == "true") if scheduler_row else True,
|
||||
"current_time": now.isoformat(),
|
||||
}
|
||||
|
||||
async def get_token_trend(self, db: AsyncSession, days: int = 30) -> list:
|
||||
"""Token消耗趋势(近N天)"""
|
||||
end_date = date.today()
|
||||
start_date = end_date - timedelta(days=days - 1)
|
||||
result = await db.execute(
|
||||
select(TokenStat).where(
|
||||
TokenStat.stat_date >= start_date,
|
||||
TokenStat.stat_date <= end_date,
|
||||
).order_by(TokenStat.stat_date)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
stat_map = {r.stat_date.isoformat(): r.total_tokens for r in rows}
|
||||
|
||||
trend = []
|
||||
for i in range(days):
|
||||
d = (start_date + timedelta(days=i)).isoformat()
|
||||
trend.append({"date": d, "tokens": stat_map.get(d, 0)})
|
||||
return trend
|
||||
|
||||
async def get_monthly_token_trend(self, db: AsyncSession) -> list:
|
||||
"""近12个月Token消耗"""
|
||||
today = date.today()
|
||||
months = []
|
||||
for i in range(11, -1, -1):
|
||||
if today.month - i <= 0:
|
||||
year = today.year - 1
|
||||
month = today.month - i + 12
|
||||
else:
|
||||
year = today.year
|
||||
month = today.month - i
|
||||
months.append((year, month))
|
||||
|
||||
trend = []
|
||||
for year, month in months:
|
||||
start = date(year, month, 1)
|
||||
if month == 12:
|
||||
end = date(year + 1, 1, 1) - timedelta(days=1)
|
||||
else:
|
||||
end = date(year, month + 1, 1) - timedelta(days=1)
|
||||
result = await db.execute(
|
||||
select(func.sum(TokenStat.total_tokens)).where(
|
||||
TokenStat.stat_date >= start, TokenStat.stat_date <= end
|
||||
)
|
||||
)
|
||||
total = result.scalar() or 0
|
||||
trend.append({"month": f"{year}-{month:02d}", "tokens": total})
|
||||
return trend
|
||||
|
||||
async def get_interaction_records(
|
||||
self, db: AsyncSession,
|
||||
page: int = 1, page_size: int = 20,
|
||||
user_id: int = None, interact_type: str = None,
|
||||
status: int = None, start_date: str = None,
|
||||
end_date: str = None, keyword: str = None
|
||||
) -> dict:
|
||||
query = select(InteractionRecord)
|
||||
conditions = []
|
||||
if user_id:
|
||||
conditions.append(InteractionRecord.user_id == user_id)
|
||||
if interact_type:
|
||||
conditions.append(InteractionRecord.interact_type == interact_type)
|
||||
if status is not None:
|
||||
conditions.append(InteractionRecord.status == status)
|
||||
if start_date:
|
||||
conditions.append(InteractionRecord.executed_at >= start_date)
|
||||
if end_date:
|
||||
conditions.append(InteractionRecord.executed_at <= end_date + " 23:59:59")
|
||||
if keyword:
|
||||
from sqlalchemy import or_
|
||||
conditions.append(
|
||||
or_(InteractionRecord.article_title.like(f"%{keyword}%"),
|
||||
InteractionRecord.content.like(f"%{keyword}%"),
|
||||
InteractionRecord.user_nickname.like(f"%{keyword}%"))
|
||||
)
|
||||
if conditions:
|
||||
query = query.where(and_(*conditions))
|
||||
|
||||
count_q = select(func.count()).select_from(query.subquery())
|
||||
total = (await db.execute(count_q)).scalar()
|
||||
|
||||
query = query.order_by(InteractionRecord.executed_at.desc()).offset(
|
||||
(page - 1) * page_size
|
||||
).limit(page_size)
|
||||
result = await db.execute(query)
|
||||
records = result.scalars().all()
|
||||
|
||||
INTERACT_LABELS = {
|
||||
"comment": "评论", "reply": "回复", "like": "点赞",
|
||||
"collect": "收藏", "forward": "转发"
|
||||
}
|
||||
STATUS_LABELS = {0: "执行中", 1: "成功", 2: "失败"}
|
||||
|
||||
items = []
|
||||
for r in records:
|
||||
items.append({
|
||||
"id": r.id, "user_id": r.user_id,
|
||||
"user_nickname": r.user_nickname, "user_account": r.user_account,
|
||||
"article_id": r.article_id, "article_title": r.article_title,
|
||||
"interact_type": r.interact_type,
|
||||
"interact_type_label": INTERACT_LABELS.get(r.interact_type, r.interact_type),
|
||||
"content": r.content, "token_consumed": r.token_consumed,
|
||||
"status": r.status, "status_label": STATUS_LABELS.get(r.status, "未知"),
|
||||
"error_msg": r.error_msg, "retry_count": r.retry_count,
|
||||
"executed_at": _fmt_dt(r.executed_at),
|
||||
})
|
||||
return {"total": total, "page": page, "page_size": page_size, "items": items}
|
||||
|
||||
|
||||
stats_service = StatsService()
|
||||
358
backend/app/services/user_service.py
Normal file
358
backend/app/services/user_service.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""虚拟用户业务服务"""
|
||||
import io
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
def _fmt_dt(dt):
|
||||
if dt is None: return None
|
||||
if dt.tzinfo is None: dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt.isoformat()
|
||||
from typing import List, Optional, Tuple
|
||||
import pandas as pd
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update, delete, func, and_, or_
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.models import VirtualUser, UserPersonality
|
||||
from app.schemas import UserCreateRequest, UserUpdateRequest
|
||||
from app.utils.crypto import encrypt, decrypt
|
||||
from app.services.ai_service import ai_service
|
||||
from app.core.logger import logger
|
||||
|
||||
STATUS_LABELS = {0: "未登录", 1: "登录中", 2: "已登录", 3: "登录失效", 4: "封禁"}
|
||||
ACTIVITY_LABELS = {0: "低", 1: "中", 2: "高"}
|
||||
ACTIVITY_COMMENT_LIMITS = {0: (3, 5), 1: (8, 15), 2: (20, 30)}
|
||||
|
||||
|
||||
class UserService:
|
||||
|
||||
async def get_users(
|
||||
self, db: AsyncSession,
|
||||
page: int = 1, page_size: int = 20,
|
||||
keyword: str = None, status: int = None,
|
||||
is_enabled: int = None
|
||||
) -> Tuple[int, List[dict]]:
|
||||
query = select(VirtualUser)
|
||||
conditions = []
|
||||
if keyword:
|
||||
conditions.append(
|
||||
or_(VirtualUser.nickname.like(f"%{keyword}%"),
|
||||
VirtualUser.account.like(f"%{keyword}%"))
|
||||
)
|
||||
if status is not None:
|
||||
conditions.append(VirtualUser.status == status)
|
||||
if is_enabled is not None:
|
||||
conditions.append(VirtualUser.is_enabled == is_enabled)
|
||||
if conditions:
|
||||
query = query.where(and_(*conditions))
|
||||
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(query.subquery())
|
||||
)
|
||||
total = count_result.scalar()
|
||||
|
||||
query = query.offset((page - 1) * page_size).limit(page_size).order_by(VirtualUser.created_at.desc())
|
||||
result = await db.execute(query)
|
||||
users = result.scalars().all()
|
||||
|
||||
items = []
|
||||
for u in users:
|
||||
# 获取人格
|
||||
p_result = await db.execute(select(UserPersonality).where(UserPersonality.user_id == u.id))
|
||||
personality = p_result.scalar_one_or_none()
|
||||
items.append(self._format_user(u, personality))
|
||||
return total, items
|
||||
|
||||
async def create_user(self, db: AsyncSession, req: UserCreateRequest) -> dict:
|
||||
# 检查账号重复
|
||||
existing = await db.execute(select(VirtualUser).where(VirtualUser.account == req.account))
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=400, detail="账号已存在")
|
||||
# 昵称选填:为空则自动生成
|
||||
nickname = req.nickname or f"用户{req.account[-4:]}"
|
||||
# 检查昵称重复(自动生成的若冲突则加随机后缀)
|
||||
existing_nick = await db.execute(select(VirtualUser).where(VirtualUser.nickname == nickname))
|
||||
if existing_nick.scalar_one_or_none():
|
||||
import random, string
|
||||
nickname = nickname + "_" + "".join(random.choices(string.digits, k=4))
|
||||
|
||||
user = VirtualUser(
|
||||
nickname=nickname,
|
||||
account=req.account,
|
||||
password_enc=encrypt(req.password),
|
||||
avatar_url=req.avatar_url,
|
||||
activity_level=req.activity_level,
|
||||
daily_comment_limit=req.daily_comment_limit,
|
||||
daily_like_limit=req.daily_like_limit,
|
||||
remark=req.remark,
|
||||
status=0,
|
||||
is_enabled=1,
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
# 自动生成AI人格
|
||||
try:
|
||||
await self._generate_personality(db, user)
|
||||
except Exception as e:
|
||||
logger.warning(f"人格生成失败,跳过: {e}")
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
p_result = await db.execute(select(UserPersonality).where(UserPersonality.user_id == user.id))
|
||||
personality = p_result.scalar_one_or_none()
|
||||
return self._format_user(user, personality)
|
||||
|
||||
async def update_user(self, db: AsyncSession, user_id: int, req: UserUpdateRequest) -> dict:
|
||||
user = await self._get_or_404(db, user_id)
|
||||
if req.nickname and req.nickname != user.nickname:
|
||||
existing = await db.execute(
|
||||
select(VirtualUser).where(VirtualUser.nickname == req.nickname, VirtualUser.id != user_id)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=400, detail="昵称已被使用")
|
||||
user.nickname = req.nickname
|
||||
if req.password:
|
||||
user.password_enc = encrypt(req.password)
|
||||
if req.avatar_url is not None:
|
||||
user.avatar_url = req.avatar_url
|
||||
if req.activity_level is not None:
|
||||
user.activity_level = req.activity_level
|
||||
if req.daily_comment_limit is not None:
|
||||
user.daily_comment_limit = req.daily_comment_limit
|
||||
if req.daily_like_limit is not None:
|
||||
user.daily_like_limit = req.daily_like_limit
|
||||
if req.remark is not None:
|
||||
user.remark = req.remark
|
||||
if req.is_enabled is not None:
|
||||
user.is_enabled = req.is_enabled
|
||||
if req.is_enabled == 0:
|
||||
user.status = 0 # 禁用后重置状态
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
p_result = await db.execute(select(UserPersonality).where(UserPersonality.user_id == user.id))
|
||||
personality = p_result.scalar_one_or_none()
|
||||
return self._format_user(user, personality)
|
||||
|
||||
async def delete_user(self, db: AsyncSession, user_id: int):
|
||||
user = await self._get_or_404(db, user_id)
|
||||
await db.execute(delete(UserPersonality).where(UserPersonality.user_id == user_id))
|
||||
await db.delete(user)
|
||||
await db.commit()
|
||||
|
||||
async def batch_action(self, db: AsyncSession, user_ids: List[int], action: str):
|
||||
"""批量操作"""
|
||||
if action == "enable":
|
||||
await db.execute(update(VirtualUser).where(VirtualUser.id.in_(user_ids)).values(is_enabled=1))
|
||||
elif action == "disable":
|
||||
await db.execute(update(VirtualUser).where(VirtualUser.id.in_(user_ids)).values(is_enabled=0, status=0))
|
||||
elif action == "logout":
|
||||
await db.execute(update(VirtualUser).where(VirtualUser.id.in_(user_ids)).values(status=0, session_token=None))
|
||||
elif action == "delete":
|
||||
await db.execute(delete(UserPersonality).where(UserPersonality.user_id.in_(user_ids)))
|
||||
await db.execute(delete(VirtualUser).where(VirtualUser.id.in_(user_ids)))
|
||||
await db.commit()
|
||||
return {"affected": len(user_ids)}
|
||||
|
||||
async def generate_personality(self, db: AsyncSession, user_id: int) -> dict:
|
||||
"""为用户生成/重新生成AI人格"""
|
||||
user = await self._get_or_404(db, user_id)
|
||||
# 删除旧人格
|
||||
await db.execute(delete(UserPersonality).where(UserPersonality.user_id == user_id))
|
||||
personality = await self._generate_personality(db, user)
|
||||
await db.commit()
|
||||
return self._format_personality(personality)
|
||||
|
||||
async def update_personality(self, db: AsyncSession, user_id: int, req) -> dict:
|
||||
p_result = await db.execute(select(UserPersonality).where(UserPersonality.user_id == user_id))
|
||||
personality = p_result.scalar_one_or_none()
|
||||
if not personality:
|
||||
raise HTTPException(status_code=404, detail="人格不存在")
|
||||
for field, val in req.model_dump(exclude_none=True).items():
|
||||
setattr(personality, field, val)
|
||||
# 重新生成提示词
|
||||
personality.comment_style_prompt = self._build_style_prompt(personality)
|
||||
await db.commit()
|
||||
await db.refresh(personality)
|
||||
return self._format_personality(personality)
|
||||
|
||||
async def import_from_excel(self, db: AsyncSession, file_content: bytes) -> dict:
|
||||
"""Excel批量导入 - 每行独立事务,互不影响"""
|
||||
try:
|
||||
df = pd.read_excel(io.BytesIO(file_content), engine='openpyxl')
|
||||
except Exception:
|
||||
df = pd.read_excel(io.BytesIO(file_content))
|
||||
|
||||
df.columns = [str(c).strip() for c in df.columns]
|
||||
required_cols = {"新闻平台账号", "登录密码", "昵称"}
|
||||
if not required_cols.issubset(set(df.columns)):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"缺少必填列: {required_cols - set(df.columns)},当前列: {list(df.columns)}"
|
||||
)
|
||||
|
||||
success_count = 0
|
||||
error_list = []
|
||||
|
||||
for idx, row in df.iterrows():
|
||||
row_num = idx + 2
|
||||
row_account = ""
|
||||
try:
|
||||
# 账号可能是数字类型(手机号),统一转为字符串
|
||||
account = str(row.get("新闻平台账号", "") or "").strip().split(".")[0] # 去掉 .0 后缀
|
||||
password = str(row.get("登录密码", "") or "").strip()
|
||||
nickname = str(row.get("昵称", "") or "").strip()
|
||||
row_account = account
|
||||
|
||||
if account.lower() in ("nan", "none", ""):
|
||||
error_list.append({"row": row_num, "error": "账号为空"}); continue
|
||||
if password.lower() in ("nan", "none", ""):
|
||||
error_list.append({"row": row_num, "account": account, "error": "密码为空"}); continue
|
||||
if len(password) < 6:
|
||||
error_list.append({"row": row_num, "account": account, "error": "密码不足6位"}); continue
|
||||
# 昵称选填:为空时自动用账号末4位生成
|
||||
if nickname.lower() in ("nan", "none", ""):
|
||||
nickname = f"用户{account[-4:]}"
|
||||
|
||||
existing = await db.execute(select(VirtualUser).where(VirtualUser.account == account))
|
||||
if existing.scalar_one_or_none():
|
||||
error_list.append({"row": row_num, "account": account, "error": "账号已存在"}); continue
|
||||
existing_nick = await db.execute(select(VirtualUser).where(VirtualUser.nickname == nickname))
|
||||
if existing_nick.scalar_one_or_none():
|
||||
error_list.append({"row": row_num, "account": account, "error": "昵称已被使用"}); continue
|
||||
|
||||
avatar = str(row.get("头像链接", "") or "").strip()
|
||||
remark = str(row.get("备注", "") or "").strip()
|
||||
|
||||
user = VirtualUser(
|
||||
nickname=nickname, account=account,
|
||||
password_enc=encrypt(password),
|
||||
avatar_url=avatar if avatar.lower() not in ("nan","none","") else None,
|
||||
remark=remark if remark.lower() not in ("nan","none","") else None,
|
||||
status=0, is_enabled=1, activity_level=1,
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
await db.commit()
|
||||
|
||||
try:
|
||||
await db.refresh(user)
|
||||
await self._generate_personality(db, user)
|
||||
await db.commit()
|
||||
except Exception as pe:
|
||||
logger.warning(f"第{row_num}行人格生成跳过: {pe}")
|
||||
await db.rollback()
|
||||
|
||||
success_count += 1
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
error_list.append({"row": row_num, "account": row_account, "error": str(e)})
|
||||
logger.warning(f"导入第{row_num}行失败: {e}")
|
||||
|
||||
return {"success": success_count, "failed": len(error_list), "errors": error_list}
|
||||
|
||||
async def export_to_excel(self, db: AsyncSession) -> bytes:
|
||||
"""导出全量用户数据(不含密码)"""
|
||||
result = await db.execute(select(VirtualUser).order_by(VirtualUser.created_at.desc()))
|
||||
users = result.scalars().all()
|
||||
rows = []
|
||||
for u in users:
|
||||
p_result = await db.execute(select(UserPersonality).where(UserPersonality.user_id == u.id))
|
||||
p = p_result.scalar_one_or_none()
|
||||
rows.append({
|
||||
"ID": u.id, "昵称": u.nickname, "账号": u.account,
|
||||
"状态": STATUS_LABELS.get(u.status, "未知"),
|
||||
"活跃度": ACTIVITY_LABELS.get(u.activity_level, "中"),
|
||||
"性格": p.character_type if p else "", "语言风格": p.language_style if p else "",
|
||||
"兴趣偏好": ",".join(p.interest_tags or []) if p else "",
|
||||
"互动倾向": p.interact_tendency if p else "",
|
||||
"累计互动": u.total_interactions, "今日评论": u.today_comment_count,
|
||||
"今日点赞": u.today_like_count, "最后登录": u.last_login_at,
|
||||
"最后互动": u.last_interact_at, "备注": u.remark,
|
||||
"是否启用": "是" if u.is_enabled else "否", "创建时间": u.created_at,
|
||||
})
|
||||
df = pd.DataFrame(rows)
|
||||
buf = io.BytesIO()
|
||||
df.to_excel(buf, index=False, sheet_name="虚拟用户")
|
||||
buf.seek(0)
|
||||
return buf.read()
|
||||
|
||||
async def get_excel_template(self) -> bytes:
|
||||
"""获取导入模板(账号+密码必填,其他选填)"""
|
||||
df = pd.DataFrame(columns=["新闻平台账号", "登录密码", "昵称(选填)", "头像链接(选填)", "备注(选填)"])
|
||||
df.loc[0] = ["13800138000", "password123", "(留空自动生成)", "", ""]
|
||||
buf = io.BytesIO()
|
||||
df.to_excel(buf, index=False, sheet_name="导入模板")
|
||||
buf.seek(0)
|
||||
return buf.read()
|
||||
|
||||
async def _generate_personality(self, db: AsyncSession, user: VirtualUser) -> UserPersonality:
|
||||
"""调用AI生成人格"""
|
||||
result = await ai_service.generate_personality(user.nickname, user.account)
|
||||
personality = UserPersonality(
|
||||
user_id=user.id,
|
||||
character_type=result.get("character_type", "温和"),
|
||||
language_style=result.get("language_style", "幽默"),
|
||||
interest_tags=result.get("interest_tags", ["科技"]),
|
||||
interact_tendency=result.get("interact_tendency", "爱评论"),
|
||||
word_count_min=result.get("word_count_min", 20),
|
||||
word_count_max=result.get("word_count_max", 80),
|
||||
personality_desc=result.get("personality_desc", ""),
|
||||
)
|
||||
personality.comment_style_prompt = self._build_style_prompt(personality)
|
||||
db.add(personality)
|
||||
await db.flush()
|
||||
return personality
|
||||
|
||||
def _build_style_prompt(self, p: UserPersonality) -> str:
|
||||
interests = "、".join(p.interest_tags or []) if p.interest_tags else "综合"
|
||||
return (
|
||||
f"你是一个{p.character_type}性格、{p.language_style}语言风格的新闻读者,"
|
||||
f"主要对{interests}类内容感兴趣,互动倾向是{p.interact_tendency}。"
|
||||
f"评论字数控制在{p.word_count_min}~{p.word_count_max}字。"
|
||||
f"个人简介:{p.personality_desc}"
|
||||
)
|
||||
|
||||
def _format_user(self, u: VirtualUser, p: Optional[UserPersonality]) -> dict:
|
||||
return {
|
||||
"id": u.id, "nickname": u.nickname, "account": u.account,
|
||||
"avatar_url": u.avatar_url,
|
||||
"real_name": getattr(u, "real_name", None),
|
||||
"sex": getattr(u, "sex", 0),
|
||||
"platform_uid": getattr(u, "platform_uid", None),
|
||||
"status": u.status,
|
||||
"status_label": STATUS_LABELS.get(u.status, "未知"),
|
||||
"activity_level": u.activity_level,
|
||||
"activity_label": ACTIVITY_LABELS.get(u.activity_level, "中"),
|
||||
"daily_comment_limit": u.daily_comment_limit,
|
||||
"daily_like_limit": u.daily_like_limit,
|
||||
"today_comment_count": u.today_comment_count,
|
||||
"today_like_count": u.today_like_count,
|
||||
"total_interactions": u.total_interactions,
|
||||
"last_login_at": _fmt_dt(u.last_login_at),
|
||||
"last_interact_at": _fmt_dt(u.last_interact_at),
|
||||
"remark": u.remark, "is_enabled": u.is_enabled,
|
||||
"created_at": _fmt_dt(u.created_at),
|
||||
"personality": self._format_personality(p) if p else None,
|
||||
}
|
||||
|
||||
def _format_personality(self, p: Optional[UserPersonality]) -> Optional[dict]:
|
||||
if not p:
|
||||
return None
|
||||
return {
|
||||
"id": p.id, "user_id": p.user_id,
|
||||
"character_type": p.character_type, "language_style": p.language_style,
|
||||
"interest_tags": p.interest_tags or [], "interact_tendency": p.interact_tendency,
|
||||
"word_count_min": p.word_count_min, "word_count_max": p.word_count_max,
|
||||
"personality_desc": p.personality_desc,
|
||||
"updated_at": _fmt_dt(p.updated_at),
|
||||
}
|
||||
|
||||
async def _get_or_404(self, db: AsyncSession, user_id: int) -> VirtualUser:
|
||||
result = await db.execute(select(VirtualUser).where(VirtualUser.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
return user
|
||||
|
||||
|
||||
user_service = UserService()
|
||||
0
backend/app/utils/__init__.py
Normal file
0
backend/app/utils/__init__.py
Normal file
49
backend/app/utils/crypto.py
Normal file
49
backend/app/utils/crypto.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""AES加密工具 - 用于密码和API Key加密存储"""
|
||||
import base64
|
||||
import hashlib
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
def _get_key() -> bytes:
|
||||
"""获取32字节AES密钥"""
|
||||
key = settings.AES_KEY.encode("utf-8")
|
||||
return hashlib.sha256(key).digest()
|
||||
|
||||
|
||||
def encrypt(plaintext: str) -> str:
|
||||
"""AES-CBC加密"""
|
||||
if not plaintext:
|
||||
return ""
|
||||
key = _get_key()
|
||||
cipher = AES.new(key, AES.MODE_CBC)
|
||||
ct_bytes = cipher.encrypt(pad(plaintext.encode("utf-8"), AES.block_size))
|
||||
iv = base64.b64encode(cipher.iv).decode("utf-8")
|
||||
ct = base64.b64encode(ct_bytes).decode("utf-8")
|
||||
return f"{iv}:{ct}"
|
||||
|
||||
|
||||
def decrypt(ciphertext: str) -> str:
|
||||
"""AES-CBC解密"""
|
||||
if not ciphertext or ":" not in ciphertext:
|
||||
return ""
|
||||
try:
|
||||
iv_str, ct_str = ciphertext.split(":", 1)
|
||||
key = _get_key()
|
||||
iv = base64.b64decode(iv_str)
|
||||
ct = base64.b64decode(ct_str)
|
||||
cipher = AES.new(key, AES.MODE_CBC, iv)
|
||||
pt = unpad(cipher.decrypt(ct), AES.block_size)
|
||||
return pt.decode("utf-8")
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def mask_password(password: str) -> str:
|
||||
"""密码脱敏显示"""
|
||||
if not password:
|
||||
return ""
|
||||
if len(password) <= 2:
|
||||
return "*" * len(password)
|
||||
return password[0] + "*" * (len(password) - 2) + password[-1]
|
||||
24
backend/requirements.txt
Normal file
24
backend/requirements.txt
Normal file
@@ -0,0 +1,24 @@
|
||||
fastapi==0.115.6
|
||||
uvicorn[standard]==0.34.0
|
||||
sqlalchemy==2.0.36
|
||||
pymysql==1.1.1
|
||||
cryptography==44.0.0
|
||||
redis==5.2.1
|
||||
apscheduler==3.10.4
|
||||
pandas==2.2.3
|
||||
openpyxl==3.1.5
|
||||
passlib[bcrypt]==1.7.4
|
||||
pycryptodome==3.21.0
|
||||
httpx==0.28.1
|
||||
python-multipart==0.0.20
|
||||
python-jose[cryptography]==3.3.0
|
||||
pydantic==2.10.4
|
||||
pydantic-settings==2.7.0
|
||||
openai==1.59.6
|
||||
langchain==0.3.13
|
||||
langchain-openai==0.3.0
|
||||
aiofiles==24.1.0
|
||||
loguru==0.7.3
|
||||
alembic==1.14.0
|
||||
aiomysql==0.2.0
|
||||
greenlet==3.1.1
|
||||
Reference in New Issue
Block a user