137 lines
3.6 KiB
Python
137 lines
3.6 KiB
Python
"""
|
|
AI 模型配置 API
|
|
"""
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from sqlalchemy.orm import Session
|
|
from typing import List
|
|
|
|
from app.models.base import get_db
|
|
from app.models.ai_model import AIModelConfig
|
|
from app.schemas.ai_model import (
|
|
AIModelConfigCreate,
|
|
AIModelConfigUpdate,
|
|
AIModelConfigResponse,
|
|
AIModelTestRequest,
|
|
AIModelTestResponse
|
|
)
|
|
from app.services.ai_service import ai_service
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.get("", response_model=List[AIModelConfigResponse])
|
|
def get_ai_models(
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""获取所有 AI 模型配置"""
|
|
models = db.query(AIModelConfig).all()
|
|
return models
|
|
|
|
|
|
@router.get("/{model_id}", response_model=AIModelConfigResponse)
|
|
def get_ai_model(
|
|
model_id: int,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""获取 AI 模型详情"""
|
|
model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first()
|
|
if not model:
|
|
raise HTTPException(status_code=404, detail="Model not found")
|
|
return model
|
|
|
|
|
|
@router.post("", response_model=AIModelConfigResponse)
|
|
def create_ai_model(
|
|
model_data: AIModelConfigCreate,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""创建 AI 模型配置"""
|
|
# 检查是否已存在
|
|
existing = db.query(AIModelConfig).filter(
|
|
AIModelConfig.model_name == model_data.model_name
|
|
).first()
|
|
|
|
if existing:
|
|
raise HTTPException(status_code=400, detail="Model already exists")
|
|
|
|
model = AIModelConfig(**model_data.model_dump())
|
|
|
|
# 如果是第一个模型,设为默认
|
|
if not db.query(AIModelConfig).filter(AIModelConfig.is_default == True).first():
|
|
model.is_default = True
|
|
|
|
db.add(model)
|
|
db.commit()
|
|
db.refresh(model)
|
|
|
|
return model
|
|
|
|
|
|
@router.put("/{model_id}", response_model=AIModelConfigResponse)
|
|
def update_ai_model(
|
|
model_id: int,
|
|
model_data: AIModelConfigUpdate,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""更新 AI 模型配置"""
|
|
model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first()
|
|
if not model:
|
|
raise HTTPException(status_code=404, detail="Model not found")
|
|
|
|
update_data = model_data.model_dump(exclude_unset=True)
|
|
|
|
# 如果设置默认模型,先取消其他模型的默认状态
|
|
if update_data.get("is_default"):
|
|
db.query(AIModelConfig).update({"is_default": False})
|
|
|
|
for key, value in update_data.items():
|
|
setattr(model, key, value)
|
|
|
|
db.commit()
|
|
db.refresh(model)
|
|
|
|
return model
|
|
|
|
|
|
@router.delete("/{model_id}")
|
|
def delete_ai_model(
|
|
model_id: int,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""删除 AI 模型配置"""
|
|
model = db.query(AIModelConfig).filter(AIModelConfig.id == model_id).first()
|
|
if not model:
|
|
raise HTTPException(status_code=404, detail="Model not found")
|
|
|
|
db.delete(model)
|
|
db.commit()
|
|
|
|
return {"message": "Model deleted successfully"}
|
|
|
|
|
|
@router.post("/test", response_model=AIModelTestResponse)
|
|
async def test_ai_model(
|
|
request: AIModelTestRequest,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""测试 AI 模型"""
|
|
model = db.query(AIModelConfig).filter(AIModelConfig.id == request.model_id).first()
|
|
if not model:
|
|
raise HTTPException(status_code=404, detail="Model not found")
|
|
|
|
model_config = {
|
|
"provider": model.provider,
|
|
"model_name": model.model_name,
|
|
"api_key": model.api_key,
|
|
"api_url": model.api_url,
|
|
"temperature": model.temperature,
|
|
"max_tokens": model.max_tokens,
|
|
}
|
|
|
|
result = await ai_service.test_model(
|
|
model_config=model_config,
|
|
test_prompt=request.test_prompt
|
|
)
|
|
|
|
return result
|