163 lines
4.9 KiB
Python
163 lines
4.9 KiB
Python
"""
|
|
虚拟用户管理 API
|
|
"""
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile, File
|
|
from sqlalchemy.orm import Session
|
|
from typing import List, Optional
|
|
import pandas as pd
|
|
import io
|
|
|
|
from app.models.base import get_db
|
|
from app.schemas.virtual_user import (
|
|
VirtualUserCreate,
|
|
VirtualUserUpdate,
|
|
VirtualUserResponse,
|
|
VirtualUserListResponse,
|
|
VirtualUserGenerateRequest,
|
|
ActivityLevel,
|
|
UserStatus
|
|
)
|
|
from app.services.virtual_user_service import VirtualUserService, get_virtual_user_service
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.get("", response_model=VirtualUserListResponse)
|
|
def get_virtual_users(
|
|
page: int = Query(1, ge=1, description="页码"),
|
|
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
|
status: Optional[UserStatus] = Query(None, description="状态筛选"),
|
|
search: Optional[str] = Query(None, description="搜索关键词"),
|
|
db: Session = Depends(get_db),
|
|
service: VirtualUserService = Depends(get_virtual_user_service)
|
|
):
|
|
"""获取虚拟用户列表"""
|
|
result = service.get_users(page=page, page_size=page_size, status=status, search=search)
|
|
return result
|
|
|
|
|
|
@router.get("/{user_id}", response_model=VirtualUserResponse)
|
|
def get_virtual_user(
|
|
user_id: int,
|
|
db: Session = Depends(get_db),
|
|
service: VirtualUserService = Depends(get_virtual_user_service)
|
|
):
|
|
"""获取虚拟用户详情"""
|
|
user = service.get_user_by_id(user_id)
|
|
if not user:
|
|
raise HTTPException(status_code=404, detail="User not found")
|
|
return user
|
|
|
|
|
|
@router.post("", response_model=VirtualUserResponse)
|
|
def create_virtual_user(
|
|
user_data: VirtualUserCreate,
|
|
db: Session = Depends(get_db),
|
|
service: VirtualUserService = Depends(get_virtual_user_service)
|
|
):
|
|
"""创建虚拟用户"""
|
|
user = service.create_user(
|
|
username=user_data.username,
|
|
password=user_data.password,
|
|
nickname=user_data.nickname,
|
|
writing_style=user_data.writing_style,
|
|
activity_level=user_data.activity_level,
|
|
avatar_url=user_data.avatar_url,
|
|
persona_description=user_data.persona_description
|
|
)
|
|
|
|
if not user:
|
|
raise HTTPException(status_code=400, detail="Failed to create user (username may exist)")
|
|
|
|
return user
|
|
|
|
|
|
@router.post("/generate", response_model=VirtualUserListResponse)
|
|
def generate_virtual_users(
|
|
request: VirtualUserGenerateRequest,
|
|
db: Session = Depends(get_db),
|
|
service: VirtualUserService = Depends(get_virtual_user_service)
|
|
):
|
|
"""批量生成虚拟用户"""
|
|
users = service.generate_users(
|
|
count=request.count,
|
|
writing_styles=request.writing_styles,
|
|
activity_levels=request.activity_levels,
|
|
generate_persona=request.generate_persona
|
|
)
|
|
|
|
return {"total": len(users), "items": users}
|
|
|
|
|
|
@router.put("/{user_id}", response_model=VirtualUserResponse)
|
|
def update_virtual_user(
|
|
user_id: int,
|
|
user_data: VirtualUserUpdate,
|
|
db: Session = Depends(get_db),
|
|
service: VirtualUserService = Depends(get_virtual_user_service)
|
|
):
|
|
"""更新虚拟用户"""
|
|
update_data = user_data.model_dump(exclude_unset=True)
|
|
user = service.update_user(user_id, **update_data)
|
|
|
|
if not user:
|
|
raise HTTPException(status_code=404, detail="User not found")
|
|
|
|
return user
|
|
|
|
|
|
@router.delete("/{user_id}")
|
|
def delete_virtual_user(
|
|
user_id: int,
|
|
db: Session = Depends(get_db),
|
|
service: VirtualUserService = Depends(get_virtual_user_service)
|
|
):
|
|
"""删除虚拟用户"""
|
|
success = service.delete_user(user_id)
|
|
if not success:
|
|
raise HTTPException(status_code=404, detail="User not found")
|
|
return {"message": "User deleted successfully"}
|
|
|
|
|
|
@router.post("/import", response_model=dict)
|
|
def import_virtual_users(
|
|
file: UploadFile = File(...),
|
|
generate_persona: bool = Query(True, description="是否生成 AI 人格描述"),
|
|
db: Session = Depends(get_db),
|
|
service: VirtualUserService = Depends(get_virtual_user_service)
|
|
):
|
|
"""从 Excel 导入虚拟用户"""
|
|
try:
|
|
# 读取 Excel 文件
|
|
contents = file.file.read()
|
|
df = pd.read_excel(io.BytesIO(contents))
|
|
|
|
# 转换为字典列表
|
|
users_data = df.to_dict('records')
|
|
|
|
# 导入用户
|
|
result = service.import_users_from_excel(
|
|
users_data=users_data,
|
|
generate_persona=generate_persona
|
|
)
|
|
|
|
return {
|
|
"message": "Import completed",
|
|
"success_count": result["success_count"],
|
|
"failed_count": result["failed_count"]
|
|
}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"Import failed: {str(e)}")
|
|
|
|
|
|
@router.get("/{user_id}/stats")
|
|
def get_user_stats(
|
|
user_id: int,
|
|
db: Session = Depends(get_db),
|
|
service: VirtualUserService = Depends(get_virtual_user_service)
|
|
):
|
|
"""获取用户统计信息"""
|
|
stats = service.get_user_stats(user_id)
|
|
return stats
|