Files
huihuiSquare/backend/app/api/ai_model.py
yuqianqian10204095yu cebc0a288f 1.0.0初始化源代码
2026-03-23 15:40:36 +08:00

137 lines
3.6 KiB
Python

"""
AI 模型配置 API
"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import List
from app.models.base import get_db
from app.models.ai_model import AIModelConfig
from app.schemas.ai_model import (
AIModelConfigCreate,
AIModelConfigUpdate,
AIModelConfigResponse,
AIModelTestRequest,
AIModelTestResponse
)
from app.services.ai_service import ai_service
router = APIRouter()
@router.get("", response_model=List[AIModelConfigResponse])
def get_ai_models(
db: Session = Depends(get_db)
):
"""获取所有 AI 模型配置"""
models = db.query(AIModelConfig).all()
return models
@router.get("/{model_id}", response_model=AIModelConfigResponse)
def get_ai_model(
model_id: int,
db: Session = Depends(get_db)
):
"""获取 AI 模型详情"""
model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first()
if not model:
raise HTTPException(status_code=404, detail="Model not found")
return model
@router.post("", response_model=AIModelConfigResponse)
def create_ai_model(
model_data: AIModelConfigCreate,
db: Session = Depends(get_db)
):
"""创建 AI 模型配置"""
# 检查是否已存在
existing = db.query(AIModelConfig).filter(
AIModelConfig.model_name == model_data.model_name
).first()
if existing:
raise HTTPException(status_code=400, detail="Model already exists")
model = AIModelConfig(**model_data.model_dump())
# 如果是第一个模型,设为默认
if not db.query(AIModelConfig).filter(AIModelConfig.is_default == True).first():
model.is_default = True
db.add(model)
db.commit()
db.refresh(model)
return model
@router.put("/{model_id}", response_model=AIModelConfigResponse)
def update_ai_model(
model_id: int,
model_data: AIModelConfigUpdate,
db: Session = Depends(get_db)
):
"""更新 AI 模型配置"""
model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first()
if not model:
raise HTTPException(status_code=404, detail="Model not found")
update_data = model_data.model_dump(exclude_unset=True)
# 如果设置默认模型,先取消其他模型的默认状态
if update_data.get("is_default"):
db.query(AIModelConfig).update({"is_default": False})
for key, value in update_data.items():
setattr(model, key, value)
db.commit()
db.refresh(model)
return model
@router.delete("/{model_id}")
def delete_ai_model(
model_id: int,
db: Session = Depends(get_db)
):
"""删除 AI 模型配置"""
model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first()
if not model:
raise HTTPException(status_code=404, detail="Model not found")
db.delete(model)
db.commit()
return {"message": "Model deleted successfully"}
@router.post("/test", response_model=AIModelTestResponse)
async def test_ai_model(
request: AIModelTestRequest,
db: Session = Depends(get_db)
):
"""测试 AI 模型"""
model = db.query(AIModelConfig).filter(AIModelConfig.id == request.model_id).first()
if not model:
raise HTTPException(status_code=404, detail="Model not found")
model_config = {
"provider": model.provider,
"model_name": model.model_name,
"api_key": model.api_key,
"api_url": model.api_url,
"temperature": model.temperature,
"max_tokens": model.max_tokens,
}
result = await ai_service.test_model(
model_config=model_config,
test_prompt=request.test_prompt
)
return result