feat: AI虚拟用户新闻互动系统 v1.3.0 初始提交

- 虚拟用户管理(昵称/头像/性别/简介/邮箱同步到目标平台)
- AI互动调度(点赞/收藏/评论/转发)
- 日志时间改为北京时间
- 评论达上限后继续执行点赞收藏转发
- 一键登出全部功能
- 浅色主题UI
This commit is contained in:
stefanfeng
2026-03-31 10:20:57 +08:00
commit 0cfc9bf9c8
53 changed files with 8457 additions and 0 deletions

0
backend/app/__init__.py Normal file
View File

View File

@@ -0,0 +1,12 @@
"""API路由汇总"""
from fastapi import APIRouter
from app.api.endpoints import users, interactions, ai_models, dashboard, system, logs
router = APIRouter()
router.include_router(users.router, prefix="/users", tags=["虚拟用户管理"])
router.include_router(interactions.router, prefix="/interactions", tags=["互动记录"])
router.include_router(ai_models.router, prefix="/ai-models", tags=["AI模型配置"])
router.include_router(dashboard.router, prefix="/dashboard", tags=["数据看板"])
router.include_router(system.router, prefix="/system", tags=["系统设置"])
router.include_router(logs.router, prefix="/logs", tags=["日志管理"])

View File

View File

