1.0.0初始化源代码
This commit is contained in:
3
backend/app/api/__init__.py
Normal file
3
backend/app/api/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
API 路由模块初始化
|
||||
"""
|
||||
136
backend/app/api/ai_model.py
Normal file
136
backend/app/api/ai_model.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
AI 模型配置 API
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
|
||||
from app.models.base import get_db
|
||||
from app.models.ai_model import AIModelConfig
|
||||
from app.schemas.ai_model import (
|
||||
AIModelConfigCreate,
|
||||
AIModelConfigUpdate,
|
||||
AIModelConfigResponse,
|
||||
AIModelTestRequest,
|
||||
AIModelTestResponse
|
||||
)
|
||||
from app.services.ai_service import ai_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=List[AIModelConfigResponse])
|
||||
def get_ai_models(
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取所有 AI 模型配置"""
|
||||
models = db.query(AIModelConfig).all()
|
||||
return models
|
||||
|
||||
|
||||
@router.get("/{model_id}", response_model=AIModelConfigResponse)
|
||||
def get_ai_model(
|
||||
model_id: int,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取 AI 模型详情"""
|
||||
model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
return model
|
||||
|
||||
|
||||
@router.post("", response_model=AIModelConfigResponse)
|
||||
def create_ai_model(
|
||||
model_data: AIModelConfigCreate,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""创建 AI 模型配置"""
|
||||
# 检查是否已存在
|
||||
existing = db.query(AIModelConfig).filter(
|
||||
AIModelConfig.model_name == model_data.model_name
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="Model already exists")
|
||||
|
||||
model = AIModelConfig(**model_data.model_dump())
|
||||
|
||||
# 如果是第一个模型,设为默认
|
||||
if not db.query(AIModelConfig).filter(AIModelConfig.is_default == True).first():
|
||||
model.is_default = True
|
||||
|
||||
db.add(model)
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@router.put("/{model_id}", response_model=AIModelConfigResponse)
|
||||
def update_ai_model(
|
||||
model_id: int,
|
||||
model_data: AIModelConfigUpdate,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""更新 AI 模型配置"""
|
||||
model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
update_data = model_data.model_dump(exclude_unset=True)
|
||||
|
||||
# 如果设置默认模型,先取消其他模型的默认状态
|
||||
if update_data.get("is_default"):
|
||||
db.query(AIModelConfig).update({"is_default": False})
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(model, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@router.delete("/{model_id}")
|
||||
def delete_ai_model(
|
||||
model_id: int,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""删除 AI 模型配置"""
|
||||
model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
db.delete(model)
|
||||
db.commit()
|
||||
|
||||
return {"message": "Model deleted successfully"}
|
||||
|
||||
|
||||
@router.post("/test", response_model=AIModelTestResponse)
|
||||
async def test_ai_model(
|
||||
request: AIModelTestRequest,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""测试 AI 模型"""
|
||||
model = db.query(AIModelConfig).filter(AIModelConfig.id == request.model_id).first()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
model_config = {
|
||||
"provider": model.provider,
|
||||
"model_name": model.model_name,
|
||||
"api_key": model.api_key,
|
||||
"api_url": model.api_url,
|
||||
"temperature": model.temperature,
|
||||
"max_tokens": model.max_tokens,
|
||||
}
|
||||
|
||||
result = await ai_service.test_model(
|
||||
model_config=model_config,
|
||||
test_prompt=request.test_prompt
|
||||
)
|
||||
|
||||
return result
|
||||
150
backend/app/api/dashboard.py
Normal file
150
backend/app/api/dashboard.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
控制台 API
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from datetime import date, timedelta
|
||||
|
||||
from app.models.base import get_db
|
||||
from app.schemas.dashboard import DashboardStats, CoreStats, DailyUsageItem, MonthlyUsageItem
|
||||
from app.services.token_service import TokenService, get_token_service
|
||||
from app.services.virtual_user_service import VirtualUserService, get_virtual_user_service
|
||||
from app.models.virtual_user import VirtualUser, UserStatus
|
||||
from app.models.interaction import InteractionRecord, InteractionType, InteractionStatus
|
||||
from app.models.token_usage import TokenUsage
|
||||
from sqlalchemy import func, and_
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=DashboardStats)
|
||||
def get_dashboard_stats(
|
||||
db: Session = Depends(get_db),
|
||||
token_service: TokenService = Depends(get_token_service),
|
||||
user_service: VirtualUserService = Depends(get_virtual_user_service)
|
||||
):
|
||||
"""获取控制台统计数据"""
|
||||
|
||||
# 核心指标统计
|
||||
today = date.today()
|
||||
yesterday = today - timedelta(days=1)
|
||||
|
||||
# 用户统计
|
||||
total_users = db.query(VirtualUser).count()
|
||||
active_users = db.query(VirtualUser).filter(VirtualUser.status == UserStatus.ACTIVE).count()
|
||||
disabled_users = total_users - active_users
|
||||
|
||||
# 今日互动统计
|
||||
today_interactions = db.query(InteractionRecord).filter(
|
||||
and_(
|
||||
func.date(InteractionRecord.execution_time) == today,
|
||||
InteractionRecord.status == InteractionStatus.SUCCESS
|
||||
)
|
||||
).all()
|
||||
|
||||
today_comments = sum(1 for i in today_interactions if i.interaction_type == InteractionType.COMMENT)
|
||||
today_replies = sum(1 for i in today_interactions if i.interaction_type == InteractionType.REPLY)
|
||||
today_likes = sum(1 for i in today_interactions if i.interaction_type == InteractionType.LIKE)
|
||||
today_favorites = sum(1 for i in today_interactions if i.interaction_type == InteractionType.FAVORITE)
|
||||
today_shares = sum(1 for i in today_interactions if i.interaction_type == InteractionType.SHARE)
|
||||
|
||||
# 昨日互动统计
|
||||
yesterday_interactions = db.query(InteractionRecord).filter(
|
||||
and_(
|
||||
func.date(InteractionRecord.execution_time) == yesterday,
|
||||
InteractionRecord.status == InteractionStatus.SUCCESS
|
||||
)
|
||||
).all()
|
||||
|
||||
yesterday_comments = sum(1 for i in yesterday_interactions if i.interaction_type == InteractionType.COMMENT)
|
||||
yesterday_replies = sum(1 for i in yesterday_interactions if i.interaction_type == InteractionType.REPLY)
|
||||
|
||||
# Token 统计
|
||||
today_tokens = token_service.get_today_usage()
|
||||
month_tokens = token_service.get_month_usage()
|
||||
remaining_tokens = token_service.get_remaining_tokens()
|
||||
|
||||
core_stats = CoreStats(
|
||||
total_users=total_users,
|
||||
active_users=active_users,
|
||||
disabled_users=disabled_users,
|
||||
today_comments=today_comments,
|
||||
today_replies=today_replies,
|
||||
today_likes=today_likes,
|
||||
today_favorites=today_favorites,
|
||||
today_shares=today_shares,
|
||||
yesterday_comments=yesterday_comments,
|
||||
yesterday_replies=yesterday_replies,
|
||||
month_tokens=month_tokens,
|
||||
today_tokens=today_tokens,
|
||||
remaining_tokens=remaining_tokens
|
||||
)
|
||||
|
||||
# 每日 Token 使用(近 30 天)
|
||||
daily_usages = token_service.get_daily_usages(days=30)
|
||||
daily_items = [DailyUsageItem(date=u["date"], tokens=u["tokens"], comments=0, replies=0) for u in daily_usages]
|
||||
|
||||
# 每月 Token 使用(近 12 个月)
|
||||
monthly_usages = token_service.get_monthly_usages(months=12)
|
||||
monthly_items = [MonthlyUsageItem(month=u["month"], tokens=u["tokens"]) for u in monthly_usages]
|
||||
|
||||
# 最近互动记录
|
||||
recent_interactions = db.query(InteractionRecord).order_by(
|
||||
InteractionRecord.execution_time.desc()
|
||||
).limit(10).all()
|
||||
|
||||
return DashboardStats(
|
||||
core_stats=core_stats,
|
||||
daily_token_usages=daily_items,
|
||||
monthly_token_usages=monthly_items,
|
||||
recent_interactions=[
|
||||
{
|
||||
"id": r.id,
|
||||
"virtual_user_id": r.virtual_user_id,
|
||||
"interaction_type": r.interaction_type.value,
|
||||
"status": r.status.value,
|
||||
"execution_time": r.execution_time
|
||||
}
|
||||
for r in recent_interactions
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/token/stats")
|
||||
def get_token_stats(
|
||||
db: Session = Depends(get_db),
|
||||
token_service: TokenService = Depends(get_token_service)
|
||||
):
|
||||
"""获取 Token 统计"""
|
||||
today_used = token_service.get_today_usage()
|
||||
today_limit = 10000 # TODO: 从系统配置读取
|
||||
|
||||
return {
|
||||
"today_used": today_used,
|
||||
"today_limit": today_limit,
|
||||
"today_remaining": max(0, today_limit - today_used),
|
||||
"usage_percentage": round((today_used / today_limit) * 100, 2) if today_limit > 0 else 0
|
||||
}
|
||||
|
||||
|
||||
@router.get("/token/daily", response_model=List[DailyUsageItem])
|
||||
def get_daily_token_usage(
|
||||
days: int = Query(30, ge=1, le=90, description="天数"),
|
||||
db: Session = Depends(get_db),
|
||||
token_service: TokenService = Depends(get_token_service)
|
||||
):
|
||||
"""获取每日 Token 使用"""
|
||||
usages = token_service.get_daily_usages(days=days)
|
||||
return [DailyUsageItem(date=u["date"], tokens=u["tokens"], comments=0, replies=0) for u in usages]
|
||||
|
||||
|
||||
@router.get("/token/monthly", response_model=List[MonthlyUsageItem])
|
||||
def get_monthly_token_usage(
|
||||
months: int = Query(12, ge=1, le=24, description="月数"),
|
||||
db: Session = Depends(get_db),
|
||||
token_service: TokenService = Depends(get_token_service)
|
||||
):
|
||||
"""获取每月 Token 使用"""
|
||||
usages = token_service.get_monthly_usages(months=months)
|
||||
return [MonthlyUsageItem(month=u["month"], tokens=u["tokens"]) for u in usages]
|
||||
72
backend/app/api/interaction.py
Normal file
72
backend/app/api/interaction.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
互动管理 API
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
|
||||
from app.models.base import get_db
|
||||
from app.models.interaction import InteractionType
|
||||
from app.schemas.interaction import (
|
||||
InteractionRecordResponse,
|
||||
InteractionRecordListResponse,
|
||||
InteractionExecuteRequest
|
||||
)
|
||||
from app.services.interaction_service import InteractionService, get_interaction_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=InteractionRecordListResponse)
|
||||
def get_interaction_records(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
virtual_user_id: Optional[int] = Query(None, description="虚拟用户 ID"),
|
||||
interaction_type: Optional[InteractionType] = Query(None, description="互动类型"),
|
||||
db: Session = Depends(get_db),
|
||||
service: InteractionService = Depends(get_interaction_service)
|
||||
):
|
||||
"""获取互动记录列表"""
|
||||
# TODO: 实现筛选和分页查询
|
||||
return {"total": 0, "items": []}
|
||||
|
||||
|
||||
@router.get("/{record_id}", response_model=InteractionRecordResponse)
|
||||
def get_interaction_record(
|
||||
record_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
service: InteractionService = Depends(get_interaction_service)
|
||||
):
|
||||
"""获取互动记录详情"""
|
||||
# TODO: 实现详情查询
|
||||
raise HTTPException(status_code=404, detail="Record not found")
|
||||
|
||||
|
||||
@router.post("/execute", response_model=InteractionRecordResponse)
|
||||
async def execute_interaction(
|
||||
request: InteractionExecuteRequest,
|
||||
db: Session = Depends(get_db),
|
||||
service: InteractionService = Depends(get_interaction_service)
|
||||
):
|
||||
"""执行互动"""
|
||||
record = await service.execute_interaction(
|
||||
virtual_user_id=request.virtual_user_id,
|
||||
interaction_type=request.interaction_type,
|
||||
news_id=request.news_id
|
||||
)
|
||||
|
||||
if not record:
|
||||
raise HTTPException(status_code=400, detail="Failed to execute interaction")
|
||||
|
||||
return record
|
||||
|
||||
|
||||
@router.post("/retry/{record_id}")
|
||||
async def retry_interaction(
|
||||
record_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
service: InteractionService = Depends(get_interaction_service)
|
||||
):
|
||||
"""重试失败的互动"""
|
||||
# TODO: 实现重试逻辑
|
||||
raise HTTPException(status_code=404, detail="Record not found")
|
||||
19
backend/app/api/router.py
Normal file
19
backend/app/api/router.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
API 路由
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .virtual_user import router as virtual_user_router
|
||||
from .interaction import router as interaction_router
|
||||
from .ai_model import router as ai_model_router
|
||||
from .system_config import router as system_config_router
|
||||
from .dashboard import router as dashboard_router
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
# 注册各模块路由
|
||||
api_router.include_router(virtual_user_router, prefix="/virtual-users", tags=["虚拟用户管理"])
|
||||
api_router.include_router(interaction_router, prefix="/interactions", tags=["互动管理"])
|
||||
api_router.include_router(ai_model_router, prefix="/ai-models", tags=["AI 模型配置"])
|
||||
api_router.include_router(system_config_router, prefix="/system", tags=["系统设置"])
|
||||
api_router.include_router(dashboard_router, prefix="/dashboard", tags=["控制台"])
|
||||
126
backend/app/api/system_config.py
Normal file
126
backend/app/api/system_config.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""
|
||||
系统配置 API
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
|
||||
from app.models.base import get_db
|
||||
from app.models.system_config import SystemConfig
|
||||
from app.schemas.system_config import (
|
||||
SystemConfigResponse,
|
||||
SystemConfigUpdate,
|
||||
ScheduleConfig,
|
||||
LimitConfig,
|
||||
ProbabilityConfig
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=List[SystemConfigResponse])
|
||||
def get_system_configs(
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取所有系统配置"""
|
||||
configs = db.query(SystemConfig).all()
|
||||
return configs
|
||||
|
||||
|
||||
@router.get("/schedule", response_model=ScheduleConfig)
|
||||
def get_schedule_config(
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取调度配置"""
|
||||
from app.services.scheduler_service import scheduler_service
|
||||
|
||||
return ScheduleConfig(
|
||||
task_start_hour=settings.TASK_START_HOUR,
|
||||
task_end_hour=settings.TASK_END_HOUR,
|
||||
task_interval_min=settings.TASK_INTERVAL_MIN,
|
||||
task_interval_max=settings.TASK_INTERVAL_MAX,
|
||||
is_task_running=scheduler_service.is_running
|
||||
)
|
||||
|
||||
|
||||
@router.get("/limits", response_model=LimitConfig)
|
||||
def get_limit_config(
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取限额配置"""
|
||||
return LimitConfig(
|
||||
max_tokens_per_day=settings.MAX_TOKENS_PER_DAY,
|
||||
max_comments_per_user_per_day=settings.MAX_COMMENTS_PER_USER_PER_DAY,
|
||||
max_replies_per_user_per_day=settings.MAX_REPLIES_PER_USER_PER_DAY
|
||||
)
|
||||
|
||||
|
||||
@router.get("/probabilities", response_model=ProbabilityConfig)
|
||||
def get_probability_config(
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取概率配置"""
|
||||
return ProbabilityConfig(
|
||||
like_probability=settings.LIKE_PROBABILITY,
|
||||
favorite_probability=settings.FAVORITE_PROBABILITY,
|
||||
share_probability=settings.SHARE_PROBABILITY
|
||||
)
|
||||
|
||||
|
||||
@router.put("/schedule")
|
||||
def update_schedule_config(
|
||||
config: ScheduleConfig,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""更新调度配置"""
|
||||
# TODO: 更新系统配置表并重新加载
|
||||
return {"message": "Schedule config updated"}
|
||||
|
||||
|
||||
@router.put("/limits")
|
||||
def update_limit_config(
|
||||
config: LimitConfig,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""更新限额配置"""
|
||||
# TODO: 更新系统配置表
|
||||
return {"message": "Limit config updated"}
|
||||
|
||||
|
||||
@router.post("/scheduler/start")
|
||||
def start_scheduler(
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""启动定时任务"""
|
||||
from app.services.scheduler_service import scheduler_service
|
||||
|
||||
scheduler_service.start()
|
||||
scheduler_service.add_interaction_task()
|
||||
|
||||
return {"message": "Scheduler started", "running": scheduler_service.is_running}
|
||||
|
||||
|
||||
@router.post("/scheduler/stop")
|
||||
def stop_scheduler(
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""停止定时任务"""
|
||||
from app.services.scheduler_service import scheduler_service
|
||||
|
||||
scheduler_service.stop()
|
||||
|
||||
return {"message": "Scheduler stopped", "running": scheduler_service.is_running}
|
||||
|
||||
|
||||
@router.get("/scheduler/status")
|
||||
def get_scheduler_status(
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取定时任务状态"""
|
||||
from app.services.scheduler_service import scheduler_service
|
||||
|
||||
return {
|
||||
"is_running": scheduler_service.is_running,
|
||||
"jobs": [job.id for job in scheduler_service.scheduler.get_jobs()]
|
||||
}
|
||||
162
backend/app/api/virtual_user.py
Normal file
162
backend/app/api/virtual_user.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
虚拟用户管理 API
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile, File
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import pandas as pd
|
||||
import io
|
||||
|
||||
from app.models.base import get_db
|
||||
from app.schemas.virtual_user import (
|
||||
VirtualUserCreate,
|
||||
VirtualUserUpdate,
|
||||
VirtualUserResponse,
|
||||
VirtualUserListResponse,
|
||||
VirtualUserGenerateRequest,
|
||||
ActivityLevel,
|
||||
UserStatus
|
||||
)
|
||||
from app.services.virtual_user_service import VirtualUserService, get_virtual_user_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=VirtualUserListResponse)
|
||||
def get_virtual_users(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
status: Optional[UserStatus] = Query(None, description="状态筛选"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
db: Session = Depends(get_db),
|
||||
service: VirtualUserService = Depends(get_virtual_user_service)
|
||||
):
|
||||
"""获取虚拟用户列表"""
|
||||
result = service.get_users(page=page, page_size=page_size, status=status, search=search)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=VirtualUserResponse)
|
||||
def get_virtual_user(
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
service: VirtualUserService = Depends(get_virtual_user_service)
|
||||
):
|
||||
"""获取虚拟用户详情"""
|
||||
user = service.get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return user
|
||||
|
||||
|
||||
@router.post("", response_model=VirtualUserResponse)
|
||||
def create_virtual_user(
|
||||
user_data: VirtualUserCreate,
|
||||
db: Session = Depends(get_db),
|
||||
service: VirtualUserService = Depends(get_virtual_user_service)
|
||||
):
|
||||
"""创建虚拟用户"""
|
||||
user = service.create_user(
|
||||
username=user_data.username,
|
||||
password=user_data.password,
|
||||
nickname=user_data.nickname,
|
||||
writing_style=user_data.writing_style,
|
||||
activity_level=user_data.activity_level,
|
||||
avatar_url=user_data.avatar_url,
|
||||
persona_description=user_data.persona_description
|
||||
)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(status_code=400, detail="Failed to create user (username may exist)")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@router.post("/generate", response_model=VirtualUserListResponse)
|
||||
def generate_virtual_users(
|
||||
request: VirtualUserGenerateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
service: VirtualUserService = Depends(get_virtual_user_service)
|
||||
):
|
||||
"""批量生成虚拟用户"""
|
||||
users = service.generate_users(
|
||||
count=request.count,
|
||||
writing_styles=request.writing_styles,
|
||||
activity_levels=request.activity_levels,
|
||||
generate_persona=request.generate_persona
|
||||
)
|
||||
|
||||
return {"total": len(users), "items": users}
|
||||
|
||||
|
||||
@router.put("/{user_id}", response_model=VirtualUserResponse)
|
||||
def update_virtual_user(
|
||||
user_id: int,
|
||||
user_data: VirtualUserUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
service: VirtualUserService = Depends(get_virtual_user_service)
|
||||
):
|
||||
"""更新虚拟用户"""
|
||||
update_data = user_data.model_dump(exclude_unset=True)
|
||||
user = service.update_user(user_id, **update_data)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@router.delete("/{user_id}")
|
||||
def delete_virtual_user(
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
service: VirtualUserService = Depends(get_virtual_user_service)
|
||||
):
|
||||
"""删除虚拟用户"""
|
||||
success = service.delete_user(user_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return {"message": "User deleted successfully"}
|
||||
|
||||
|
||||
@router.post("/import", response_model=dict)
|
||||
def import_virtual_users(
|
||||
file: UploadFile = File(...),
|
||||
generate_persona: bool = Query(True, description="是否生成 AI 人格描述"),
|
||||
db: Session = Depends(get_db),
|
||||
service: VirtualUserService = Depends(get_virtual_user_service)
|
||||
):
|
||||
"""从 Excel 导入虚拟用户"""
|
||||
try:
|
||||
# 读取 Excel 文件
|
||||
contents = file.file.read()
|
||||
df = pd.read_excel(io.BytesIO(contents))
|
||||
|
||||
# 转换为字典列表
|
||||
users_data = df.to_dict('records')
|
||||
|
||||
# 导入用户
|
||||
result = service.import_users_from_excel(
|
||||
users_data=users_data,
|
||||
generate_persona=generate_persona
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Import completed",
|
||||
"success_count": result["success_count"],
|
||||
"failed_count": result["failed_count"]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Import failed: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{user_id}/stats")
|
||||
def get_user_stats(
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
service: VirtualUserService = Depends(get_virtual_user_service)
|
||||
):
|
||||
"""获取用户统计信息"""
|
||||
stats = service.get_user_stats(user_id)
|
||||
return stats
|
||||
Reference in New Issue
Block a user