- 虚拟用户管理(昵称/头像/性别/简介/邮箱同步到目标平台) - AI互动调度(点赞/收藏/评论/转发) - 日志时间改为北京时间 - 评论达上限后继续执行点赞收藏转发 - 一键登出全部功能 - 浅色主题UI
88 lines
3.2 KiB
Python
88 lines
3.2 KiB
Python
"""AI模型配置接口"""
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from sqlalchemy import select, update
|
|
|
|
from app.core.database import get_db
|
|
from app.schemas import ApiResponse, AIModelCreateRequest, AIModelUpdateRequest, AIModelTestRequest
|
|
from app.models import AIModelConfig
|
|
from app.utils.crypto import encrypt, decrypt
|
|
from app.services.ai_service import ai_service
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.get("")
|
|
async def list_models(db=Depends(get_db)):
|
|
result = await db.execute(select(AIModelConfig).order_by(AIModelConfig.created_at.desc()))
|
|
models = result.scalars().all()
|
|
items = [_format_model(m) for m in models]
|
|
return ApiResponse(data=items)
|
|
|
|
|
|
@router.post("")
|
|
async def create_model(req: AIModelCreateRequest, db=Depends(get_db)):
|
|
if req.is_default:
|
|
await db.execute(update(AIModelConfig).values(is_default=0))
|
|
model = AIModelConfig(
|
|
model_name=req.model_name,
|
|
provider=req.provider,
|
|
api_base_url=req.api_base_url,
|
|
api_key_enc=encrypt(req.api_key) if req.api_key else None,
|
|
model_version=req.model_version,
|
|
temperature=req.temperature,
|
|
max_tokens=req.max_tokens,
|
|
timeout_seconds=req.timeout_seconds,
|
|
is_default=req.is_default,
|
|
is_enabled=1,
|
|
)
|
|
db.add(model)
|
|
await db.commit()
|
|
await db.refresh(model)
|
|
return ApiResponse(data=_format_model(model), message="模型添加成功")
|
|
|
|
|
|
@router.put("/{model_id}")
|
|
async def update_model(model_id: int, req: AIModelUpdateRequest, db=Depends(get_db)):
|
|
result = await db.execute(select(AIModelConfig).where(AIModelConfig.id == model_id))
|
|
model = result.scalar_one_or_none()
|
|
if not model:
|
|
raise HTTPException(status_code=404, detail="模型不存在")
|
|
if req.is_default:
|
|
await db.execute(update(AIModelConfig).where(AIModelConfig.id != model_id).values(is_default=0))
|
|
for field, val in req.model_dump(exclude_none=True).items():
|
|
if field == "api_key":
|
|
model.api_key_enc = encrypt(val) if val else None
|
|
else:
|
|
setattr(model, field, val)
|
|
await db.commit()
|
|
await db.refresh(model)
|
|
return ApiResponse(data=_format_model(model), message="更新成功")
|
|
|
|
|
|
@router.delete("/{model_id}")
|
|
async def delete_model(model_id: int, db=Depends(get_db)):
|
|
result = await db.execute(select(AIModelConfig).where(AIModelConfig.id == model_id))
|
|
model = result.scalar_one_or_none()
|
|
if not model:
|
|
raise HTTPException(status_code=404, detail="模型不存在")
|
|
await db.delete(model)
|
|
await db.commit()
|
|
return ApiResponse(message="删除成功")
|
|
|
|
|
|
@router.post("/test")
|
|
async def test_model(req: AIModelTestRequest, db=Depends(get_db)):
|
|
result = await ai_service.test_model(db, req.model_id, req.test_prompt)
|
|
return ApiResponse(data=result)
|
|
|
|
|
|
def _format_model(m: AIModelConfig) -> dict:
|
|
return {
|
|
"id": m.id, "model_name": m.model_name, "provider": m.provider,
|
|
"api_base_url": m.api_base_url, "has_api_key": bool(m.api_key_enc),
|
|
"model_version": m.model_version, "temperature": m.temperature,
|
|
"max_tokens": m.max_tokens, "timeout_seconds": m.timeout_seconds,
|
|
"is_default": m.is_default, "is_enabled": m.is_enabled,
|
|
"created_at": m.created_at.isoformat(),
|
|
}
|