- 日志时间改为北京时间(TZ=Asia/Shanghai) - 评论达上限后继续执行点赞/收藏/转发 - 用户信息同步改用 PATCH /v2/users/current - 一键登出全部功能 - 一键登出全部前端按钮 - update.sh 一键更新脚本
371 lines
14 KiB
Python
Executable File
371 lines
14 KiB
Python
Executable File
"""虚拟用户管理接口"""
|
||
from typing import Optional
|
||
from fastapi import APIRouter, Depends, Query, UploadFile, File, HTTPException
|
||
from fastapi.responses import StreamingResponse
|
||
import io
|
||
|
||
from app.core.database import get_db
|
||
from app.schemas import ApiResponse, UserCreateRequest, UserUpdateRequest, UserBatchRequest, PersonalityUpdateRequest
|
||
from app.services.user_service import user_service
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
@router.get("")
|
||
async def list_users(
|
||
page: int = Query(default=1, ge=1),
|
||
page_size: int = Query(default=20, ge=1, le=100),
|
||
keyword: Optional[str] = Query(default=None),
|
||
status: Optional[int] = Query(default=None),
|
||
is_enabled: Optional[int] = Query(default=None),
|
||
db=Depends(get_db)
|
||
):
|
||
"""获取虚拟用户列表"""
|
||
total, items = await user_service.get_users(db, page, page_size, keyword, status, is_enabled)
|
||
return ApiResponse(data={"total": total, "page": page, "page_size": page_size, "items": items})
|
||
|
||
|
||
@router.post("")
|
||
async def create_user(req: UserCreateRequest, db=Depends(get_db)):
|
||
"""创建虚拟用户"""
|
||
user = await user_service.create_user(db, req)
|
||
return ApiResponse(data=user, message="用户创建成功")
|
||
|
||
|
||
@router.get("/{user_id}")
|
||
async def get_user(user_id: int, db=Depends(get_db)):
|
||
"""获取单个用户详情"""
|
||
total, items = await user_service.get_users(db, 1, 1)
|
||
from sqlalchemy import select
|
||
from app.models import VirtualUser, UserPersonality
|
||
from app.services.user_service import user_service as svc
|
||
result = await db.execute(select(VirtualUser).where(VirtualUser.id == user_id))
|
||
user = result.scalar_one_or_none()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
p_result = await db.execute(select(UserPersonality).where(UserPersonality.user_id == user_id))
|
||
personality = p_result.scalar_one_or_none()
|
||
return ApiResponse(data=svc._format_user(user, personality))
|
||
|
||
|
||
@router.put("/{user_id}")
|
||
async def update_user(user_id: int, req: UserUpdateRequest, db=Depends(get_db)):
|
||
"""更新用户信息(sync_to_platform=true 时同步到目标平台)"""
|
||
result = await user_service.update_user(db, user_id, req)
|
||
|
||
if req.sync_to_platform:
|
||
from sqlalchemy import select
|
||
from app.models import VirtualUser as _VU
|
||
from app.services.news_service import news_service
|
||
ur = await db.execute(select(_VU).where(_VU.id == user_id))
|
||
user = ur.scalar_one_or_none()
|
||
if user and user.status == 2:
|
||
ok, err = await news_service.update_user_profile(
|
||
db, user,
|
||
nick_name=req.nickname,
|
||
real_name=req.real_name,
|
||
sex=req.sex,
|
||
description=req.description,
|
||
email=req.email,
|
||
)
|
||
if not ok:
|
||
return ApiResponse(data=result,
|
||
message=f"本地已保存,同步到平台失败: {err}", code=206)
|
||
|
||
return ApiResponse(data=result, message="更新成功")
|
||
|
||
|
||
@router.delete("/{user_id}")
|
||
async def delete_user(user_id: int, db=Depends(get_db)):
|
||
"""删除用户"""
|
||
await user_service.delete_user(db, user_id)
|
||
return ApiResponse(message="删除成功")
|
||
|
||
|
||
@router.post("/batch/action")
|
||
async def batch_action(req: UserBatchRequest, db=Depends(get_db)):
|
||
"""批量操作用户"""
|
||
result = await user_service.batch_action(db, req.user_ids, req.action)
|
||
return ApiResponse(data=result, message="批量操作成功")
|
||
|
||
|
||
@router.post("/{user_id}/login")
|
||
async def manual_login(user_id: int, db=Depends(get_db)):
|
||
"""手动触发用户登录"""
|
||
from app.services.news_service import news_service
|
||
from sqlalchemy import select
|
||
from app.models import VirtualUser
|
||
result = await db.execute(select(VirtualUser).where(VirtualUser.id == user_id))
|
||
user = result.scalar_one_or_none()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
success = await news_service.login(db, user)
|
||
if success:
|
||
return ApiResponse(message="登录成功")
|
||
raise HTTPException(status_code=400, detail="登录失败,请检查账号密码")
|
||
|
||
|
||
@router.post("/{user_id}/logout")
|
||
async def manual_logout(user_id: int, db=Depends(get_db)):
|
||
"""手动登出"""
|
||
from app.services.news_service import news_service
|
||
await news_service.logout(db, user_id)
|
||
return ApiResponse(message="已登出")
|
||
|
||
|
||
@router.post("/{user_id}/personality/generate")
|
||
async def generate_personality(user_id: int, db=Depends(get_db)):
|
||
"""重新生成AI人格"""
|
||
personality = await user_service.generate_personality(db, user_id)
|
||
return ApiResponse(data=personality, message="人格生成成功")
|
||
|
||
|
||
@router.put("/{user_id}/personality")
|
||
async def update_personality(user_id: int, req: PersonalityUpdateRequest, db=Depends(get_db)):
|
||
"""手动编辑人格属性"""
|
||
personality = await user_service.update_personality(db, user_id, req)
|
||
return ApiResponse(data=personality, message="人格更新成功")
|
||
|
||
|
||
@router.get("/excel/template")
|
||
async def download_template():
|
||
"""下载Excel导入模板"""
|
||
content = await user_service.get_excel_template()
|
||
return StreamingResponse(
|
||
io.BytesIO(content),
|
||
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||
headers={"Content-Disposition": "attachment; filename=virtual_users_template.xlsx"}
|
||
)
|
||
|
||
|
||
@router.post("/excel/import")
|
||
async def import_excel(file: UploadFile = File(...), db=Depends(get_db)):
|
||
"""Excel批量导入"""
|
||
if not file.filename.endswith((".xlsx", ".xls")):
|
||
raise HTTPException(status_code=400, detail="仅支持Excel文件(.xlsx/.xls)")
|
||
content = await file.read()
|
||
result = await user_service.import_from_excel(db, content)
|
||
return ApiResponse(data=result, message=f"导入完成:成功{result['success']}条,失败{result['failed']}条")
|
||
|
||
|
||
@router.get("/excel/export")
|
||
async def export_excel(db=Depends(get_db)):
|
||
"""导出用户数据Excel"""
|
||
content = await user_service.export_to_excel(db)
|
||
return StreamingResponse(
|
||
io.BytesIO(content),
|
||
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||
headers={"Content-Disposition": "attachment; filename=virtual_users_export.xlsx"}
|
||
)
|
||
|
||
|
||
@router.post("/deduplicate")
|
||
async def deduplicate_users(db=Depends(get_db)):
|
||
"""删除重复用户(保留最早创建的一条)"""
|
||
from sqlalchemy import text
|
||
# 找出重复的账号,保留 id 最小的,删除其他的
|
||
result = await db.execute(
|
||
text("""
|
||
DELETE FROM virtual_users
|
||
WHERE id NOT IN (
|
||
SELECT MIN(id) FROM virtual_users GROUP BY account
|
||
)
|
||
""")
|
||
)
|
||
await db.commit()
|
||
deleted = result.rowcount
|
||
return ApiResponse(data={"deleted": deleted}, message=f"已清理 {deleted} 条重复数据")
|
||
|
||
|
||
@router.post("/clear-all")
|
||
async def clear_all_users(db=Depends(get_db)):
|
||
"""清空所有用户(慎用)"""
|
||
from sqlalchemy import text
|
||
from app.core.redis_client import get_redis
|
||
await db.execute(text("DELETE FROM user_personalities"))
|
||
await db.execute(text("DELETE FROM virtual_users"))
|
||
await db.commit()
|
||
return ApiResponse(message="已清空所有用户数据")
|
||
|
||
|
||
@router.post("/login-all")
|
||
async def batch_login_all(db=Depends(get_db)):
|
||
"""一键登录所有未登录/登录失效的用户"""
|
||
from sqlalchemy import select
|
||
from app.services.news_service import news_service
|
||
from app.models import VirtualUser as _VU
|
||
from app.core.database import AsyncSessionLocal
|
||
import asyncio
|
||
|
||
# 先用当前 session 查出所有待登录用户 ID
|
||
result_r = await db.execute(
|
||
select(_VU.id, _VU.account).where(
|
||
_VU.is_enabled == 1,
|
||
_VU.status.in_([0, 3])
|
||
)
|
||
)
|
||
rows = result_r.all()
|
||
if not rows:
|
||
return ApiResponse(message="没有需要登录的用户", data={"count": 0})
|
||
|
||
user_ids = [r[0] for r in rows]
|
||
total = len(user_ids)
|
||
success = failed = 0
|
||
|
||
# 每个用户独立 session,避免事务污染
|
||
async def login_one(uid: int):
|
||
async with AsyncSessionLocal() as s:
|
||
try:
|
||
ur = await s.execute(select(_VU).where(_VU.id == uid))
|
||
u = ur.scalar_one_or_none()
|
||
if u:
|
||
return await news_service.login(s, u)
|
||
except Exception as e:
|
||
logger.warning(f"login_one {uid} 异常: {e}")
|
||
return False
|
||
return False
|
||
|
||
batch_size = 5
|
||
for i in range(0, total, batch_size):
|
||
batch_ids = user_ids[i:i+batch_size]
|
||
results = await asyncio.gather(*[login_one(uid) for uid in batch_ids], return_exceptions=True)
|
||
for r in results:
|
||
if r is True: success += 1
|
||
else: failed += 1
|
||
if i + batch_size < total:
|
||
await asyncio.sleep(1) # 批次间隔避免过于集中
|
||
|
||
return ApiResponse(
|
||
message=f"登录完成:成功 {success} 个,失败 {failed} 个",
|
||
data={"success": success, "failed": failed, "total": total}
|
||
)
|
||
|
||
|
||
@router.post("/sync-all-profiles")
|
||
async def sync_all_profiles(db=Depends(get_db)):
|
||
"""
|
||
同步所有已登录用户的平台信息(昵称/真实姓名/性别/头像)到本系统
|
||
从登录 session 中的 token 调用目标平台接口获取最新用户信息
|
||
"""
|
||
from sqlalchemy import select, update
|
||
from app.models import VirtualUser as _VU
|
||
from app.core.database import AsyncSessionLocal
|
||
from app.core.redis_client import get_session
|
||
import httpx, asyncio
|
||
|
||
# 查出所有已登录用户
|
||
result_r = await db.execute(select(_VU).where(_VU.status == 2, _VU.is_enabled == 1))
|
||
users = result_r.scalars().all()
|
||
if not users:
|
||
return ApiResponse(message="没有已登录的用户", data={"synced": 0})
|
||
|
||
synced = failed = 0
|
||
|
||
async def sync_one(uid: int):
|
||
"""从登录 session 中提取已缓存的用户信息,直接写入数据库,无需调用外部接口"""
|
||
async with AsyncSessionLocal() as s:
|
||
try:
|
||
sess = await get_session(uid)
|
||
if not sess:
|
||
return False
|
||
platform_uid = sess.get("platform_uid", "")
|
||
# 登录成功时 session 里已存有用户信息
|
||
vals = {}
|
||
if platform_uid: vals["platform_uid"] = platform_uid
|
||
# session 里的字段(登录时写入)
|
||
if sess.get("nickname"): vals["nickname"] = sess["nickname"]
|
||
if sess.get("real_name"): vals["real_name"] = sess["real_name"]
|
||
if sess.get("sex"): vals["sex"] = int(sess["sex"])
|
||
if sess.get("avatar"): vals["avatar_url"] = sess["avatar"]
|
||
if vals:
|
||
await s.execute(update(_VU).where(_VU.id == uid).values(**vals))
|
||
await s.commit()
|
||
return True
|
||
except Exception as e:
|
||
logger.warning(f"sync_one {uid} 失败: {e}")
|
||
return False
|
||
|
||
results = await asyncio.gather(*[sync_one(u.id) for u in users], return_exceptions=True)
|
||
for r in results:
|
||
if r is True: synced += 1
|
||
else: failed += 1
|
||
|
||
return ApiResponse(
|
||
message=f"同步完成:成功 {synced} 个,失败/跳过 {failed} 个",
|
||
data={"synced": synced, "failed": failed, "total": len(users)}
|
||
)
|
||
|
||
@router.post("/{user_id}/upload-avatar")
|
||
async def upload_avatar(
|
||
user_id: int,
|
||
file: UploadFile = File(...),
|
||
sync_to_platform: bool = Query(default=True),
|
||
db=Depends(get_db)
|
||
):
|
||
"""上传头像并可选同步到目标平台"""
|
||
from sqlalchemy import select, update
|
||
from app.models import VirtualUser as _VU
|
||
from app.services.news_service import news_service
|
||
|
||
ur = await db.execute(select(_VU).where(_VU.id == user_id))
|
||
user = ur.scalar_one_or_none()
|
||
if not user:
|
||
return ApiResponse(code=404, message="用户不存在")
|
||
|
||
# 读取文件内容
|
||
file_bytes = await file.read()
|
||
if len(file_bytes) > 5 * 1024 * 1024:
|
||
return ApiResponse(code=400, message="头像文件不能超过5MB")
|
||
|
||
avatar_url = None
|
||
|
||
if sync_to_platform and user.status == 2:
|
||
# 上传到目标平台
|
||
ok, result = await news_service.upload_avatar(db, user, file_bytes, file.filename)
|
||
if ok:
|
||
avatar_url = result
|
||
else:
|
||
return ApiResponse(code=500, message=f"头像上传到平台失败: {result}")
|
||
else:
|
||
# 仅本地存储(转 base64 或存储到本地)
|
||
import base64
|
||
avatar_url = f"data:{file.content_type};base64,{base64.b64encode(file_bytes).decode()}"
|
||
|
||
# 更新数据库
|
||
await db.execute(update(_VU).where(_VU.id == user_id).values(avatar_url=avatar_url))
|
||
await db.commit()
|
||
|
||
# 如果已同步到平台,再调用 update_user_profile 更新头像字段
|
||
if sync_to_platform and user.status == 2 and avatar_url:
|
||
await news_service.update_user_profile(db, user, avatar=avatar_url)
|
||
|
||
return ApiResponse(data={"avatar_url": avatar_url}, message="头像更新成功")
|
||
@router.post("/logout-all")
|
||
async def batch_logout_all(db=Depends(get_db)):
|
||
"""一键登出所有已登录用户"""
|
||
from sqlalchemy import select, update
|
||
from app.models import VirtualUser as _VU
|
||
from app.core.redis_client import delete_session
|
||
|
||
result_r = await db.execute(
|
||
select(_VU.id).where(_VU.status == 2, _VU.is_enabled == 1)
|
||
)
|
||
rows = result_r.all()
|
||
if not rows:
|
||
return ApiResponse(message="没有已登录的用户", data={"count": 0})
|
||
|
||
count = 0
|
||
for row in rows:
|
||
try:
|
||
await delete_session(row[0])
|
||
count += 1
|
||
except Exception:
|
||
pass
|
||
|
||
# 更新所有用户状态为未登录
|
||
await db.execute(
|
||
update(_VU).where(_VU.status == 2).values(status=0)
|
||
)
|
||
await db.commit()
|
||
return ApiResponse(message=f"已登出 {count} 个用户", data={"count": count})
|