@@ -0,0 +1,87 @@
"""AI模型配置接口"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select, update
from app.core.database import get_db
from app.schemas import ApiResponse, AIModelCreateRequest, AIModelUpdateRequest, AIModelTestRequest
from app.models import AIModelConfig
from app.utils.crypto import encrypt, decrypt
from app.services.ai_service import ai_service
router = APIRouter()
@router.get("")
async def list_models(db=Depends(get_db)):
result = await db.execute(select(AIModelConfig).order_by(AIModelConfig.created_at.desc()))
models = result.scalars().all()
items = [_format_model(m) for m in models]
return ApiResponse(data=items)
@router.post("")
async def create_model(req: AIModelCreateRequest, db=Depends(get_db)):
if req.is_default:
await db.execute(update(AIModelConfig).values(is_default=0))
model = AIModelConfig(
model_name=req.model_name,
provider=req.provider,
api_base_url=req.api_base_url,
api_key_enc=encrypt(req.api_key) if req.api_key else None,
model_version=req.model_version,
temperature=req.temperature,
max_tokens=req.max_tokens,
timeout_seconds=req.timeout_seconds,
is_default=req.is_default,
is_enabled=1,
)
db.add(model)
await db.commit()
await db.refresh(model)
return ApiResponse(data=_format_model(model), message="模型添加成功")
@router.put("/{model_id}")
async def update_model(model_id: int, req: AIModelUpdateRequest, db=Depends(get_db)):
result = await db.execute(select(AIModelConfig).where(AIModelConfig.id == model_id))
model = result.scalar_one_or_none()
if not model:
raise HTTPException(status_code=404, detail="模型不存在")
if req.is_default:
await db.execute(update(AIModelConfig).where(AIModelConfig.id != model_id).values(is_default=0))
for field, val in req.model_dump(exclude_none=True).items():
if field == "api_key":
model.api_key_enc = encrypt(val) if val else None
else:
setattr(model, field, val)
await db.commit()
await db.refresh(model)
return ApiResponse(data=_format_model(model), message="更新成功")
@router.delete("/{model_id}")
async def delete_model(model_id: int, db=Depends(get_db)):
result = await db.execute(select(AIModelConfig).where(AIModelConfig.id == model_id))
model = result.scalar_one_or_none()
if not model:
raise HTTPException(status_code=404, detail="模型不存在")
await db.delete(model)
await db.commit()
return ApiResponse(message="删除成功")
@router.post("/test")
async def test_model(req: AIModelTestRequest, db=Depends(get_db)):
result = await ai_service.test_model(db, req.model_id, req.test_prompt)
return ApiResponse(data=result)
def _format_model(m: AIModelConfig) -> dict:
return {
"id": m.id, "model_name": m.model_name, "provider": m.provider,
"api_base_url": m.api_base_url, "has_api_key": bool(m.api_key_enc),
"model_version": m.model_version, "temperature": m.temperature,
"max_tokens": m.max_tokens, "timeout_seconds": m.timeout_seconds,
"is_default": m.is_default, "is_enabled": m.is_enabled,
"created_at": m.created_at.isoformat(),
}

View File

@@ -0,0 +1,25 @@
"""数据看板接口"""
from fastapi import APIRouter, Depends, Query
from app.core.database import get_db
from app.schemas import ApiResponse
from app.services.stats_service import stats_service
router = APIRouter()
@router.get("")
async def get_dashboard(db=Depends(get_db)):
data = await stats_service.get_dashboard(db)
return ApiResponse(data=data)
@router.get("/token-trend")
async def get_token_trend(days: int = Query(default=30, ge=7, le=90), db=Depends(get_db)):
trend = await stats_service.get_token_trend(db, days)
return ApiResponse(data=trend)
@router.get("/monthly-token-trend")
async def get_monthly_token_trend(db=Depends(get_db)):
trend = await stats_service.get_monthly_token_trend(db)
return ApiResponse(data=trend)

View File

@@ -0,0 +1,170 @@
"""互动记录接口"""
from typing import Optional
from fastapi import APIRouter, Depends, Query, HTTPException
from fastapi.responses import StreamingResponse
import io, pandas as pd
from app.core.database import get_db
from app.schemas import ApiResponse
from app.services.stats_service import stats_service
from app.models import InteractionRecord
from sqlalchemy import select, update
router = APIRouter()
@router.get("")
async def list_interactions(
page: int = Query(default=1, ge=1),
page_size: int = Query(default=20, ge=1, le=100),
user_id: Optional[int] = None,
interact_type: Optional[str] = None,
status: Optional[int] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
keyword: Optional[str] = None,
db=Depends(get_db)
):
result = await stats_service.get_interaction_records(
db, page, page_size, user_id, interact_type, status, start_date, end_date, keyword
)
return ApiResponse(data=result)
@router.post("/{record_id}/retry")
async def retry_interaction(record_id: int, db=Depends(get_db)):
"""手动重试失败任务"""
result = await db.execute(select(InteractionRecord).where(InteractionRecord.id == record_id))
record = result.scalar_one_or_none()
if not record:
raise HTTPException(status_code=404, detail="记录不存在")
if record.status != 2:
raise HTTPException(status_code=400, detail="只能重试失败的任务")
if record.retry_count >= 3:
raise HTTPException(status_code=400, detail="已超过最大重试次数(3次)")
from app.services.news_service import news_service
from app.services.ai_service import ai_service
from app.models import VirtualUser, UserPersonality
user_result = await db.execute(select(VirtualUser).where(VirtualUser.id == record.user_id))
user = user_result.scalar_one_or_none()
if not user or user.status != 2:
raise HTTPException(status_code=400, detail="用户未登录,无法重试")
success, err = False, "未知类型"
if record.interact_type == "comment" and record.content:
success, err = await news_service.post_comment(db, user, record.article_id, record.article_title or "", record.content)
elif record.interact_type == "like":
success, err = await news_service.like_news(db, user, record.article_id, org_id="", title=record.article_title or "")
elif record.interact_type == "collect":
success, err = await news_service.collect_news(db, user, record.article_id, title=record.article_title or "")
elif record.interact_type == "forward":
success, err = await news_service.forward_news(db, user, record.article_id)
await db.execute(
update(InteractionRecord).where(InteractionRecord.id == record_id).values(
status=1 if success else 2,
error_msg=None if success else err,
retry_count=record.retry_count + 1,
)
)
await db.commit()
return ApiResponse(message="重试成功" if success else f"重试失败: {err}")
@router.get("/export")
async def export_interactions(
user_id: Optional[int] = None,
interact_type: Optional[str] = None,
status: Optional[int] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
db=Depends(get_db)
):
"""导出互动记录"""
data = await stats_service.get_interaction_records(
db, 1, 10000, user_id, interact_type, status, start_date, end_date
)
rows = [{
"ID": r["id"], "用户昵称": r["user_nickname"], "用户账号": r["user_account"],
"文章标题": r["article_title"], "互动类型": r["interact_type_label"],
"内容": r["content"] or "", "Token消耗": r["token_consumed"],
"状态": r["status_label"], "失败原因": r["error_msg"] or "",
"重试次数": r["retry_count"], "执行时间": r["executed_at"],
} for r in data["items"]]
df = pd.DataFrame(rows)
buf = io.BytesIO()
df.to_excel(buf, index=False, sheet_name="互动记录")
buf.seek(0)
return StreamingResponse(
buf,
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
headers={"Content-Disposition": "attachment; filename=interactions_export.xlsx"}
)
@router.post("/{record_id}/cancel")
async def cancel_interaction(record_id: int, db=Depends(get_db)):
"""取消互动(取消点赞/收藏/删除评论),转发不支持取消"""
from sqlalchemy import select, update
from app.models import InteractionRecord, VirtualUser
from app.services.news_service import news_service
# 查找互动记录
r = await db.execute(select(InteractionRecord).where(InteractionRecord.id == record_id))
record = r.scalar_one_or_none()
if not record:
return ApiResponse(code=404, message="记录不存在")
if record.status != 1:
return ApiResponse(code=400, message="只能取消成功的互动")
if record.interact_type == "forward":
return ApiResponse(code=400, message="转发互动无法取消")
if record.interact_type == "read":
return ApiResponse(code=400, message="阅读记录无法取消")
# 查找对应用户
ur = await db.execute(select(VirtualUser).where(VirtualUser.id == record.user_id))
user = ur.scalar_one_or_none()
if not user:
return ApiResponse(code=404, message="用户不存在")
# 执行取消
ok = False
err = ""
if record.interact_type in ("like",):
ok, err = await news_service.cancel_like(
db, user,
news_id=record.article_id or "",
org_id=record.session_id or "", # session_id 字段暂存 org_id
title=record.article_title or "",
)
elif record.interact_type == "collect":
ok, err = await news_service.cancel_collect(
db, user,
news_id=record.article_id or "",
title=record.article_title or "",
)
elif record.interact_type in ("comment", "reply"):
comment_id = record.platform_record_id or ""
if not comment_id:
return ApiResponse(code=400, message="缺少评论ID无法删除")
ok, err = await news_service.cancel_comment(
db, user,
news_id=record.article_id or "",
comment_id=comment_id,
)
if ok:
# 更新状态为手动取消status=3
await db.execute(
update(InteractionRecord).where(InteractionRecord.id == record_id).values(
status=3, error_msg="手动取消"
)
)
await db.commit()
return ApiResponse(message="取消成功")
else:
return ApiResponse(code=500, message=f"取消失败: {err}")

View File

@@ -0,0 +1,83 @@
"""日志管理接口"""
import os
from typing import Optional
from fastapi import APIRouter, Depends, Query, HTTPException
from fastapi.responses import FileResponse
from sqlalchemy import select, func, and_
from app.core.database import get_db
from app.schemas import ApiResponse
from app.models import LoginLog
from app.core.config import settings
router = APIRouter()
@router.get("/login")
async def get_login_logs(
page: int = Query(default=1, ge=1),
page_size: int = Query(default=50, ge=1, le=200),
user_id: Optional[int] = None,
action: Optional[str] = None,
db=Depends(get_db)
):
query = select(LoginLog)
conditions = []
if user_id:
conditions.append(LoginLog.user_id == user_id)
if action:
conditions.append(LoginLog.action == action)
if conditions:
query = query.where(and_(*conditions))
total = (await db.execute(select(func.count()).select_from(query.subquery()))).scalar()
query = query.order_by(LoginLog.created_at.desc()).offset((page - 1) * page_size).limit(page_size)
result = await db.execute(query)
logs = result.scalars().all()
items = [{
"id": l.id, "user_id": l.user_id, "user_account": l.user_account,
"action": l.action, "session_id": l.session_id,
"error_msg": l.error_msg, "created_at": l.created_at.isoformat()
} for l in logs]
return ApiResponse(data={"total": total, "page": page, "page_size": page_size, "items": items})
@router.get("/files")
async def list_log_files():
"""列出日志文件"""
log_dir = settings.LOG_DIR
files = []
if os.path.exists(log_dir):
for fname in sorted(os.listdir(log_dir), reverse=True):
if fname.endswith(".log"):
fpath = os.path.join(log_dir, fname)
size = os.path.getsize(fpath)
files.append({"name": fname, "size": size,
"size_kb": round(size / 1024, 1)})
return ApiResponse(data=files)
@router.get("/files/{filename}/tail")
async def tail_log_file(filename: str, lines: int = Query(default=100, ge=10, le=1000)):
"""读取日志文件末尾"""
# 安全校验
if ".." in filename or "/" in filename:
raise HTTPException(status_code=400, detail="非法文件名")
fpath = os.path.join(settings.LOG_DIR, filename)
if not os.path.exists(fpath):
raise HTTPException(status_code=404, detail="文件不存在")
with open(fpath, "r", encoding="utf-8", errors="replace") as f:
all_lines = f.readlines()
tail = all_lines[-lines:]
return ApiResponse(data={"filename": filename, "lines": tail, "total_lines": len(all_lines)})
@router.get("/files/{filename}/download")
async def download_log_file(filename: str):
"""下载日志文件"""
if ".." in filename or "/" in filename:
raise HTTPException(status_code=400, detail="非法文件名")
fpath = os.path.join(settings.LOG_DIR, filename)
if not os.path.exists(fpath):
raise HTTPException(status_code=404, detail="文件不存在")
return FileResponse(fpath, filename=filename, media_type="text/plain")

View File

@@ -0,0 +1,115 @@
"""系统设置接口"""
from fastapi import APIRouter, Depends
from sqlalchemy import select, update as sql_update
from app.core.database import get_db
from app.schemas import ApiResponse
from app.models import SystemConfig
router = APIRouter()
@router.get("/configs")
async def get_configs(db=Depends(get_db)):
result = await db.execute(select(SystemConfig).order_by(SystemConfig.config_key))
configs = result.scalars().all()
data = {c.config_key: {"value": c.config_value, "type": c.config_type, "desc": c.description}
for c in configs}
return ApiResponse(data=data)
@router.put("/configs")
async def update_configs(body: dict, db=Depends(get_db)):
"""批量更新配置"""
for key, value in body.items():
result = await db.execute(select(SystemConfig).where(SystemConfig.config_key == key))
cfg = result.scalar_one_or_none()
if cfg:
cfg.config_value = str(value)
else:
db.add(SystemConfig(config_key=key, config_value=str(value)))
await db.commit()
return ApiResponse(message="配置已保存")
@router.post("/scheduler/toggle")
async def toggle_scheduler(body: dict, db=Depends(get_db)):
enabled = body.get("enabled", True)
result = await db.execute(select(SystemConfig).where(SystemConfig.config_key == "scheduler_enabled"))
cfg = result.scalar_one_or_none()
if cfg:
cfg.config_value = "true" if enabled else "false"
await db.commit()
return ApiResponse(message=f"调度器已{'启用' if enabled else '暂停'}")
@router.post("/sessions/reset-all")
async def reset_all_sessions(db=Depends(get_db)):
"""重置所有用户会话"""
from app.models import VirtualUser
await db.execute(
sql_update(VirtualUser).values(status=0, session_token=None, session_expires_at=None)
)
await db.commit()
return ApiResponse(message="所有会话已重置")
@router.post("/login/diagnose")
async def diagnose_login(body: dict, db=Depends(get_db)):
"""
诊断登录接口原始响应 - 临时调试用
传入: {"username": "xxx", "password": "xxx"}
"""
import httpx, hashlib, uuid
from datetime import datetime
from app.services.news_service import news_service
cfg = await news_service._client(db)
auth = await news_service._auth_url(db)
username = body.get("username", "")
password = body.get("password", "")
# 构建 formData与真实登录完全一致
extra = {
"username": username,
"password": password,
"loginType": "password",
"grantType": "password",
"isRegister": "false",
}
if cfg.get("clientCode"):
extra["clientCode"] = cfg["clientCode"]
form = news_service._build_form(extra, cfg)
try:
async with httpx.AsyncClient(timeout=15) as c:
resp = await c.post(f"{auth}/open/login/token", data=form)
# 返回完整诊断信息
try:
resp_json = resp.json()
except Exception:
resp_json = None
return ApiResponse(data={
"status_code": resp.status_code,
"response_text": resp.text[:2000],
"response_json": resp_json,
"request_url": f"{auth}/open/login/token",
"request_form": {k: v if k not in ("password","accessSecret") else "***" for k, v in form.items()},
"content_type": resp.headers.get("content-type", ""),
})
except Exception as e:
return ApiResponse(code=500, message=str(e), data={"error": str(e)})
@router.post("/interaction/run-now")
async def run_interaction_now(db=Depends(get_db)):
"""立即触发一次互动任务(不受时间段限制)"""
from app.services.scheduler import scheduler_service
try:
result = await scheduler_service.run_once_now(db)
return ApiResponse(data=result, message="互动任务已触发")
except Exception as e:
return ApiResponse(code=500, message=f"触发失败: {e}")

View File

@@ -0,0 +1,370 @@
"""虚拟用户管理接口"""
from typing import Optional
from fastapi import APIRouter, Depends, Query, UploadFile, File, HTTPException
from fastapi.responses import StreamingResponse
import io
from app.core.database import get_db
from app.schemas import ApiResponse, UserCreateRequest, UserUpdateRequest, UserBatchRequest, PersonalityUpdateRequest
from app.services.user_service import user_service
router = APIRouter()
@router.get("")
async def list_users(
page: int = Query(default=1, ge=1),
page_size: int = Query(default=20, ge=1, le=100),
keyword: Optional[str] = Query(default=None),
status: Optional[int] = Query(default=None),
is_enabled: Optional[int] = Query(default=None),
db=Depends(get_db)
):
"""获取虚拟用户列表"""
total, items = await user_service.get_users(db, page, page_size, keyword, status, is_enabled)
return ApiResponse(data={"total": total, "page": page, "page_size": page_size, "items": items})
@router.post("")
async def create_user(req: UserCreateRequest, db=Depends(get_db)):
"""创建虚拟用户"""
user = await user_service.create_user(db, req)
return ApiResponse(data=user, message="用户创建成功")
@router.get("/{user_id}")
async def get_user(user_id: int, db=Depends(get_db)):
"""获取单个用户详情"""
total, items = await user_service.get_users(db, 1, 1)
from sqlalchemy import select
from app.models import VirtualUser, UserPersonality
from app.services.user_service import user_service as svc
result = await db.execute(select(VirtualUser).where(VirtualUser.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
p_result = await db.execute(select(UserPersonality).where(UserPersonality.user_id == user_id))
personality = p_result.scalar_one_or_none()
return ApiResponse(data=svc._format_user(user, personality))
@router.put("/{user_id}")
async def update_user(user_id: int, req: UserUpdateRequest, db=Depends(get_db)):
"""更新用户信息sync_to_platform=true 时同步到目标平台)"""
result = await user_service.update_user(db, user_id, req)
if req.sync_to_platform:
from sqlalchemy import select
from app.models import VirtualUser as _VU
from app.services.news_service import news_service
ur = await db.execute(select(_VU).where(_VU.id == user_id))
user = ur.scalar_one_or_none()
if user and user.status == 2:
ok, err = await news_service.update_user_profile(
db, user,
nick_name=req.nickname,
real_name=req.real_name,
sex=req.sex,
description=req.description,
email=req.email,
)
if not ok:
return ApiResponse(data=result,
message=f"本地已保存,同步到平台失败: {err}", code=206)
return ApiResponse(data=result, message="更新成功")
@router.delete("/{user_id}")
async def delete_user(user_id: int, db=Depends(get_db)):
"""删除用户"""
await user_service.delete_user(db, user_id)
return ApiResponse(message="删除成功")
@router.post("/batch/action")
async def batch_action(req: UserBatchRequest, db=Depends(get_db)):
"""批量操作用户"""
result = await user_service.batch_action(db, req.user_ids, req.action)
return ApiResponse(data=result, message="批量操作成功")
@router.post("/{user_id}/login")
async def manual_login(user_id: int, db=Depends(get_db)):
"""手动触发用户登录"""
from app.services.news_service import news_service
from sqlalchemy import select
from app.models import VirtualUser
result = await db.execute(select(VirtualUser).where(VirtualUser.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
success = await news_service.login(db, user)
if success:
return ApiResponse(message="登录成功")
raise HTTPException(status_code=400, detail="登录失败,请检查账号密码")
@router.post("/{user_id}/logout")
async def manual_logout(user_id: int, db=Depends(get_db)):
"""手动登出"""
from app.services.news_service import news_service
await news_service.logout(db, user_id)
return ApiResponse(message="已登出")
@router.post("/{user_id}/personality/generate")
async def generate_personality(user_id: int, db=Depends(get_db)):
"""重新生成AI人格"""
personality = await user_service.generate_personality(db, user_id)
return ApiResponse(data=personality, message="人格生成成功")
@router.put("/{user_id}/personality")
async def update_personality(user_id: int, req: PersonalityUpdateRequest, db=Depends(get_db)):
"""手动编辑人格属性"""
personality = await user_service.update_personality(db, user_id, req)
return ApiResponse(data=personality, message="人格更新成功")
@router.get("/excel/template")
async def download_template():
"""下载Excel导入模板"""
content = await user_service.get_excel_template()
return StreamingResponse(
io.BytesIO(content),
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
headers={"Content-Disposition": "attachment; filename=virtual_users_template.xlsx"}
)
@router.post("/excel/import")
async def import_excel(file: UploadFile = File(...), db=Depends(get_db)):
"""Excel批量导入"""
if not file.filename.endswith((".xlsx", ".xls")):
raise HTTPException(status_code=400, detail="仅支持Excel文件(.xlsx/.xls)")
content = await file.read()
result = await user_service.import_from_excel(db, content)
return ApiResponse(data=result, message=f"导入完成:成功{result['success']}条,失败{result['failed']}")
@router.get("/excel/export")
async def export_excel(db=Depends(get_db)):
"""导出用户数据Excel"""
content = await user_service.export_to_excel(db)
return StreamingResponse(
io.BytesIO(content),
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
headers={"Content-Disposition": "attachment; filename=virtual_users_export.xlsx"}
)
@router.post("/deduplicate")
async def deduplicate_users(db=Depends(get_db)):
"""删除重复用户(保留最早创建的一条)"""
from sqlalchemy import text
# 找出重复的账号,保留 id 最小的,删除其他的
result = await db.execute(
text("""
DELETE FROM virtual_users
WHERE id NOT IN (
SELECT MIN(id) FROM virtual_users GROUP BY account
)
""")
)
await db.commit()
deleted = result.rowcount
return ApiResponse(data={"deleted": deleted}, message=f"已清理 {deleted} 条重复数据")
@router.post("/clear-all")
async def clear_all_users(db=Depends(get_db)):
"""清空所有用户(慎用)"""
from sqlalchemy import text
from app.core.redis_client import get_redis
await db.execute(text("DELETE FROM user_personalities"))
await db.execute(text("DELETE FROM virtual_users"))
await db.commit()
return ApiResponse(message="已清空所有用户数据")
@router.post("/login-all")
async def batch_login_all(db=Depends(get_db)):
"""一键登录所有未登录/登录失效的用户"""
from sqlalchemy import select
from app.services.news_service import news_service
from app.models import VirtualUser as _VU
from app.core.database import AsyncSessionLocal
import asyncio
# 先用当前 session 查出所有待登录用户 ID
result_r = await db.execute(
select(_VU.id, _VU.account).where(
_VU.is_enabled == 1,
_VU.status.in_([0, 3])
)
)
rows = result_r.all()
if not rows:
return ApiResponse(message="没有需要登录的用户", data={"count": 0})
user_ids = [r[0] for r in rows]
total = len(user_ids)
success = failed = 0
# 每个用户独立 session避免事务污染
async def login_one(uid: int):
async with AsyncSessionLocal() as s:
try:
ur = await s.execute(select(_VU).where(_VU.id == uid))
u = ur.scalar_one_or_none()
if u:
return await news_service.login(s, u)
except Exception as e:
logger.warning(f"login_one {uid} 异常: {e}")
return False
return False
batch_size = 5
for i in range(0, total, batch_size):
batch_ids = user_ids[i:i+batch_size]
results = await asyncio.gather(*[login_one(uid) for uid in batch_ids], return_exceptions=True)
for r in results:
if r is True: success += 1
else: failed += 1
if i + batch_size < total:
await asyncio.sleep(1) # 批次间隔避免过于集中
return ApiResponse(
message=f"登录完成:成功 {success} 个,失败 {failed}",
data={"success": success, "failed": failed, "total": total}
)
@router.post("/sync-all-profiles")
async def sync_all_profiles(db=Depends(get_db)):
"""
同步所有已登录用户的平台信息(昵称/真实姓名/性别/头像)到本系统
从登录 session 中的 token 调用目标平台接口获取最新用户信息
"""
from sqlalchemy import select, update
from app.models import VirtualUser as _VU
from app.core.database import AsyncSessionLocal
from app.core.redis_client import get_session
import httpx, asyncio
# 查出所有已登录用户
result_r = await db.execute(select(_VU).where(_VU.status == 2, _VU.is_enabled == 1))
users = result_r.scalars().all()
if not users:
return ApiResponse(message="没有已登录的用户", data={"synced": 0})
synced = failed = 0
async def sync_one(uid: int):
"""从登录 session 中提取已缓存的用户信息,直接写入数据库,无需调用外部接口"""
async with AsyncSessionLocal() as s:
try:
sess = await get_session(uid)
if not sess:
return False
platform_uid = sess.get("platform_uid", "")
# 登录成功时 session 里已存有用户信息
vals = {}
if platform_uid: vals["platform_uid"] = platform_uid
# session 里的字段(登录时写入)
if sess.get("nickname"): vals["nickname"] = sess["nickname"]
if sess.get("real_name"): vals["real_name"] = sess["real_name"]
if sess.get("sex"): vals["sex"] = int(sess["sex"])
if sess.get("avatar"): vals["avatar_url"] = sess["avatar"]
if vals:
await s.execute(update(_VU).where(_VU.id == uid).values(**vals))
await s.commit()
return True
except Exception as e:
logger.warning(f"sync_one {uid} 失败: {e}")
return False
results = await asyncio.gather(*[sync_one(u.id) for u in users], return_exceptions=True)
for r in results:
if r is True: synced += 1
else: failed += 1
return ApiResponse(
message=f"同步完成:成功 {synced} 个,失败/跳过 {failed}",
data={"synced": synced, "failed": failed, "total": len(users)}
)
@router.post("/{user_id}/upload-avatar")
async def upload_avatar(
user_id: int,
file: UploadFile = File(...),
sync_to_platform: bool = Query(default=True),
db=Depends(get_db)
):
"""上传头像并可选同步到目标平台"""
from sqlalchemy import select, update
from app.models import VirtualUser as _VU
from app.services.news_service import news_service
ur = await db.execute(select(_VU).where(_VU.id == user_id))
user = ur.scalar_one_or_none()
if not user:
return ApiResponse(code=404, message="用户不存在")
# 读取文件内容
file_bytes = await file.read()
if len(file_bytes) > 5 * 1024 * 1024:
return ApiResponse(code=400, message="头像文件不能超过5MB")
avatar_url = None
if sync_to_platform and user.status == 2:
# 上传到目标平台
ok, result = await news_service.upload_avatar(db, user, file_bytes, file.filename)
if ok:
avatar_url = result
else:
return ApiResponse(code=500, message=f"头像上传到平台失败: {result}")
else:
# 仅本地存储(转 base64 或存储到本地)
import base64
avatar_url = f"data:{file.content_type};base64,{base64.b64encode(file_bytes).decode()}"
# 更新数据库
await db.execute(update(_VU).where(_VU.id == user_id).values(avatar_url=avatar_url))
await db.commit()
# 如果已同步到平台,再调用 update_user_profile 更新头像字段
if sync_to_platform and user.status == 2 and avatar_url:
await news_service.update_user_profile(db, user, avatar=avatar_url)
return ApiResponse(data={"avatar_url": avatar_url}, message="头像更新成功")
@router.post("/logout-all")
async def batch_logout_all(db=Depends(get_db)):
"""一键登出所有已登录用户"""
from sqlalchemy import select, update
from app.models import VirtualUser as _VU
from app.core.redis_client import delete_session
result_r = await db.execute(
select(_VU.id).where(_VU.status == 2, _VU.is_enabled == 1)
)
rows = result_r.all()
if not rows:
return ApiResponse(message="没有已登录的用户", data={"count": 0})
count = 0
for row in rows:
try:
await delete_session(row[0])
count += 1
except Exception:
pass
# 更新所有用户状态为未登录
await db.execute(
update(_VU).where(_VU.status == 2).values(status=0)
)
await db.commit()
return ApiResponse(message=f"已登出 {count} 个用户", data={"count": count})

View File

@@ -0,0 +1 @@
# app.core package

View File

@@ -0,0 +1,46 @@
"""系统配置"""
import os
from urllib.parse import quote_plus
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
# 数据库
DB_HOST: str = os.getenv("DB_HOST", "localhost")
DB_PORT: int = int(os.getenv("DB_PORT", "3306"))
DB_USER: str = os.getenv("DB_USER", "aivirtual")
DB_PASSWORD: str = os.getenv("DB_PASSWORD", "AiVirtual2024")
DB_NAME: str = os.getenv("DB_NAME", "ai_virtual_news")
# Redis
REDIS_HOST: str = os.getenv("REDIS_HOST", "localhost")
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
# 安全
SECRET_KEY: str = os.getenv("SECRET_KEY", "dev-secret-key-change-in-prod")
AES_KEY: str = os.getenv("AES_KEY", "your-aes-key-32-chars-change-now!")
# 新闻平台
NEWS_PLATFORM_BASE_URL: str = os.getenv(
"NEWS_PLATFORM_BASE_URL", "http://192.168.1.200:63120"
)
# 日志目录
LOG_DIR: str = "/app/logs"
@property
def DATABASE_URL(self) -> str:
# 对密码做 URL 编码,防止 @ # ! 等特殊字符破坏连接字符串
pwd = quote_plus(self.DB_PASSWORD)
return f"mysql+aiomysql://{self.DB_USER}:{pwd}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}?charset=utf8mb4"
@property
def SYNC_DATABASE_URL(self) -> str:
pwd = quote_plus(self.DB_PASSWORD)
return f"mysql+pymysql://{self.DB_USER}:{pwd}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}?charset=utf8mb4"
class Config:
env_file = ".env"
settings = Settings()

View File

@@ -0,0 +1,72 @@
"""数据库连接管理"""
import asyncio
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
from app.core.config import settings
from app.core.logger import logger
class Base(DeclarativeBase):
pass
engine = create_async_engine(
settings.DATABASE_URL,
echo=False,
pool_pre_ping=True,
pool_recycle=3600,
pool_size=10,
max_overflow=20,
)
AsyncSessionLocal = async_sessionmaker(
engine, class_=AsyncSession, expire_on_commit=False
)
async def get_db():
"""获取数据库会话"""
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def wait_for_db(max_retries: int = 30, interval: int = 2):
"""等待 MySQL 就绪,最多重试 max_retries 次"""
for attempt in range(1, max_retries + 1):
try:
async with engine.begin() as conn:
await conn.execute(__import__("sqlalchemy").text("SELECT 1"))
logger.info(f"✅ 数据库连接成功(第 {attempt} 次尝试)")
return
except Exception as e:
if attempt == max_retries:
logger.error(f"数据库连接失败,已重试 {max_retries} 次: {e}")
raise
logger.warning(f"数据库未就绪,{interval}s 后重试({attempt}/{max_retries}: {e}")
await asyncio.sleep(interval)
async def init_db():
"""初始化数据库 - 等待 MySQL 就绪并注册所有模型"""
try:
# 等待 MySQL 容器真正就绪
await wait_for_db(max_retries=30, interval=2)
# 导入所有模型类,确保 SQLAlchemy ORM 元数据注册
from app.models import (
VirtualUser, UserPersonality, InteractionRecord,
TokenStat, AIModelConfig, SystemConfig, LoginLog
)
logger.info("✅ 数据库模型注册成功")
logger.info("✅ 数据库初始化完成")
except Exception as e:
logger.error(f"数据库初始化失败: {e}")
raise

View File

@@ -0,0 +1,48 @@
"""日志配置"""
import sys
import os
from loguru import logger
LOG_DIR = os.getenv("LOG_DIR", "./logs")
os.makedirs(LOG_DIR, exist_ok=True)
# 移除默认处理器
logger.remove()
# 控制台输出
logger.add(
sys.stdout,
level="INFO",
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
)
# 通用日志文件
logger.add(
f"{LOG_DIR}/app_{{time:YYYY-MM-DD}}.log",
rotation="00:00",
retention="30 days",
level="INFO",
encoding="utf-8",
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}",
)
# 错误日志文件
logger.add(
f"{LOG_DIR}/error_{{time:YYYY-MM-DD}}.log",
rotation="00:00",
retention="30 days",
level="ERROR",
encoding="utf-8",
)
# AI调用日志
logger.add(
f"{LOG_DIR}/ai_{{time:YYYY-MM-DD}}.log",
rotation="00:00",
retention="30 days",
level="INFO",
encoding="utf-8",
filter=lambda record: "ai_call" in record["extra"],
)
__all__ = ["logger"]

View File

@@ -0,0 +1,81 @@
"""Redis缓存客户端"""
import json
import redis.asyncio as aioredis
from app.core.config import settings
from app.core.logger import logger
_redis_client = None
async def get_redis() -> aioredis.Redis:
global _redis_client
if _redis_client is None:
_redis_client = aioredis.from_url(
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}",
encoding="utf-8",
decode_responses=True,
)
return _redis_client
# Session键前缀
SESSION_PREFIX = "session:"
LOCK_PREFIX = "lock:"
RATE_PREFIX = "rate:"
async def set_session(user_id: int, session_data: dict, expire: int = 86400):
"""存储用户会话"""
r = await get_redis()
key = f"{SESSION_PREFIX}{user_id}"
await r.setex(key, expire, json.dumps(session_data, ensure_ascii=False))
async def get_session(user_id: int) -> dict | None:
"""获取用户会话"""
r = await get_redis()
key = f"{SESSION_PREFIX}{user_id}"
data = await r.get(key)
if data:
return json.loads(data)
return None
async def delete_session(user_id: int):
"""删除用户会话"""
r = await get_redis()
key = f"{SESSION_PREFIX}{user_id}"
await r.delete(key)
async def acquire_lock(name: str, expire: int = 60) -> bool:
"""获取分布式锁"""
r = await get_redis()
key = f"{LOCK_PREFIX}{name}"
result = await r.set(key, "1", nx=True, ex=expire)
return result is True
async def release_lock(name: str):
"""释放分布式锁"""
r = await get_redis()
key = f"{LOCK_PREFIX}{name}"
await r.delete(key)
async def incr_rate(key: str, expire: int = 86400) -> int:
"""限流计数"""
r = await get_redis()
rate_key = f"{RATE_PREFIX}{key}"
count = await r.incr(rate_key)
if count == 1:
await r.expire(rate_key, expire)
return count
async def get_counter(key: str) -> int:
"""获取计数"""
r = await get_redis()
rate_key = f"{RATE_PREFIX}{key}"
val = await r.get(rate_key)
return int(val) if val else 0

65
backend/app/main.py Normal file
View File

@@ -0,0 +1,65 @@
"""
AI虚拟用户新闻互动系统 - 后端主入口
"""
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from app.core.config import settings
from app.core.database import init_db
from app.core.logger import logger
from app.api import router
from app.services.scheduler import scheduler_service
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
logger.info("🚀 AI虚拟用户新闻互动系统启动中...")
# 初始化数据库
await init_db()
# 启动调度器
await scheduler_service.start()
logger.info("✅ 系统启动完成")
yield
# 关闭调度器
await scheduler_service.stop()
logger.info("🛑 系统已关闭")
app = FastAPI(
title="AI虚拟用户新闻互动系统",
description="基于AI驱动的虚拟用户新闻互动自动化平台",
version="1.0.0",
lifespan=lifespan,
docs_url="/api/docs",
redoc_url="/api/redoc",
)
# CORS配置
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册路由
app.include_router(router, prefix="/api")
@app.get("/health")
async def health_check():
return {"status": "ok", "service": "ai-virtual-news-backend"}
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
logger.error(f"全局异常: {exc}")
return JSONResponse(
status_code=500,
content={"code": 500, "message": f"服务器内部错误: {str(exc)}"},
)

View File

@@ -0,0 +1,132 @@
"""SQLAlchemy ORM 模型"""
from datetime import datetime
from sqlalchemy import (
BigInteger, Integer, SmallInteger, String, Text, DateTime,
Boolean, Float, Date, JSON, func
)
from sqlalchemy.orm import Mapped, mapped_column
from app.core.database import Base
class VirtualUser(Base):
__tablename__ = "virtual_users"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
nickname: Mapped[str] = mapped_column(String(64), nullable=False)
account: Mapped[str] = mapped_column(String(128), nullable=False, unique=True)
password_enc: Mapped[str] = mapped_column(String(512), nullable=False)
avatar_url: Mapped[str | None] = mapped_column(String(512))
status: Mapped[int] = mapped_column(SmallInteger, default=0)
activity_level: Mapped[int] = mapped_column(SmallInteger, default=1)
daily_comment_limit: Mapped[int] = mapped_column(Integer, default=10)
daily_like_limit: Mapped[int] = mapped_column(Integer, default=30)
today_comment_count: Mapped[int] = mapped_column(Integer, default=0)
today_like_count: Mapped[int] = mapped_column(Integer, default=0)
total_interactions: Mapped[int] = mapped_column(Integer, default=0)
session_token: Mapped[str | None] = mapped_column(Text)
session_expires_at: Mapped[datetime | None] = mapped_column(DateTime)
last_login_at: Mapped[datetime | None] = mapped_column(DateTime)
last_interact_at: Mapped[datetime | None] = mapped_column(DateTime)
real_name: Mapped[str | None] = mapped_column(String(64)) # 真实姓名(从平台同步)
sex: Mapped[int] = mapped_column(SmallInteger, default=0) # 性别 0未知 1男 2女
platform_uid: Mapped[str | None] = mapped_column(String(64)) # 平台用户ID
remark: Mapped[str | None] = mapped_column(String(256))
is_enabled: Mapped[int] = mapped_column(SmallInteger, default=1)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now())
class UserPersonality(Base):
__tablename__ = "user_personalities"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False, unique=True)
character_type: Mapped[str | None] = mapped_column(String(32))
language_style: Mapped[str | None] = mapped_column(String(32))
interest_tags: Mapped[dict | None] = mapped_column(JSON)
interact_tendency: Mapped[str | None] = mapped_column(String(32))
word_count_min: Mapped[int] = mapped_column(Integer, default=20)
word_count_max: Mapped[int] = mapped_column(Integer, default=100)
personality_desc: Mapped[str | None] = mapped_column(Text)
comment_style_prompt: Mapped[str | None] = mapped_column(Text)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now())
class InteractionRecord(Base):
__tablename__ = "interaction_records"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False, index=True)
user_nickname: Mapped[str | None] = mapped_column(String(64))
user_account: Mapped[str | None] = mapped_column(String(128))
article_id: Mapped[str | None] = mapped_column(String(64))
article_title: Mapped[str | None] = mapped_column(String(256))
interact_type: Mapped[str] = mapped_column(String(16), nullable=False, index=True)
content: Mapped[str | None] = mapped_column(Text)
platform_record_id: Mapped[str | None] = mapped_column(String(64)) # 平台返回的记录ID用于取消互动
parent_comment_id: Mapped[str | None] = mapped_column(String(64))
session_id: Mapped[str | None] = mapped_column(String(128))
token_consumed: Mapped[int] = mapped_column(Integer, default=0)
status: Mapped[int] = mapped_column(SmallInteger, default=0)
error_msg: Mapped[str | None] = mapped_column(String(512))
retry_count: Mapped[int] = mapped_column(SmallInteger, default=0)
executed_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), index=True)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
class TokenStat(Base):
__tablename__ = "token_stats"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
stat_date: Mapped[datetime] = mapped_column(Date, nullable=False, unique=True)
model_name: Mapped[str | None] = mapped_column(String(64))
total_tokens: Mapped[int] = mapped_column(Integer, default=0)
prompt_tokens: Mapped[int] = mapped_column(Integer, default=0)
completion_tokens: Mapped[int] = mapped_column(Integer, default=0)
call_count: Mapped[int] = mapped_column(Integer, default=0)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now())
class AIModelConfig(Base):
__tablename__ = "ai_model_configs"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
model_name: Mapped[str] = mapped_column(String(64), nullable=False)
provider: Mapped[str] = mapped_column(String(32), nullable=False)
api_base_url: Mapped[str | None] = mapped_column(String(256))
api_key_enc: Mapped[str | None] = mapped_column(String(512))
model_version: Mapped[str | None] = mapped_column(String(64))
temperature: Mapped[float] = mapped_column(Float, default=0.7)
max_tokens: Mapped[int] = mapped_column(Integer, default=1000)
timeout_seconds: Mapped[int] = mapped_column(Integer, default=30)
is_default: Mapped[int] = mapped_column(SmallInteger, default=0)
is_enabled: Mapped[int] = mapped_column(SmallInteger, default=1)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now())
class SystemConfig(Base):
__tablename__ = "system_configs"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
config_key: Mapped[str] = mapped_column(String(64), nullable=False, unique=True)
config_value: Mapped[str | None] = mapped_column(Text)
config_type: Mapped[str] = mapped_column(String(16), default="string")
description: Mapped[str | None] = mapped_column(String(256))
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now())
class LoginLog(Base):
__tablename__ = "login_logs"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False, index=True)
user_account: Mapped[str | None] = mapped_column(String(128))
action: Mapped[str] = mapped_column(String(16), nullable=False)
session_id: Mapped[str | None] = mapped_column(String(128))
ip_address: Mapped[str | None] = mapped_column(String(64))
error_msg: Mapped[str | None] = mapped_column(String(512))
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), index=True)

View File

@@ -0,0 +1,19 @@
# models package - re-export all models
from app.models import (
VirtualUser, UserPersonality, InteractionRecord,
TokenStat, AIModelConfig, SystemConfig, LoginLog
)
# Aliases for import compatibility
virtual_user = VirtualUser
personality = UserPersonality
interaction = InteractionRecord
token_stat = TokenStat
ai_model = AIModelConfig
system_config = SystemConfig
login_log = LoginLog
__all__ = [
"VirtualUser", "UserPersonality", "InteractionRecord",
"TokenStat", "AIModelConfig", "SystemConfig", "LoginLog",
]

220
backend/app/schemas/__init__.py Executable file
View File

@@ -0,0 +1,220 @@
"""Pydantic数据模型 - 请求/响应模式"""
from datetime import datetime
from typing import Optional, List, Any
from pydantic import BaseModel, Field
# ===== 通用响应 =====
class ApiResponse(BaseModel):
code: int = 200
message: str = "success"
data: Any = None
class PageResult(BaseModel):
total: int
page: int
page_size: int
items: List[Any]
# ===== 虚拟用户 =====
class UserCreateRequest(BaseModel):
# 必填
account: str = Field(..., min_length=1, max_length=128, description="新闻平台账号(必填)")
password: str = Field(..., min_length=6, max_length=64, description="登录密码(必填)")
# 选填
nickname: Optional[str] = Field(None, max_length=64, description="昵称(选填,为空自动生成)")
avatar_url: Optional[str] = None
activity_level: int = Field(default=1, ge=0, le=2)
daily_comment_limit: int = Field(default=10, ge=1, le=100)
daily_like_limit: int = Field(default=30, ge=1, le=200)
remark: Optional[str] = None
class UserUpdateRequest(BaseModel):
nickname: Optional[str] = Field(None, min_length=1, max_length=64)
password: Optional[str] = Field(None, min_length=6, max_length=64)
avatar_url: Optional[str] = None
real_name: Optional[str] = None
sex: Optional[int] = None
description: Optional[str] = None
email: Optional[str] = None
activity_level: Optional[int] = Field(None, ge=0, le=2)
daily_comment_limit: Optional[int] = Field(None, ge=1, le=100)
daily_like_limit: Optional[int] = Field(None, ge=1, le=200)
remark: Optional[str] = None
is_enabled: Optional[int] = None
sync_to_platform: bool = False
class UserResponse(BaseModel):
id: int
nickname: str
account: str
avatar_url: Optional[str]
real_name: Optional[str] = None
sex: int = 0
platform_uid: Optional[str] = None
status: int
status_label: str
activity_level: int
activity_label: str
daily_comment_limit: int
daily_like_limit: int
today_comment_count: int
today_like_count: int
total_interactions: int
last_login_at: Optional[datetime]
last_interact_at: Optional[datetime]
remark: Optional[str]
is_enabled: int
created_at: datetime
personality: Optional[dict] = None
class Config:
from_attributes = True
class UserBatchRequest(BaseModel):
user_ids: List[int]
action: str # enable/disable/logout/delete
# ===== 人格 =====
class PersonalityUpdateRequest(BaseModel):
character_type: Optional[str] = None
language_style: Optional[str] = None
interest_tags: Optional[List[str]] = None
interact_tendency: Optional[str] = None
word_count_min: Optional[int] = Field(None, ge=10, le=500)
word_count_max: Optional[int] = Field(None, ge=10, le=1000)
personality_desc: Optional[str] = None
class PersonalityResponse(BaseModel):
id: int
user_id: int
character_type: Optional[str]
language_style: Optional[str]
interest_tags: Optional[List[str]]
interact_tendency: Optional[str]
word_count_min: int
word_count_max: int
personality_desc: Optional[str]
updated_at: datetime
class Config:
from_attributes = True
# ===== 互动记录 =====
class InteractionQueryParams(BaseModel):
page: int = Field(default=1, ge=1)
page_size: int = Field(default=20, ge=1, le=100)
user_id: Optional[int] = None
interact_type: Optional[str] = None
status: Optional[int] = None
start_date: Optional[str] = None
end_date: Optional[str] = None
keyword: Optional[str] = None
class InteractionResponse(BaseModel):
id: int
user_id: int
user_nickname: Optional[str]
user_account: Optional[str]
article_id: Optional[str]
article_title: Optional[str]
interact_type: str
interact_type_label: str
content: Optional[str]
token_consumed: int
status: int
status_label: str
error_msg: Optional[str]
retry_count: int
executed_at: datetime
class Config:
from_attributes = True
# ===== AI模型配置 =====
class AIModelCreateRequest(BaseModel):
model_name: str = Field(..., min_length=1, max_length=64)
provider: str = Field(..., pattern="^(openai|zhipu|wenxin|qianwen|local)$")
api_base_url: Optional[str] = None
api_key: Optional[str] = None
model_version: Optional[str] = None
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
max_tokens: int = Field(default=1000, ge=1, le=32000)
timeout_seconds: int = Field(default=30, ge=5, le=300)
is_default: int = Field(default=0, ge=0, le=1)
class AIModelUpdateRequest(BaseModel):
model_name: Optional[str] = None
api_base_url: Optional[str] = None
api_key: Optional[str] = None
model_version: Optional[str] = None
temperature: Optional[float] = Field(None, ge=0.0, le=2.0)
max_tokens: Optional[int] = Field(None, ge=1, le=32000)
timeout_seconds: Optional[int] = Field(None, ge=5, le=300)
is_default: Optional[int] = None
is_enabled: Optional[int] = None
class AIModelResponse(BaseModel):
id: int
model_name: str
provider: str
api_base_url: Optional[str]
has_api_key: bool
model_version: Optional[str]
temperature: float
max_tokens: int
timeout_seconds: int
is_default: int
is_enabled: int
created_at: datetime
class Config:
from_attributes = True
class AIModelTestRequest(BaseModel):
model_id: int
test_prompt: str = "你好,请简单介绍一下自己。"
# ===== 系统配置 =====
class SystemConfigUpdateRequest(BaseModel):
configs: dict
# ===== 数据统计 =====
class DashboardResponse(BaseModel):
user_stats: dict
today_interactions: dict
monthly_stats: dict
token_stats: dict
system_status: dict
online_users: int
# ===== 调度配置 =====
class SchedulerConfigRequest(BaseModel):
interact_time_start: Optional[str] = None
interact_time_end: Optional[str] = None
interact_interval_min: Optional[int] = None
interact_interval_max: Optional[int] = None
max_concurrent_users: Optional[int] = None
daily_token_limit: Optional[int] = None
comment_probability: Optional[float] = None
reply_probability: Optional[float] = None
like_probability: Optional[float] = None
collect_probability: Optional[float] = None
forward_probability: Optional[float] = None
scheduler_enabled: Optional[bool] = None

View File

@@ -0,0 +1,215 @@
"""Pydantic数据模型 - 请求/响应模式"""
from datetime import datetime
from typing import Optional, List, Any
from pydantic import BaseModel, Field
# ===== 通用响应 =====
class ApiResponse(BaseModel):
code: int = 200
message: str = "success"
data: Any = None
class PageResult(BaseModel):
total: int
page: int
page_size: int
items: List[Any]
# ===== 虚拟用户 =====
class UserCreateRequest(BaseModel):
# 必填
account: str = Field(..., min_length=1, max_length=128, description="新闻平台账号(必填)")
password: str = Field(..., min_length=6, max_length=64, description="登录密码(必填)")
# 选填
nickname: Optional[str] = Field(None, max_length=64, description="昵称(选填,为空自动生成)")
avatar_url: Optional[str] = None
activity_level: int = Field(default=1, ge=0, le=2)
daily_comment_limit: int = Field(default=10, ge=1, le=100)
daily_like_limit: int = Field(default=30, ge=1, le=200)
remark: Optional[str] = None
class UserUpdateRequest(BaseModel):
nickname: Optional[str] = Field(None, min_length=1, max_length=64)
password: Optional[str] = Field(None, min_length=6, max_length=64)
avatar_url: Optional[str] = None
activity_level: Optional[int] = Field(None, ge=0, le=2)
daily_comment_limit: Optional[int] = Field(None, ge=1, le=100)
daily_like_limit: Optional[int] = Field(None, ge=1, le=200)
remark: Optional[str] = None
is_enabled: Optional[int] = None
class UserResponse(BaseModel):
id: int
nickname: str
account: str
avatar_url: Optional[str]
real_name: Optional[str] = None
sex: int = 0
platform_uid: Optional[str] = None
status: int
status_label: str
activity_level: int
activity_label: str
daily_comment_limit: int
daily_like_limit: int
today_comment_count: int
today_like_count: int
total_interactions: int
last_login_at: Optional[datetime]
last_interact_at: Optional[datetime]
remark: Optional[str]
is_enabled: int
created_at: datetime
personality: Optional[dict] = None
class Config:
from_attributes = True
class UserBatchRequest(BaseModel):
user_ids: List[int]
action: str # enable/disable/logout/delete
# ===== 人格 =====
class PersonalityUpdateRequest(BaseModel):
character_type: Optional[str] = None
language_style: Optional[str] = None
interest_tags: Optional[List[str]] = None
interact_tendency: Optional[str] = None
word_count_min: Optional[int] = Field(None, ge=10, le=500)
word_count_max: Optional[int] = Field(None, ge=10, le=1000)
personality_desc: Optional[str] = None
class PersonalityResponse(BaseModel):
id: int
user_id: int
character_type: Optional[str]
language_style: Optional[str]
interest_tags: Optional[List[str]]
interact_tendency: Optional[str]
word_count_min: int
word_count_max: int
personality_desc: Optional[str]
updated_at: datetime
class Config:
from_attributes = True
# ===== 互动记录 =====
class InteractionQueryParams(BaseModel):
page: int = Field(default=1, ge=1)
page_size: int = Field(default=20, ge=1, le=100)
user_id: Optional[int] = None
interact_type: Optional[str] = None
status: Optional[int] = None
start_date: Optional[str] = None
end_date: Optional[str] = None
keyword: Optional[str] = None
class InteractionResponse(BaseModel):
id: int
user_id: int
user_nickname: Optional[str]
user_account: Optional[str]
article_id: Optional[str]
article_title: Optional[str]
interact_type: str
interact_type_label: str
content: Optional[str]
token_consumed: int
status: int
status_label: str
error_msg: Optional[str]
retry_count: int
executed_at: datetime
class Config:
from_attributes = True
# ===== AI模型配置 =====
class AIModelCreateRequest(BaseModel):
model_name: str = Field(..., min_length=1, max_length=64)
provider: str = Field(..., pattern="^(openai|zhipu|wenxin|qianwen|local)$")
api_base_url: Optional[str] = None
api_key: Optional[str] = None
model_version: Optional[str] = None
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
max_tokens: int = Field(default=1000, ge=1, le=32000)
timeout_seconds: int = Field(default=30, ge=5, le=300)
is_default: int = Field(default=0, ge=0, le=1)
class AIModelUpdateRequest(BaseModel):
model_name: Optional[str] = None
api_base_url: Optional[str] = None
api_key: Optional[str] = None
model_version: Optional[str] = None
temperature: Optional[float] = Field(None, ge=0.0, le=2.0)
max_tokens: Optional[int] = Field(None, ge=1, le=32000)
timeout_seconds: Optional[int] = Field(None, ge=5, le=300)
is_default: Optional[int] = None
is_enabled: Optional[int] = None
class AIModelResponse(BaseModel):
id: int
model_name: str
provider: str
api_base_url: Optional[str]
has_api_key: bool
model_version: Optional[str]
temperature: float
max_tokens: int
timeout_seconds: int
is_default: int
is_enabled: int
created_at: datetime
class Config:
from_attributes = True
class AIModelTestRequest(BaseModel):
model_id: int
test_prompt: str = "你好,请简单介绍一下自己。"
# ===== 系统配置 =====
class SystemConfigUpdateRequest(BaseModel):
configs: dict
# ===== 数据统计 =====
class DashboardResponse(BaseModel):
user_stats: dict
today_interactions: dict
monthly_stats: dict
token_stats: dict
system_status: dict
online_users: int
# ===== 调度配置 =====
class SchedulerConfigRequest(BaseModel):
interact_time_start: Optional[str] = None
interact_time_end: Optional[str] = None
interact_interval_min: Optional[int] = None
interact_interval_max: Optional[int] = None
max_concurrent_users: Optional[int] = None
daily_token_limit: Optional[int] = None
comment_probability: Optional[float] = None
reply_probability: Optional[float] = None
like_probability: Optional[float] = None
collect_probability: Optional[float] = None
forward_probability: Optional[float] = None
scheduler_enabled: Optional[bool] = None

View File

View File

@@ -0,0 +1,258 @@
"""AI服务 - 人格生成、内容创作"""
import json
import random
import re
from typing import Optional
import httpx
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update
from app.models import AIModelConfig, TokenStat
from app.utils.crypto import decrypt
from app.core.logger import logger
from datetime import date
class AIService:
"""AI大模型服务"""
# 人格候选池
CHARACTER_TYPES = ["开朗", "内敛", "毒舌", "温和", "理性", "感性", "幽默", "严谨"]
LANGUAGE_STYLES = ["严肃", "幽默", "文艺", "吐槽", "口语化", "学术", "简洁", "丰富"]
INTEREST_TAGS_POOL = [
"科技", "财经", "娱乐", "体育", "政治", "文化", "教育", "医疗",
"汽车", "房产", "旅游", "美食", "军事", "国际", "环保", "农业"
]
INTERACT_TENDENCIES = ["爱评论", "爱点赞", "爱收藏", "潜水", "爱转发", "爱回复"]
async def _get_default_model(self, db: AsyncSession) -> Optional[AIModelConfig]:
result = await db.execute(
select(AIModelConfig).where(
AIModelConfig.is_default == 1, AIModelConfig.is_enabled == 1
)
)
return result.scalar_one_or_none()
async def _call_api(
self, db: AsyncSession, prompt: str, system_prompt: str = None,
max_tokens: int = None
) -> tuple[str, int]:
"""调用AI接口返回(内容, token数)"""
model = await self._get_default_model(db)
if not model:
# 无模型配置时返回随机预设
return "", 0
api_key = decrypt(model.api_key_enc) if model.api_key_enc else ""
base_url = model.api_base_url or "https://api.openai.com/v1"
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
payload = {
"model": model.model_version or "gpt-3.5-turbo",
"messages": messages,
"temperature": model.temperature,
"max_tokens": max_tokens or model.max_tokens,
}
import asyncio as _asyncio
last_err = None
for attempt in range(3): # 最多重试3次
try:
async with httpx.AsyncClient(timeout=model.timeout_seconds) as client:
resp = await client.post(
f"{base_url}/chat/completions",
headers=headers,
json=payload,
)
# 429 限流:等待后重试
if resp.status_code == 429:
wait = 30 * (attempt + 1) # 30s, 60s, 90s
logger.warning(f"AI接口限流(429){wait}s后重试({attempt+1}/3)")
await _asyncio.sleep(wait)
continue
resp.raise_for_status()
data = resp.json()
text = data["choices"][0]["message"]["content"].strip()
tokens = data.get("usage", {}).get("total_tokens", 0)
await self._record_token_usage(db, tokens, data.get("usage", {}), model.model_name)
logger.bind(ai_call=True).info(
f"AI调用成功 model={model.model_name} tokens={tokens}"
)
return text, tokens
except Exception as e:
last_err = e
if attempt < 2:
await _asyncio.sleep(5 * (attempt + 1))
logger.error(f"AI调用失败(已重试3次): {last_err}")
return "", 0
async def generate_personality(self, nickname: str, account: str) -> dict:
"""生成用户人格含fallback随机生成"""
# 如无AI配置使用随机生成
from app.core.database import AsyncSessionLocal
try:
async with AsyncSessionLocal() as db:
model = await self._get_default_model(db)
if not model:
return self._random_personality()
prompt = f"""请为以下虚拟新闻读者生成一个独特的人格档案,要求真实自然、贴合中国用户特征。
用户昵称:{nickname}
请严格以JSON格式返回不要有其他内容
{{
"character_type": "从[开朗/内敛/毒舌/温和/理性/感性/幽默/严谨]中选一个",
"language_style": "从[严肃/幽默/文艺/吐槽/口语化/学术/简洁/丰富]中选一个",
"interest_tags": ["兴趣1", "兴趣2", "兴趣3"],
"interact_tendency": "从[爱评论/爱点赞/爱收藏/潜水/爱转发/爱回复]中选一个",
"word_count_min": 最少字数(10-50整数),
"word_count_max": 最多字数(50-200整数),
"personality_desc": "一句话描述此人的性格特点30字以内"
}}"""
content, _ = await self._call_api(db, prompt, max_tokens=300)
if content:
try:
# 提取JSON
json_match = re.search(r'\{.*\}', content, re.DOTALL)
if json_match:
return json.loads(json_match.group())
except Exception:
pass
return self._random_personality()
except Exception as e:
logger.error(f"人格生成失败: {e}")
return self._random_personality()
def _random_personality(self) -> dict:
"""随机生成人格无AI时的备用方案"""
interests = random.sample(self.INTEREST_TAGS_POOL, random.randint(2, 4))
char = random.choice(self.CHARACTER_TYPES)
style = random.choice(self.LANGUAGE_STYLES)
tendency = random.choice(self.INTERACT_TENDENCIES)
w_min = random.randint(15, 40)
w_max = random.randint(60, 150)
return {
"character_type": char,
"language_style": style,
"interest_tags": interests,
"interact_tendency": tendency,
"word_count_min": w_min,
"word_count_max": w_max,
"personality_desc": f"一个{char}性格、{tendency}的新闻读者",
}
async def generate_comment(
self, db: AsyncSession, article_title: str, article_content: str,
personality_prompt: str, word_min: int = 20, word_max: int = 80
) -> tuple[str, int]:
"""生成文章评论"""
system_prompt = f"""你是一名真实的社区用户,正在阅读新闻后发表评论。{personality_prompt}
重要规则:
- 评论必须积极正面、文明友善,绝对不包含任何政治敏感、色情、暴力、侮辱、歧视内容
- 不要提及具体政治人物、党派、政策批评、社会矛盾等敏感话题
- 内容围绕文章本身展开,表达个人感受、分享观点、提出建设性问题
- 语言朴实自然,像普通网友留言,不夸张不煽情"""
prompt = f"""请根据以下新闻文章写一条评论。
文章标题:{article_title}
文章摘要:{article_content[:200] if article_content else '(无摘要)'}
要求:
1. 评论字数 {word_min}~{word_max}
2. 内容积极正面,贴近文章主题
3. 语气自然真实,符合普通读者口吻
4. 必须是完整的句子,不能被截断,以句号/感叹号/问号结尾
5. 只输出评论正文,不要加任何前缀或解释
评论:"""
return await self._call_api(db, prompt, system_prompt, max_tokens=300)
async def generate_reply(
self, db: AsyncSession, article_title: str, parent_comment: str,
personality_prompt: str, word_min: int = 15, word_max: int = 60
) -> tuple[str, int]:
"""生成回复"""
system_prompt = f"""你是一名真实的社区用户。{personality_prompt}
重要规则:回复必须积极正面、文明友善,不含任何敏感违规内容。"""
prompt = f"""文章:{article_title}
原评论:{parent_comment}
请对上面的评论写一条友善自然的回复,{word_min}~{word_max}字,直接输出回复内容。"""
return await self._call_api(db, prompt, system_prompt, max_tokens=150)
async def test_model(self, db: AsyncSession, model_id: int, test_prompt: str) -> dict:
"""测试模型可用性"""
result = await db.execute(select(AIModelConfig).where(AIModelConfig.id == model_id))
model = result.scalar_one_or_none()
if not model:
return {"success": False, "error": "模型不存在"}
api_key = decrypt(model.api_key_enc) if model.api_key_enc else ""
base_url = model.api_base_url or "https://api.openai.com/v1"
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
payload = {
"model": model.model_version or "gpt-3.5-turbo",
"messages": [{"role": "user", "content": test_prompt}],
"max_tokens": 200,
}
try:
import time
start = time.time()
async with httpx.AsyncClient(timeout=model.timeout_seconds) as client:
resp = await client.post(f"{base_url}/chat/completions", headers=headers, json=payload)
resp.raise_for_status()
data = resp.json()
elapsed = round(time.time() - start, 2)
content = data["choices"][0]["message"]["content"]
tokens = data.get("usage", {}).get("total_tokens", 0)
return {
"success": True, "content": content,
"tokens": tokens, "elapsed_seconds": elapsed,
}
except Exception as e:
return {"success": False, "error": str(e)}
async def _record_token_usage(
self, db: AsyncSession, total: int, usage: dict, model_name: str
):
"""记录Token消耗"""
today = date.today()
from sqlalchemy.dialects.mysql import insert as mysql_insert
try:
existing = await db.execute(
select(TokenStat).where(TokenStat.stat_date == today)
)
stat = existing.scalar_one_or_none()
if stat:
stat.total_tokens += total
stat.prompt_tokens += usage.get("prompt_tokens", 0)
stat.completion_tokens += usage.get("completion_tokens", 0)
stat.call_count += 1
else:
stat = TokenStat(
stat_date=today,
model_name=model_name,
total_tokens=total,
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
call_count=1,
)
db.add(stat)
except Exception as e:
logger.error(f"记录Token消耗失败: {e}")
ai_service = AIService()

View File

@@ -0,0 +1,729 @@
"""
新闻平台对接服务
登录: POST {auth}/open/login/token (formData)
签名: 完全对应 sign.js 的实现
"""
import uuid
import hashlib
import hmac
from datetime import datetime, timedelta
from typing import Optional
import httpx
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update
from app.models import VirtualUser, SystemConfig, LoginLog
from app.utils.crypto import decrypt
from app.core.redis_client import set_session, get_session, delete_session
from app.core.logger import logger
class NewsPlatformService:
# ─── 配置读取 ──────────────────────────────────────────────
async def _cfg(self, db: AsyncSession, key: str, default: str = "") -> str:
result = await db.execute(select(SystemConfig).where(SystemConfig.config_key == key))
row = result.scalar_one_or_none()
return row.config_value if row else default
async def _biz_url(self, db: AsyncSession) -> str:
return await self._cfg(db, "news_platform_base_url", "http://192.168.1.200:63120")
async def _auth_url(self, db: AsyncSession) -> str:
return await self._cfg(db, "auth_base_url", "http://192.168.1.200:60040")
async def _client(self, db: AsyncSession) -> dict:
return {
"appId": await self._cfg(db, "platform_app_id", ""),
"accessId": await self._cfg(db, "platform_access_id", ""),
"accessSecret": await self._cfg(db, "platform_access_secret", ""),
"clientCode": await self._cfg(db, "platform_client_code", ""),
"orgId": await self._cfg(db, "platform_org_id", ""),
}
# ─── 签名(完全对应 sign.js 逻辑) ─────────────────────────
@staticmethod
def _make_sign(params: dict, secret_key: str, sign_type: str = "MD5") -> str:
"""
完全对应 sign.js:
1. 所有参数 key 排序
2. 过滤掉 signature / accessSecret / 空值 / 空数组
3. 拼接 key=value& ... accessSecret=secretKey
4. MD5/SHA256 大写
"""
SIGN_KEY = "signature"
SECRET_KEY = "accessSecret"
keys = sorted(params.keys())
str_parts = []
for k in keys:
if k in (SIGN_KEY, SECRET_KEY):
continue
v = params.get(k)
if v is None or v == "" or v == []:
continue
if isinstance(v, list):
continue
str_parts.append(f"{k}={v}")
sign_str = "&".join(str_parts) + f"&{SECRET_KEY}={secret_key}"
if sign_type.upper() == "SHA256":
return hashlib.sha256(sign_str.encode("utf-8")).hexdigest().upper()
else:
return hashlib.md5(sign_str.encode("utf-8")).hexdigest().upper()
@staticmethod
def _get_nonce() -> str:
import random, math
return str(random.random())[2:][: random.randint(8, 12)]
@staticmethod
def _get_timestamp() -> str:
"""yyyyMMddhhmmss — 注意 sign.js 用小写 hh(12小时制)"""
now = datetime.now()
# sign.js 用 yyyyMMddhhmmss (12小时制小写hh)
return now.strftime("%Y%m%d%I%M%S")
def _build_form(self, extra: dict, cfg: dict) -> dict:
"""构建带签名的 formData"""
sign_type = "MD5"
sign_version = "1.0"
secret_key = cfg.get("accessSecret", "")
base = {
"appId": cfg.get("appId", ""),
"accessId": cfg.get("accessId", ""),
"timestamp": self._get_timestamp(),
"signType": sign_type,
"signVersion": sign_version,
"accessSecret": secret_key,
"nonce": self._get_nonce(),
}
base.update(extra)
# 计算签名
signature = self._make_sign(base, secret_key, sign_type) if secret_key else ""
base["signature"] = signature
# 移除 accessSecret不发送到服务器
base.pop("accessSecret", None)
return base
# ─── 登录 ──────────────────────────────────────────────────
async def login(self, db: AsyncSession, user: VirtualUser) -> bool:
password = decrypt(user.password_enc)
if not password:
logger.error(f"[登录] {user.account} 密码解密失败")
return False
auth = await self._auth_url(db)
cfg = await self._client(db)
await db.execute(update(VirtualUser).where(VirtualUser.id == user.id).values(status=1))
await db.commit()
extra = {
"username": user.account,
"password": password,
"loginType": "password",
"grantType": "password",
"isRegister": "false",
}
if cfg.get("clientCode"):
extra["clientCode"] = cfg["clientCode"]
form = self._build_form(extra, cfg)
exc = None
try:
async with httpx.AsyncClient(
timeout=30,
follow_redirects=True, # 自动跟随 HTTP 重定向
) as c:
# 登录接口路径:需要加 /usercenter 前缀(通过网关路由)
# auth_base_url 配置为完整前缀,如 https://fat-open.99hui.com/api/usercenter
login_url = f"{auth}/open/login/token"
resp = await c.post(login_url, data=form)
# 详细记录原始响应,便于排查
logger.info(
f"[登录] {user.account} 原始响应: "
f"status={resp.status_code} "
f"content-type={resp.headers.get('content-type','')} "
f"body={resp.text[:500]}"
)
# 防止空响应体崩溃
if not resp.text.strip():
logger.warning(f"[登录] {user.account} 服务器返回空响应体,接口可能不存在或被重定向")
raise ValueError(f"服务器返回空响应 HTTP={resp.status_code}")
# 尝试解析 JSON
try:
body = resp.json()
except Exception as je:
logger.warning(f"[登录] {user.account} 响应非JSON: {resp.text[:200]}")
raise ValueError(f"响应非JSON: {resp.text[:100]}")
if resp.status_code == 200 and body.get("code") in [0, 200]:
raw = body.get("data")
access_token, platform_uid = self._extract_token(raw)
if access_token:
sid = str(uuid.uuid4())
# 登录成功后尝试获取用户组织信息
org_id = await self._fetch_org_id(db, access_token, platform_uid, cfg)
if org_id and not cfg.get("orgId"):
# 如果系统没有配置 orgId则自动保存用户所属组织
await self._save_org_id(db, org_id)
# 从登录响应中提取用户信息
user_info = raw.get("userInfo", {}) if isinstance(raw, dict) else {}
sync_nickname = user_info.get("nickName") or user_info.get("username") or ""
sync_real_name = user_info.get("realName") or ""
sync_sex = int(user_info.get("sex") or 0)
sync_avatar = user_info.get("avatar") or ""
await set_session(user.id, {
"token": access_token,
"session_id": sid,
"platform_uid": platform_uid,
"org_id": org_id or cfg.get("orgId", ""),
"login_time": datetime.now().isoformat(),
# 缓存用户信息供 sync 使用
"nickname": sync_nickname,
"real_name": sync_real_name,
"sex": sync_sex,
"avatar": sync_avatar,
}, expire=86400)
# 更新本地数据库,同步平台用户信息
update_vals = dict(
status=2, session_token=access_token,
session_expires_at=datetime.now() + timedelta(hours=24),
last_login_at=datetime.now(),
platform_uid=platform_uid,
)
if sync_nickname: update_vals["nickname"] = sync_nickname
if sync_real_name: update_vals["real_name"] = sync_real_name
if sync_sex: update_vals["sex"] = sync_sex
if sync_avatar: update_vals["avatar_url"] = sync_avatar
await db.execute(update(VirtualUser).where(VirtualUser.id == user.id).values(**update_vals))
await db.commit()
await self._write_login_log(db, user, "login", sid)
logger.info(f"✅ [登录] {user.account} 成功 orgId={org_id}")
return True
logger.warning(f"[登录] {user.account} 无token: {body}")
else:
logger.warning(f"[登录] {user.account} 失败: HTTP={resp.status_code} {body}")
except Exception as e:
exc = e
logger.error(f"[登录] {user.account} 异常: {e}")
await db.execute(update(VirtualUser).where(VirtualUser.id == user.id).values(status=3))
await db.commit()
await self._write_login_log(db, user, "fail", error_msg=str(exc or "登录失败"))
return False
async def _fetch_org_id(
self, db: AsyncSession, token: str, platform_uid: str, cfg: dict
) -> str:
"""登录成功后,调用接口获取用户所属组织 orgId"""
biz = await self._biz_url(db)
headers = self._bearer(token)
# 尝试常见的用户信息接口
endpoints = [
f"/app/user/info",
f"/open/user/info",
f"/user/info",
]
for ep in endpoints:
try:
form = self._build_form({}, cfg)
async with httpx.AsyncClient(timeout=8) as c:
r = await c.get(
f"{biz}{ep}",
headers=headers,
params={k: v for k, v in form.items() if k not in ("username","password")}
)
if r.status_code == 200:
d = r.json()
data = d.get("data") or {}
# 从各种可能的字段中提取 orgId
org_id = (
data.get("orgId") or data.get("defaultOrgId") or
data.get("org", {}).get("id") if isinstance(data.get("org"), dict) else None
)
if org_id:
return str(org_id)
except Exception as e:
logger.debug(f"获取orgId失败({ep}): {e}")
return ""
async def _save_org_id(self, db: AsyncSession, org_id: str):
"""自动保存获取到的 orgId 到系统配置"""
result = await db.execute(
select(SystemConfig).where(SystemConfig.config_key == "platform_org_id")
)
cfg = result.scalar_one_or_none()
if cfg:
cfg.config_value = org_id
else:
db.add(SystemConfig(
config_key="platform_org_id",
config_value=org_id,
description="平台组织Id自动获取"
))
await db.commit()
logger.info(f"✅ 自动获取并保存 orgId={org_id}")
@staticmethod
def _extract_token(raw) -> tuple[str, str]:
if isinstance(raw, str) and raw:
return raw, ""
if isinstance(raw, dict):
token = (raw.get("access_token") or raw.get("accessToken") or raw.get("token") or "")
# openid 是平台用户ID登录响应里 data.openid = data.userInfo.id
uid = str(
raw.get("openid") or
raw.get("userId") or raw.get("user_id") or raw.get("id") or
(raw.get("userInfo") or {}).get("id") or ""
)
return token, uid
return "", ""
async def logout(self, db: AsyncSession, user_id: int):
user_r = await db.execute(select(VirtualUser).where(VirtualUser.id == user_id))
user = user_r.scalar_one_or_none()
if user:
sess = await get_session(user_id)
await self._write_login_log(db, user, "logout",
sess.get("session_id") if sess else None)
await delete_session(user_id)
await db.execute(update(VirtualUser).where(VirtualUser.id == user_id).values(
status=0, session_token=None, session_expires_at=None))
await db.commit()
async def check_session(self, db: AsyncSession, user: VirtualUser) -> bool:
sess = await get_session(user.id)
if not sess:
return False
biz = await self._biz_url(db)
cfg = await self._client(db)
try:
params = self._build_form({}, cfg)
params.update({"orgId": cfg["orgId"] or "1", "pageNum": 1, "pageSize": 1, "status": "approved"})
async with httpx.AsyncClient(timeout=10) as c:
r = await c.get(f"{biz}/news/list", headers=self._bearer(sess["token"]), params=params)
if r.status_code == 200 and r.json().get("code") in [0, 200]:
return True
await self.logout(db, user.id)
return False
except Exception:
return False
# ─── 新闻列表 ──────────────────────────────────────────────
async def get_news_list(self, db, user, count=5, interest_tags=None) -> list:
"""
GET /business/member/square/list 广场数据分页查询
type=1 表示新闻orgId 选填(不填则查全平台新闻,无需配置 orgId
返回字段id(广场ID), recordId(新闻实际ID), title, orgId, orgName
"""
sess = await get_session(user.id)
if not sess:
return []
biz = await self._biz_url(db)
cfg = await self._client(db)
org_id = sess.get("org_id") or cfg.get("orgId") or ""
# 先查总数再随机翻页避免每次都取第1页相同内容
import math
# 第一次查询获取总页数
first_params = self._build_form({
"pageNum": 1,
"pageSize": 50,
"type": "1",
"isPlatformShow": "true",
"isAdmin": "false",
}, cfg)
if org_id:
first_params["orgId"] = org_id
total_pages = 1
try:
async with httpx.AsyncClient(timeout=10) as _c:
_r = await _c.get(
f"{biz}/business/member/square/list",
headers=self._bearer(sess["token"]),
params=first_params
)
_d = _r.json()
if _d.get("code") in [0, 200]:
total_size = _d.get("data", {}).get("totalSize", 0)
total_pages = max(1, math.ceil(total_size / 50))
except Exception:
pass
# 随机选择一页
import random as _random
rand_page = _random.randint(1, min(total_pages, 10)) # 最多取前10页随机
params = self._build_form({
"pageNum": rand_page,
"pageSize": 50,
"type": "1",
"isPlatformShow": "true",
"isAdmin": "false",
}, cfg)
if org_id:
params["orgId"] = org_id # 选填,有则按组织过滤
try:
async with httpx.AsyncClient(timeout=15) as c:
r = await c.get(
f"{biz}/business/member/square/list",
headers=self._bearer(sess["token"]),
params=params
)
if r.status_code == 200:
d = r.json()
if d.get("code") in [0, 200]:
nd = d.get("data", {})
items = nd.get("data") or nd.get("list") or nd.get("records") or []
# 过滤本人发布的文章
platform_uid = sess.get("platform_uid", "")
if platform_uid:
items = [x for x in items if x.get("createUser") != platform_uid]
# 过滤已知无效新闻(详情为空或不存在)
INVALID_IDS = {
"1965670408480907266","2029092495693975554","1960652956793597953",
"1960651987045347330","1960596408620838914","1960596083193180161",
"1960595664341594113",
}
items = [x for x in items
if (x.get("recordId") or x.get("id")) not in INVALID_IDS]
logger.info(f"[广场新闻] {user.account} 获取到 {len(items)} 条(已过滤本人+无效文章)")
import random as _rand
return _rand.sample(items, min(count, len(items))) if items else []
logger.warning(f"[广场新闻] {user.account} code={d.get('code')} msg={d.get('message')}")
except Exception as e:
logger.error(f"[广场新闻] {user.account}: {e}")
return []
async def read_news(self, db, user, news_id: str) -> bool:
sess = await get_session(user.id)
if not sess:
return False
biz = await self._biz_url(db)
cfg = await self._client(db)
try:
async with httpx.AsyncClient(timeout=10) as c:
r = await c.patch(
f"{biz}/news/read/{news_id}",
headers=self._bearer(sess["token"]),
data=self._build_form({}, cfg),
)
return r.status_code == 200
except Exception:
return False
async def post_comment(self, db, user, news_id, news_title, content, news_author_id="", org_id="") -> tuple[bool, str]:
sess = await get_session(user.id)
if not sess:
return False, "未登录"
biz = await self._biz_url(db)
cfg = await self._client(db)
uid = sess.get("platform_uid", "")
# org_id 优先取文章自带的(从广场数据获取),否则取 session/配置
final_org_id = org_id or sess.get("org_id") or cfg.get("orgId") or ""
body = {
"module": "news", "topicId": news_id, "title": news_title,
"content": content, "orgId": final_org_id,
"toUserId": news_author_id or uid, "userId": uid,
"userName": user.nickname, "avatar": user.avatar_url or "",
}
return await self._json_post(f"{biz}/message/comment", self._bearer(sess["token"]), body)
async def post_reply(self, db, user, news_id, comment_id, content) -> tuple[bool, str]:
sess = await get_session(user.id)
if not sess:
return False, "未登录"
biz = await self._biz_url(db)
uid = sess.get("platform_uid", "")
body = {
"module": "news", "topicId": news_id, "commentId": comment_id,
"commentUserId": uid, "content": content,
"fromUserName": user.nickname, "avatar": user.avatar_url or "",
}
return await self._json_post(f"{biz}/message/comment/reply", self._bearer(sess["token"]), body)
async def get_comments(self, db, user, news_id) -> list:
sess = await get_session(user.id)
if not sess:
return []
biz = await self._biz_url(db)
cfg = await self._client(db)
try:
params = self._build_form({"module": "news", "topicId": news_id, "pageNum": 1, "pageSize": 20}, cfg)
async with httpx.AsyncClient(timeout=10) as c:
r = await c.get(f"{biz}/message/comment", headers=self._bearer(sess["token"]), params=params)
if r.status_code == 200:
return r.json().get("data", {}).get("data") or []
except Exception:
pass
return []
async def like_news(self, db, user, news_id, org_id="", to_user_id="", title="") -> tuple[bool, str]:
sess = await get_session(user.id)
if not sess:
return False, "未登录"
biz = await self._biz_url(db)
uid = sess.get("platform_uid", "")
body = {
"module": "news",
"topicId": news_id,
"userId": uid,
"toUserId": to_user_id or uid,
"orgId": org_id or sess.get("org_id", "") or "",
"title": title,
}
return await self._json_post(f"{biz}/message/praise", self._bearer(sess["token"]), body)
async def collect_news(self, db, user, news_id, org_id="", to_user_id="", title="") -> tuple[bool, str]:
"""收藏新闻:复用点赞接口(平台收藏=点赞同一接口)"""
return await self.like_news(db, user, news_id, org_id=org_id, to_user_id=to_user_id, title=title)
async def forward_news(self, db, user, news_id) -> tuple[bool, str]:
sess = await get_session(user.id)
if not sess:
return False, "未登录"
biz = await self._biz_url(db)
cfg = await self._client(db)
org_id = sess.get("org_id") or cfg.get("orgId") or "1"
headers = self._bearer(sess["token"])
try:
async with httpx.AsyncClient(timeout=8) as c:
await c.get(f"{biz}/news/share/wechat/{news_id}", headers=headers,
params=self._build_form({}, cfg))
except Exception:
pass
try:
async with httpx.AsyncClient(timeout=15) as c:
r = await c.post(
f"{biz}/points/forward/news/{org_id}",
headers=headers,
data=self._build_form({}, cfg),
)
return self._ok(r)
except Exception as e:
return False, str(e)
@staticmethod
def _bearer(token: str) -> dict:
return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
@staticmethod
def _ok(resp: httpx.Response) -> tuple[bool, str]:
if resp.status_code in [200, 201]:
try:
d = resp.json()
if d.get("code") in [0, 200]:
return True, ""
return False, d.get("message") or "业务失败"
except Exception:
return True, ""
return False, f"HTTP {resp.status_code}"
async def _json_post(self, url, headers, body) -> tuple[bool, str]:
try:
async with httpx.AsyncClient(timeout=15) as c:
r = await c.post(url, json=body, headers=headers)
return self._ok(r)
except Exception as e:
return False, str(e)
async def _write_login_log(self, db, user, action, session_id=None, error_msg=None):
try:
log = LoginLog(
user_id=user.id, user_account=user.account,
action=action, session_id=session_id, error_msg=error_msg,
)
db.add(log)
await db.commit()
except Exception:
pass
# ─── 取消互动 ────────────────────────────────────────────────────
async def cancel_like(self, db, user, news_id: str, org_id: str = "", to_user_id: str = "", title: str = "") -> tuple[bool, str]:
"""DELETE /message/praise/cancel 取消点赞"""
sess = await get_session(user.id)
if not sess:
return False, "未登录"
biz = await self._biz_url(db)
uid = sess.get("platform_uid", "")
body = {
"module": "news",
"topicId": news_id,
"userId": uid,
"toUserId": to_user_id or uid,
"orgId": org_id or sess.get("org_id", "") or "",
"title": title,
}
try:
async with httpx.AsyncClient(timeout=10) as c:
r = await c.delete(
f"{biz}/message/praise/cancel",
json=body,
headers=self._bearer(sess["token"])
)
d = r.json()
if d.get("code") in [0, 200]:
return True, ""
return False, d.get("message", "取消点赞失败")
except Exception as e:
return False, str(e)
async def cancel_comment(self, db, user, news_id: str, comment_id: str) -> tuple[bool, str]:
"""DELETE /message/comment/{topicId}/{id} 删除评论"""
sess = await get_session(user.id)
if not sess:
return False, "未登录"
biz = await self._biz_url(db)
cfg = await self._client(db)
# 签名参数放 formData路径里是 topicId 和 comment_id
params = self._build_form({}, cfg)
try:
async with httpx.AsyncClient(timeout=10) as c:
r = await c.delete(
f"{biz}/message/comment/{news_id}/{comment_id}",
headers=self._bearer(sess["token"]),
params=params
)
d = r.json()
if d.get("code") in [0, 200]:
return True, ""
return False, d.get("message", "删除评论失败")
except Exception as e:
return False, str(e)
async def cancel_collect(self, db, user, news_id: str, org_id: str = "", to_user_id: str = "", title: str = "") -> tuple[bool, str]:
"""取消收藏(复用取消点赞接口)"""
return await self.cancel_like(db, user, news_id, org_id=org_id, to_user_id=to_user_id, title=title)
# ─── 修改目标系统用户信息 ─────────────────────────────────────
async def update_user_profile(
self, db: AsyncSession, user: VirtualUser,
nick_name: str = None, real_name: str = None,
sex: int = None, avatar: str = None,
description: str = None, email: str = None,
) -> tuple[bool, str]:
"""
调用 POST /usercenter/user/change/userInfo 修改用户信息
支持:昵称、真实姓名、性别、头像、简介、邮箱
同时同步更新本地数据库
"""
sess = await get_session(user.id)
if not sess:
return False, "用户未登录,请先登录"
auth = await self._auth_url(db)
cfg = await self._client(db)
token = sess.get("token", "")
platform_uid = sess.get("platform_uid", "")
if not platform_uid:
return False, "缺少平台用户ID请重新登录"
# 构建请求体(只传有值的字段)
# 构建请求体,确保至少有 nickName 字段(平台 SQL 要求 SET 子句不为空)
body = {
"id": platform_uid,
"nickName": nick_name if nick_name is not None else (user.nickname or ""),
"realName": real_name if real_name is not None else (user.real_name or ""),
"sex": sex if sex is not None else (user.sex or 0),
}
if avatar is not None: body["avatar"] = avatar
if description is not None: body["description"] = description
if email is not None: body["email"] = email
# 使用 PATCH /v2/users/current 接口(支持修改昵称)
headers = dict(self._bearer(token))
headers["Content-Type"] = "application/json"
try:
async with httpx.AsyncClient(timeout=15) as c:
r = await c.patch(
f"{auth}/v2/users/current",
json=body,
headers=headers,
)
d = r.json()
if d.get("code") in [0, 200]:
# 同步到本地数据库
local_vals = {}
if nick_name is not None: local_vals["nickname"] = nick_name
if real_name is not None: local_vals["real_name"] = real_name
if sex is not None: local_vals["sex"] = sex
if avatar is not None: local_vals["avatar_url"] = avatar
if local_vals:
from sqlalchemy import update
await db.execute(update(VirtualUser).where(
VirtualUser.id == user.id).values(**local_vals))
await db.commit()
logger.info(f"✅ 用户 {user.account} 信息已同步到目标系统")
return True, ""
err = d.get("message") or f"code={d.get('code')}"
logger.warning(f"[修改用户信息] {user.account} 失败: {err} body={r.text[:200]}")
return False, err
except Exception as e:
logger.warning(f"[修改用户信息] {user.account} 异常: {e}")
return False, str(e)
async def upload_avatar(
self, db: AsyncSession, user: VirtualUser, file_bytes: bytes, filename: str
) -> tuple[bool, str]:
"""
上传头像图片到平台 filecenter返回图片 URL
POST /filecenter/fileUpload (multipart/form-data)
"""
sess = await get_session(user.id)
if not sess:
return False, "用户未登录"
cfg = await self._client(db)
token = sess.get("token", "")
# filecenter 服务地址
biz_base = await self._biz_url(db)
# filecenter 与 huihuibusiness 同网关,替换服务名
filecenter_url = biz_base.replace("/huihuibusiness", "/filecenter")
sign_params = self._build_form({"module": "userInfo", "service": "kccloud"}, cfg)
headers = {"Authorization": f"Bearer {token}"}
try:
import mimetypes
mime = mimetypes.guess_type(filename)[0] or "image/jpeg"
files = {"file": (filename, file_bytes, mime)}
async with httpx.AsyncClient(timeout=30) as c:
r = await c.post(
f"{filecenter_url}/fileUpload",
files=files,
data=sign_params,
headers=headers,
)
d = r.json()
if d.get("code") in [0, 200]:
url = d.get("data") or d.get("url") or ""
if isinstance(url, dict):
url = url.get("url") or url.get("path") or ""
logger.info(f"✅ 头像上传成功: {url}")
return True, url
return False, d.get("message") or "上传失败"
except Exception as e:
return False, str(e)
news_service = NewsPlatformService()

View File

@@ -0,0 +1,402 @@
"""调度服务 - 定时自动互动、会话校验"""
import random
import asyncio
from datetime import datetime, date
from typing import Optional
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.interval import IntervalTrigger
from sqlalchemy import select, update, func
from app.core.database import AsyncSessionLocal
from app.core.logger import logger
from app.models import VirtualUser, UserPersonality, InteractionRecord, SystemConfig
class SchedulerService:
def __init__(self):
self.scheduler = AsyncIOScheduler(timezone="Asia/Shanghai")
self._running = False
async def run_once_now(self, db=None):
"""立即执行一次互动,不受时间段限制"""
from sqlalchemy import select
from app.core.database import AsyncSessionLocal
logger.info("⚡ 立即触发互动任务")
async with AsyncSessionLocal() as session:
result_r = await session.execute(
select(VirtualUser).where(
VirtualUser.status == 2,
VirtualUser.is_enabled == 1,
)
)
users = result_r.scalars().all()
if not users:
return {"message": "没有已登录的用户", "triggered": 0}
import random
selected = random.sample(users, min(5, len(users)))
import asyncio
tasks = [self._execute_user_interaction(u.id) for u in selected]
await asyncio.gather(*tasks, return_exceptions=True)
return {"triggered": len(selected), "users": [u.account for u in selected]}
async def start(self):
if self._running:
return
# 会话校验每10分钟
self.scheduler.add_job(
self._check_sessions, IntervalTrigger(minutes=10),
id="check_sessions", replace_existing=True
)
# 互动任务每5分钟检查一次内部判断是否在活跃时间段
self.scheduler.add_job(
self._run_interactions, IntervalTrigger(minutes=5),
id="run_interactions", replace_existing=True
)
# 每日零点重置计数
self.scheduler.add_job(
self._daily_reset, "cron", hour=16, minute=0, # 北京时间 00:00 = UTC 16:00
id="daily_reset", replace_existing=True
)
self.scheduler.start()
self._running = True
logger.info("调度器已启动")
# 记录启动时间
async with AsyncSessionLocal() as db:
await self._set_config(db, "system_start_time", datetime.now().isoformat())
async def stop(self):
if self.scheduler.running:
self.scheduler.shutdown(wait=False)
self._running = False
async def _get_config(self, db, key: str, default=None):
result = await db.execute(select(SystemConfig).where(SystemConfig.config_key == key))
cfg = result.scalar_one_or_none()
return cfg.config_value if cfg else default
async def _set_config(self, db, key: str, value: str):
result = await db.execute(select(SystemConfig).where(SystemConfig.config_key == key))
cfg = result.scalar_one_or_none()
if cfg:
cfg.config_value = value
else:
db.add(SystemConfig(config_key=key, config_value=value))
await db.commit()
async def _check_sessions(self):
"""定时校验登录状态"""
from app.services.news_service import news_service
async with AsyncSessionLocal() as db:
result = await db.execute(
select(VirtualUser).where(VirtualUser.status == 2, VirtualUser.is_enabled == 1)
)
users = result.scalars().all()
for user in users:
try:
valid = await news_service.check_session(db, user)
if not valid:
logger.warning(f"用户 {user.account} 会话失效,尝试重登")
await news_service.login(db, user)
except Exception as e:
logger.error(f"会话校验异常 {user.account}: {e}")
async def _run_interactions(self):
"""执行互动任务"""
async with AsyncSessionLocal() as db:
# 检查调度器开关
enabled = await self._get_config(db, "scheduler_enabled", "true")
if enabled != "true":
return
# 检查Token限额
token_limited = await self._get_config(db, "token_limit_reached", "false")
if token_limited == "true":
return
# 检查互动时间段(北京时间 UTC+8
from datetime import timezone, timedelta
tz_beijing = timezone(timedelta(hours=8))
now_bj = datetime.now(tz_beijing)
now_time = now_bj.strftime("%H:%M")
start_str = await self._get_config(db, "interact_time_start", "08:00")
end_str = await self._get_config(db, "interact_time_end", "22:00")
if not (start_str <= now_time <= end_str):
logger.debug(f"[调度] 当前北京时间 {now_time} 不在互动时段 {start_str}-{end_str}")
return
# 获取最小互动间隔(秒)
min_interval = int(await self._get_config(db, "interact_min_interval", "300"))
# 获取最大并发
max_concurrent = int(await self._get_config(db, "max_concurrent_users", "5"))
# 获取已登录、启用的用户
query = select(VirtualUser).where(
VirtualUser.status == 2,
VirtualUser.is_enabled == 1,
)
if max_concurrent > 0:
query = query.limit(max_concurrent)
result = await db.execute(query)
all_users = result.scalars().all()
# 没有已登录用户时,尝试登录未登录用户
if not all_users:
await self._try_login_users(db)
return
# 检查互动间隔:过滤掉最近 min_interval 秒内已互动的用户
now_utc = datetime.utcnow()
eligible = []
for u in all_users:
if u.last_interact_at is None:
eligible.append(u)
else:
elapsed = (now_utc - u.last_interact_at).total_seconds()
if elapsed >= min_interval:
eligible.append(u)
if not eligible:
logger.debug(f"[调度] 所有用户在 {min_interval}s 内已互动,跳过本次")
return
logger.info(f"[调度] {len(eligible)}/{len(all_users)} 个用户满足间隔要求,开始互动")
# 随机选取用户执行互动
for user in eligible:
if random.random() < 0.6:
asyncio.create_task(self._execute_user_interaction(user.id))
async def _try_login_users(self, db):
"""尝试登录未登录的用户"""
from app.services.news_service import news_service
result = await db.execute(
select(VirtualUser).where(
VirtualUser.status.in_([0, 3]),
VirtualUser.is_enabled == 1
).limit(3)
)
users = result.scalars().all()
for user in users:
try:
await news_service.login(db, user)
await asyncio.sleep(2)
except Exception as e:
logger.error(f"自动登录失败 {user.account}: {e}")
async def _execute_user_interaction(self, user_id: int):
"""执行单用户互动 - 基于真实接口"""
from app.services.news_service import news_service
from app.services.ai_service import ai_service
async with AsyncSessionLocal() as db:
try:
user_result = await db.execute(select(VirtualUser).where(VirtualUser.id == user_id))
user = user_result.scalar_one_or_none()
if not user or user.status != 2:
return
# 检查今日评论限额
can_comment = True
if user.today_comment_count >= user.daily_comment_limit:
can_comment = False
logger.info(f'用户 ' + user.account + ' 今日评论已达上限,仍执行点赞/收藏/转发')
# 获取人格
from app.models import UserPersonality
p_result = await db.execute(
select(UserPersonality).where(UserPersonality.user_id == user_id)
)
personality = p_result.scalar_one_or_none()
interest_tags = personality.interest_tags if personality else []
# 获取新闻列表(基于接口 GET /news/list
articles = await news_service.get_news_list(
db, user, count=5, interest_tags=interest_tags
)
if not articles:
# 尝试从 session 获取 org_id 再试一次
from app.core.redis_client import get_session as _get_sess
sess = await _get_sess(user.id)
org_from_sess = sess.get("org_id", "") if sess else ""
if org_from_sess:
articles = await news_service.get_news_list(
db, user, count=5, interest_tags=interest_tags
)
if not articles:
logger.warning(
f"用户 {user.account} 获取新闻列表为空 "
f"(orgId={await news_service._cfg(db, 'platform_org_id', '')})"
)
return
article = random.choice(articles)
# 接口返回字段: id/newsTitle/content/digest/createUser
# 广场接口字段recordId=新闻实际ID, id=广场记录ID, title=标题
news_id = str(article.get("recordId") or article.get("id", ""))
news_title = article.get("title") or article.get("newsTitle") or "未知文章"
news_content = article.get("content") or article.get("digest") or news_title
news_author = str(article.get("createUser") or "")
# 从广场数据中顺带获取 orgId
article_org_id = str(article.get("orgId") or "")
if not news_id:
return
# 读取互动概率
comment_prob = float(await self._get_config_from_db(db, "comment_probability", "0.4"))
reply_prob = float(await self._get_config_from_db(db, "reply_probability", "0.2"))
like_prob = float(await self._get_config_from_db(db, "like_probability", "0.6"))
collect_prob = float(await self._get_config_from_db(db, "collect_probability", "0.3"))
forward_prob = float(await self._get_config_from_db(db, "forward_probability", "0.15"))
interactions_done = []
# ① 先记录阅读(每次必做,模拟真实用户打开文章)
await news_service.read_news(db, user, news_id)
# ② 点赞
if random.random() < like_prob:
success, err = await news_service.like_news(db, user, news_id, org_id=article_org_id, to_user_id=news_author, title=news_title)
await self._save_record(db, user, news_id, news_title, "like", None, 0, success, err)
if success:
interactions_done.append("like")
await self._incr_total(db, user_id)
# ③ 收藏(阅读+点赞组合模拟)
if random.random() < collect_prob:
success, err = await news_service.collect_news(db, user, news_id, org_id=article_org_id, to_user_id=news_author, title=news_title)
await self._save_record(db, user, news_id, news_title, "collect", None, 0, success, err)
if success:
interactions_done.append("collect")
# ④ 转发(调用 /points/forward/news/{orgId}
if random.random() < forward_prob:
success, err = await news_service.forward_news(db, user, news_id)
await self._save_record(db, user, news_id, news_title, "forward", None, 0, success, err)
if success:
interactions_done.append("forward")
await self._incr_total(db, user_id)
# ⑤ 评论AI生成内容调用 POST /message/comment
if can_comment and random.random() < comment_prob and personality:
style_prompt = personality.comment_style_prompt or ""
# 字数上限最多80字避免超出 max_tokens 被截断
safe_word_max = min(personality.word_count_max, 80)
comment_text, tokens = await ai_service.generate_comment(
db, news_title, news_content,
style_prompt, personality.word_count_min, safe_word_max
)
if comment_text:
success, err = await news_service.post_comment(
db, user, news_id, news_title, comment_text,
news_author_id=news_author, org_id=article_org_id
)
await self._save_record(
db, user, news_id, news_title, "comment",
comment_text, tokens, success, err
)
if success:
interactions_done.append("comment")
await db.execute(
update(VirtualUser).where(VirtualUser.id == user_id).values(
today_comment_count=VirtualUser.today_comment_count + 1,
total_interactions=VirtualUser.total_interactions + 1,
last_interact_at=datetime.utcnow()
)
)
# ⑥ 回复评论(评论成功后,随机回复别人的评论)
if random.random() < reply_prob:
existing = await news_service.get_comments(db, user, news_id)
if existing:
target = random.choice(existing)
cid = str(target.get("id") or target.get("commentId") or "")
parent_content = target.get("content") or ""
if cid:
reply_text, r_tokens = await ai_service.generate_reply(
db, news_title, parent_content,
style_prompt,
personality.word_count_min,
personality.word_count_max
)
if reply_text:
r_ok, r_err = await news_service.post_reply(
db, user, news_id, cid, reply_text
)
await self._save_record(
db, user, news_id, news_title, "reply",
reply_text, r_tokens, r_ok, r_err,
parent_comment_id=cid
)
if r_ok:
interactions_done.append("reply")
await db.commit()
logger.info(f"👤 {user.account} 互动完成: {interactions_done} [新闻: {news_title[:20]}]")
except Exception as e:
logger.error(f"用户 {user_id} 互动异常: {e}")
async def _incr_total(self, db, user_id: int):
await db.execute(
update(VirtualUser).where(VirtualUser.id == user_id).values(
total_interactions=VirtualUser.total_interactions + 1,
last_interact_at=datetime.utcnow()
)
)
async def _save_record(
self, db, user: VirtualUser, article_id: str, article_title: str,
interact_type: str, content: Optional[str], tokens: int,
success: bool, error_msg: str, parent_comment_id: str = None,
platform_record_id: str = None
):
from app.core.redis_client import get_session
session = await get_session(user.id)
session_id = session.get("session_id") if session else None
record = InteractionRecord(
user_id=user.id,
user_nickname=user.nickname,
user_account=user.account,
article_id=article_id,
article_title=article_title,
interact_type=interact_type,
content=content,
parent_comment_id=parent_comment_id,
platform_record_id=platform_record_id,
session_id=session_id,
token_consumed=tokens,
status=1 if success else 2,
error_msg=error_msg or None,
executed_at=datetime.now(),
)
db.add(record)
async def _get_config_from_db(self, db, key: str, default: str = "") -> str:
result = await db.execute(select(SystemConfig).where(SystemConfig.config_key == key))
cfg = result.scalar_one_or_none()
return cfg.config_value if cfg else default
async def _daily_reset(self):
"""每日零点重置计数"""
async with AsyncSessionLocal() as db:
await db.execute(
update(VirtualUser).values(
today_comment_count=0,
today_like_count=0
)
)
# 重置Token限额标志
result = await db.execute(
select(SystemConfig).where(SystemConfig.config_key == "token_limit_reached")
)
cfg = result.scalar_one_or_none()
if cfg:
cfg.config_value = "false"
await db.commit()
logger.info("每日计数重置完成")
scheduler_service = SchedulerService()

View File

@@ -0,0 +1,251 @@
"""数据统计服务"""
from datetime import datetime, date, timedelta, timezone
def _fmt_dt(dt):
"""统一输出 UTC 时间,带时区标识,让前端正确解析为 +8"""
if dt is None:
return None
if dt.tzinfo is None:
# 数据库存的是 UTC补上时区信息
dt = dt.replace(tzinfo=timezone.utc)
return dt.isoformat()
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, and_
from app.models import VirtualUser, InteractionRecord, TokenStat, SystemConfig
from app.core.logger import logger
class StatsService:
async def get_dashboard(self, db: AsyncSession) -> dict:
"""获取控制台数据"""
today = date.today()
now = datetime.now()
month_start = today.replace(day=1)
# 用户统计
user_stats = await self._get_user_stats(db)
# 今日互动统计
today_stats = await self._get_today_stats(db, today)
# 本月互动统计
monthly_stats = await self._get_monthly_stats(db, month_start, today)
# Token统计
token_stats = await self._get_token_stats(db, today)
# 系统状态
system_status = await self._get_system_status(db, now)
# 在线用户数
online_count_result = await db.execute(
select(func.count()).where(VirtualUser.status == 2)
)
online_count = online_count_result.scalar() or 0
return {
"user_stats": user_stats,
"today_interactions": today_stats,
"monthly_stats": monthly_stats,
"token_stats": token_stats,
"system_status": system_status,
"online_users": online_count,
}
async def _get_user_stats(self, db: AsyncSession) -> dict:
total = await db.execute(select(func.count()).select_from(VirtualUser))
normal = await db.execute(select(func.count()).where(VirtualUser.is_enabled == 1))
banned = await db.execute(select(func.count()).where(VirtualUser.status == 4))
abnormal = await db.execute(select(func.count()).where(VirtualUser.status == 3))
return {
"total": total.scalar() or 0,
"normal": normal.scalar() or 0,
"banned": banned.scalar() or 0,
"abnormal": abnormal.scalar() or 0,
}
async def _get_today_stats(self, db: AsyncSession, today: date) -> dict:
result = await db.execute(
select(
InteractionRecord.interact_type,
func.count().label("cnt"),
).where(
func.date(InteractionRecord.executed_at) == today,
InteractionRecord.status == 1,
).group_by(InteractionRecord.interact_type)
)
rows = result.all()
stats = {"comment": 0, "reply": 0, "like": 0, "collect": 0, "forward": 0, "total": 0}
for row in rows:
if row.interact_type in stats:
stats[row.interact_type] = row.cnt
stats["total"] += row.cnt
return stats
async def _get_monthly_stats(self, db: AsyncSession, month_start: date, today: date) -> dict:
result = await db.execute(
select(func.count()).where(
InteractionRecord.executed_at >= month_start,
InteractionRecord.status == 1,
)
)
return {"total": result.scalar() or 0, "month_start": month_start.isoformat()}
async def _get_token_stats(self, db: AsyncSession, today: date) -> dict:
# 今日
today_stat = await db.execute(select(TokenStat).where(TokenStat.stat_date == today))
today_row = today_stat.scalar_one_or_none()
# 每日限额
limit_cfg = await db.execute(
select(SystemConfig).where(SystemConfig.config_key == "daily_token_limit")
)
limit_row = limit_cfg.scalar_one_or_none()
daily_limit = int(limit_row.config_value) if limit_row else 100000
today_used = today_row.total_tokens if today_row else 0
return {
"today_used": today_used,
"daily_limit": daily_limit,
"remaining": max(0, daily_limit - today_used),
"today_calls": today_row.call_count if today_row else 0,
}
async def _get_system_status(self, db: AsyncSession, now: datetime) -> dict:
start_cfg = await db.execute(
select(SystemConfig).where(SystemConfig.config_key == "system_start_time")
)
start_row = start_cfg.scalar_one_or_none()
uptime = ""
if start_row and start_row.config_value:
try:
start_time = datetime.fromisoformat(start_row.config_value)
delta = now - start_time
hours, rem = divmod(int(delta.total_seconds()), 3600)
mins = rem // 60
uptime = f"{hours}小时{mins}分钟"
except Exception:
uptime = "未知"
scheduler_cfg = await db.execute(
select(SystemConfig).where(SystemConfig.config_key == "scheduler_enabled")
)
scheduler_row = scheduler_cfg.scalar_one_or_none()
return {
"uptime": uptime,
"scheduler_enabled": (scheduler_row.config_value == "true") if scheduler_row else True,
"current_time": now.isoformat(),
}
async def get_token_trend(self, db: AsyncSession, days: int = 30) -> list:
"""Token消耗趋势近N天"""
end_date = date.today()
start_date = end_date - timedelta(days=days - 1)
result = await db.execute(
select(TokenStat).where(
TokenStat.stat_date >= start_date,
TokenStat.stat_date <= end_date,
).order_by(TokenStat.stat_date)
)
rows = result.scalars().all()
stat_map = {r.stat_date.isoformat(): r.total_tokens for r in rows}
trend = []
for i in range(days):
d = (start_date + timedelta(days=i)).isoformat()
trend.append({"date": d, "tokens": stat_map.get(d, 0)})
return trend
async def get_monthly_token_trend(self, db: AsyncSession) -> list:
"""近12个月Token消耗"""
today = date.today()
months = []
for i in range(11, -1, -1):
if today.month - i <= 0:
year = today.year - 1
month = today.month - i + 12
else:
year = today.year
month = today.month - i
months.append((year, month))
trend = []
for year, month in months:
start = date(year, month, 1)
if month == 12:
end = date(year + 1, 1, 1) - timedelta(days=1)
else:
end = date(year, month + 1, 1) - timedelta(days=1)
result = await db.execute(
select(func.sum(TokenStat.total_tokens)).where(
TokenStat.stat_date >= start, TokenStat.stat_date <= end
)
)
total = result.scalar() or 0
trend.append({"month": f"{year}-{month:02d}", "tokens": total})
return trend
async def get_interaction_records(
self, db: AsyncSession,
page: int = 1, page_size: int = 20,
user_id: int = None, interact_type: str = None,
status: int = None, start_date: str = None,
end_date: str = None, keyword: str = None
) -> dict:
query = select(InteractionRecord)
conditions = []
if user_id:
conditions.append(InteractionRecord.user_id == user_id)
if interact_type:
conditions.append(InteractionRecord.interact_type == interact_type)
if status is not None:
conditions.append(InteractionRecord.status == status)
if start_date:
conditions.append(InteractionRecord.executed_at >= start_date)
if end_date:
conditions.append(InteractionRecord.executed_at <= end_date + " 23:59:59")
if keyword:
from sqlalchemy import or_
conditions.append(
or_(InteractionRecord.article_title.like(f"%{keyword}%"),
InteractionRecord.content.like(f"%{keyword}%"),
InteractionRecord.user_nickname.like(f"%{keyword}%"))
)
if conditions:
query = query.where(and_(*conditions))
count_q = select(func.count()).select_from(query.subquery())
total = (await db.execute(count_q)).scalar()
query = query.order_by(InteractionRecord.executed_at.desc()).offset(
(page - 1) * page_size
).limit(page_size)
result = await db.execute(query)
records = result.scalars().all()
INTERACT_LABELS = {
"comment": "评论", "reply": "回复", "like": "点赞",
"collect": "收藏", "forward": "转发"
}
STATUS_LABELS = {0: "执行中", 1: "成功", 2: "失败"}
items = []
for r in records:
items.append({
"id": r.id, "user_id": r.user_id,
"user_nickname": r.user_nickname, "user_account": r.user_account,
"article_id": r.article_id, "article_title": r.article_title,
"interact_type": r.interact_type,
"interact_type_label": INTERACT_LABELS.get(r.interact_type, r.interact_type),
"content": r.content, "token_consumed": r.token_consumed,
"status": r.status, "status_label": STATUS_LABELS.get(r.status, "未知"),
"error_msg": r.error_msg, "retry_count": r.retry_count,
"executed_at": _fmt_dt(r.executed_at),
})
return {"total": total, "page": page, "page_size": page_size, "items": items}
stats_service = StatsService()

View File

@@ -0,0 +1,358 @@
"""虚拟用户业务服务"""
import io
import uuid
from datetime import datetime, timezone
def _fmt_dt(dt):
if dt is None: return None
if dt.tzinfo is None: dt = dt.replace(tzinfo=timezone.utc)
return dt.isoformat()
from typing import List, Optional, Tuple
import pandas as pd
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete, func, and_, or_
from fastapi import HTTPException
from app.models import VirtualUser, UserPersonality
from app.schemas import UserCreateRequest, UserUpdateRequest
from app.utils.crypto import encrypt, decrypt
from app.services.ai_service import ai_service
from app.core.logger import logger
STATUS_LABELS = {0: "未登录", 1: "登录中", 2: "已登录", 3: "登录失效", 4: "封禁"}
ACTIVITY_LABELS = {0: "", 1: "", 2: ""}
ACTIVITY_COMMENT_LIMITS = {0: (3, 5), 1: (8, 15), 2: (20, 30)}
class UserService:
async def get_users(
self, db: AsyncSession,
page: int = 1, page_size: int = 20,
keyword: str = None, status: int = None,
is_enabled: int = None
) -> Tuple[int, List[dict]]:
query = select(VirtualUser)
conditions = []
if keyword:
conditions.append(
or_(VirtualUser.nickname.like(f"%{keyword}%"),
VirtualUser.account.like(f"%{keyword}%"))
)
if status is not None:
conditions.append(VirtualUser.status == status)
if is_enabled is not None:
conditions.append(VirtualUser.is_enabled == is_enabled)
if conditions:
query = query.where(and_(*conditions))
count_result = await db.execute(
select(func.count()).select_from(query.subquery())
)
total = count_result.scalar()
query = query.offset((page - 1) * page_size).limit(page_size).order_by(VirtualUser.created_at.desc())
result = await db.execute(query)
users = result.scalars().all()
items = []
for u in users:
# 获取人格
p_result = await db.execute(select(UserPersonality).where(UserPersonality.user_id == u.id))
personality = p_result.scalar_one_or_none()
items.append(self._format_user(u, personality))
return total, items
async def create_user(self, db: AsyncSession, req: UserCreateRequest) -> dict:
# 检查账号重复
existing = await db.execute(select(VirtualUser).where(VirtualUser.account == req.account))
if existing.scalar_one_or_none():
raise HTTPException(status_code=400, detail="账号已存在")
# 昵称选填:为空则自动生成
nickname = req.nickname or f"用户{req.account[-4:]}"
# 检查昵称重复(自动生成的若冲突则加随机后缀)
existing_nick = await db.execute(select(VirtualUser).where(VirtualUser.nickname == nickname))
if existing_nick.scalar_one_or_none():
import random, string
nickname = nickname + "_" + "".join(random.choices(string.digits, k=4))
user = VirtualUser(
nickname=nickname,
account=req.account,
password_enc=encrypt(req.password),
avatar_url=req.avatar_url,
activity_level=req.activity_level,
daily_comment_limit=req.daily_comment_limit,
daily_like_limit=req.daily_like_limit,
remark=req.remark,
status=0,
is_enabled=1,
)
db.add(user)
await db.flush()
# 自动生成AI人格
try:
await self._generate_personality(db, user)
except Exception as e:
logger.warning(f"人格生成失败,跳过: {e}")
await db.commit()
await db.refresh(user)
p_result = await db.execute(select(UserPersonality).where(UserPersonality.user_id == user.id))
personality = p_result.scalar_one_or_none()
return self._format_user(user, personality)
async def update_user(self, db: AsyncSession, user_id: int, req: UserUpdateRequest) -> dict:
user = await self._get_or_404(db, user_id)
if req.nickname and req.nickname != user.nickname:
existing = await db.execute(
select(VirtualUser).where(VirtualUser.nickname == req.nickname, VirtualUser.id != user_id)
)
if existing.scalar_one_or_none():
raise HTTPException(status_code=400, detail="昵称已被使用")
user.nickname = req.nickname
if req.password:
user.password_enc = encrypt(req.password)
if req.avatar_url is not None:
user.avatar_url = req.avatar_url
if req.activity_level is not None:
user.activity_level = req.activity_level
if req.daily_comment_limit is not None:
user.daily_comment_limit = req.daily_comment_limit
if req.daily_like_limit is not None:
user.daily_like_limit = req.daily_like_limit
if req.remark is not None:
user.remark = req.remark
if req.is_enabled is not None:
user.is_enabled = req.is_enabled
if req.is_enabled == 0:
user.status = 0 # 禁用后重置状态
await db.commit()
await db.refresh(user)
p_result = await db.execute(select(UserPersonality).where(UserPersonality.user_id == user.id))
personality = p_result.scalar_one_or_none()
return self._format_user(user, personality)
async def delete_user(self, db: AsyncSession, user_id: int):
user = await self._get_or_404(db, user_id)
await db.execute(delete(UserPersonality).where(UserPersonality.user_id == user_id))
await db.delete(user)
await db.commit()
async def batch_action(self, db: AsyncSession, user_ids: List[int], action: str):
"""批量操作"""
if action == "enable":
await db.execute(update(VirtualUser).where(VirtualUser.id.in_(user_ids)).values(is_enabled=1))
elif action == "disable":
await db.execute(update(VirtualUser).where(VirtualUser.id.in_(user_ids)).values(is_enabled=0, status=0))
elif action == "logout":
await db.execute(update(VirtualUser).where(VirtualUser.id.in_(user_ids)).values(status=0, session_token=None))
elif action == "delete":
await db.execute(delete(UserPersonality).where(UserPersonality.user_id.in_(user_ids)))
await db.execute(delete(VirtualUser).where(VirtualUser.id.in_(user_ids)))
await db.commit()
return {"affected": len(user_ids)}
async def generate_personality(self, db: AsyncSession, user_id: int) -> dict:
"""为用户生成/重新生成AI人格"""
user = await self._get_or_404(db, user_id)
# 删除旧人格
await db.execute(delete(UserPersonality).where(UserPersonality.user_id == user_id))
personality = await self._generate_personality(db, user)
await db.commit()
return self._format_personality(personality)
async def update_personality(self, db: AsyncSession, user_id: int, req) -> dict:
p_result = await db.execute(select(UserPersonality).where(UserPersonality.user_id == user_id))
personality = p_result.scalar_one_or_none()
if not personality:
raise HTTPException(status_code=404, detail="人格不存在")
for field, val in req.model_dump(exclude_none=True).items():
setattr(personality, field, val)
# 重新生成提示词
personality.comment_style_prompt = self._build_style_prompt(personality)
await db.commit()
await db.refresh(personality)
return self._format_personality(personality)
async def import_from_excel(self, db: AsyncSession, file_content: bytes) -> dict:
"""Excel批量导入 - 每行独立事务,互不影响"""
try:
df = pd.read_excel(io.BytesIO(file_content), engine='openpyxl')
except Exception:
df = pd.read_excel(io.BytesIO(file_content))
df.columns = [str(c).strip() for c in df.columns]
required_cols = {"新闻平台账号", "登录密码", "昵称"}
if not required_cols.issubset(set(df.columns)):
raise HTTPException(
status_code=400,
detail=f"缺少必填列: {required_cols - set(df.columns)},当前列: {list(df.columns)}"
)
success_count = 0
error_list = []
for idx, row in df.iterrows():
row_num = idx + 2
row_account = ""
try:
# 账号可能是数字类型(手机号),统一转为字符串
account = str(row.get("新闻平台账号", "") or "").strip().split(".")[0] # 去掉 .0 后缀
password = str(row.get("登录密码", "") or "").strip()
nickname = str(row.get("昵称", "") or "").strip()
row_account = account
if account.lower() in ("nan", "none", ""):
error_list.append({"row": row_num, "error": "账号为空"}); continue
if password.lower() in ("nan", "none", ""):
error_list.append({"row": row_num, "account": account, "error": "密码为空"}); continue
if len(password) < 6:
error_list.append({"row": row_num, "account": account, "error": "密码不足6位"}); continue
# 昵称选填为空时自动用账号末4位生成
if nickname.lower() in ("nan", "none", ""):
nickname = f"用户{account[-4:]}"
existing = await db.execute(select(VirtualUser).where(VirtualUser.account == account))
if existing.scalar_one_or_none():
error_list.append({"row": row_num, "account": account, "error": "账号已存在"}); continue
existing_nick = await db.execute(select(VirtualUser).where(VirtualUser.nickname == nickname))
if existing_nick.scalar_one_or_none():
error_list.append({"row": row_num, "account": account, "error": "昵称已被使用"}); continue
avatar = str(row.get("头像链接", "") or "").strip()
remark = str(row.get("备注", "") or "").strip()
user = VirtualUser(
nickname=nickname, account=account,
password_enc=encrypt(password),
avatar_url=avatar if avatar.lower() not in ("nan","none","") else None,
remark=remark if remark.lower() not in ("nan","none","") else None,
status=0, is_enabled=1, activity_level=1,
)
db.add(user)
await db.flush()
await db.commit()
try:
await db.refresh(user)
await self._generate_personality(db, user)
await db.commit()
except Exception as pe:
logger.warning(f"{row_num}行人格生成跳过: {pe}")
await db.rollback()
success_count += 1
except Exception as e:
await db.rollback()
error_list.append({"row": row_num, "account": row_account, "error": str(e)})
logger.warning(f"导入第{row_num}行失败: {e}")
return {"success": success_count, "failed": len(error_list), "errors": error_list}
async def export_to_excel(self, db: AsyncSession) -> bytes:
"""导出全量用户数据(不含密码)"""
result = await db.execute(select(VirtualUser).order_by(VirtualUser.created_at.desc()))
users = result.scalars().all()
rows = []
for u in users:
p_result = await db.execute(select(UserPersonality).where(UserPersonality.user_id == u.id))
p = p_result.scalar_one_or_none()
rows.append({
"ID": u.id, "昵称": u.nickname, "账号": u.account,
"状态": STATUS_LABELS.get(u.status, "未知"),
"活跃度": ACTIVITY_LABELS.get(u.activity_level, ""),
"性格": p.character_type if p else "", "语言风格": p.language_style if p else "",
"兴趣偏好": ",".join(p.interest_tags or []) if p else "",
"互动倾向": p.interact_tendency if p else "",
"累计互动": u.total_interactions, "今日评论": u.today_comment_count,
"今日点赞": u.today_like_count, "最后登录": u.last_login_at,
"最后互动": u.last_interact_at, "备注": u.remark,
"是否启用": "" if u.is_enabled else "", "创建时间": u.created_at,
})
df = pd.DataFrame(rows)
buf = io.BytesIO()
df.to_excel(buf, index=False, sheet_name="虚拟用户")
buf.seek(0)
return buf.read()
async def get_excel_template(self) -> bytes:
"""获取导入模板(账号+密码必填,其他选填)"""
df = pd.DataFrame(columns=["新闻平台账号", "登录密码", "昵称(选填)", "头像链接(选填)", "备注(选填)"])
df.loc[0] = ["13800138000", "password123", "(留空自动生成)", "", ""]
buf = io.BytesIO()
df.to_excel(buf, index=False, sheet_name="导入模板")
buf.seek(0)
return buf.read()
async def _generate_personality(self, db: AsyncSession, user: VirtualUser) -> UserPersonality:
"""调用AI生成人格"""
result = await ai_service.generate_personality(user.nickname, user.account)
personality = UserPersonality(
user_id=user.id,
character_type=result.get("character_type", "温和"),
language_style=result.get("language_style", "幽默"),
interest_tags=result.get("interest_tags", ["科技"]),
interact_tendency=result.get("interact_tendency", "爱评论"),
word_count_min=result.get("word_count_min", 20),
word_count_max=result.get("word_count_max", 80),
personality_desc=result.get("personality_desc", ""),
)
personality.comment_style_prompt = self._build_style_prompt(personality)
db.add(personality)
await db.flush()
return personality
def _build_style_prompt(self, p: UserPersonality) -> str:
interests = "".join(p.interest_tags or []) if p.interest_tags else "综合"
return (
f"你是一个{p.character_type}性格、{p.language_style}语言风格的新闻读者,"
f"主要对{interests}类内容感兴趣,互动倾向是{p.interact_tendency}"
f"评论字数控制在{p.word_count_min}~{p.word_count_max}字。"
f"个人简介:{p.personality_desc}"
)
def _format_user(self, u: VirtualUser, p: Optional[UserPersonality]) -> dict:
return {
"id": u.id, "nickname": u.nickname, "account": u.account,
"avatar_url": u.avatar_url,
"real_name": getattr(u, "real_name", None),
"sex": getattr(u, "sex", 0),
"platform_uid": getattr(u, "platform_uid", None),
"status": u.status,
"status_label": STATUS_LABELS.get(u.status, "未知"),
"activity_level": u.activity_level,
"activity_label": ACTIVITY_LABELS.get(u.activity_level, ""),
"daily_comment_limit": u.daily_comment_limit,
"daily_like_limit": u.daily_like_limit,
"today_comment_count": u.today_comment_count,
"today_like_count": u.today_like_count,
"total_interactions": u.total_interactions,
"last_login_at": _fmt_dt(u.last_login_at),
"last_interact_at": _fmt_dt(u.last_interact_at),
"remark": u.remark, "is_enabled": u.is_enabled,
"created_at": _fmt_dt(u.created_at),
"personality": self._format_personality(p) if p else None,
}
def _format_personality(self, p: Optional[UserPersonality]) -> Optional[dict]:
if not p:
return None
return {
"id": p.id, "user_id": p.user_id,
"character_type": p.character_type, "language_style": p.language_style,
"interest_tags": p.interest_tags or [], "interact_tendency": p.interact_tendency,
"word_count_min": p.word_count_min, "word_count_max": p.word_count_max,
"personality_desc": p.personality_desc,
"updated_at": _fmt_dt(p.updated_at),
}
async def _get_or_404(self, db: AsyncSession, user_id: int) -> VirtualUser:
result = await db.execute(select(VirtualUser).where(VirtualUser.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
return user
user_service = UserService()

View File

View File

@@ -0,0 +1,49 @@
"""AES加密工具 - 用于密码和API Key加密存储"""
import base64
import hashlib
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from app.core.config import settings
def _get_key() -> bytes:
"""获取32字节AES密钥"""
key = settings.AES_KEY.encode("utf-8")
return hashlib.sha256(key).digest()
def encrypt(plaintext: str) -> str:
"""AES-CBC加密"""
if not plaintext:
return ""
key = _get_key()
cipher = AES.new(key, AES.MODE_CBC)
ct_bytes = cipher.encrypt(pad(plaintext.encode("utf-8"), AES.block_size))
iv = base64.b64encode(cipher.iv).decode("utf-8")
ct = base64.b64encode(ct_bytes).decode("utf-8")
return f"{iv}:{ct}"
def decrypt(ciphertext: str) -> str:
"""AES-CBC解密"""
if not ciphertext or ":" not in ciphertext:
return ""
try:
iv_str, ct_str = ciphertext.split(":", 1)
key = _get_key()
iv = base64.b64decode(iv_str)
ct = base64.b64decode(ct_str)
cipher = AES.new(key, AES.MODE_CBC, iv)
pt = unpad(cipher.decrypt(ct), AES.block_size)
return pt.decode("utf-8")
except Exception:
return ""
def mask_password(password: str) -> str:
"""密码脱敏显示"""
if not password:
return ""
if len(password) <= 2:
return "*" * len(password)
return password[0] + "*" * (len(password) - 2) + password[-1]