302 lines
9.6 KiB
Python
302 lines
9.6 KiB
Python
"""
|
||
AI 大模型对接服务
|
||
支持 OpenAI、智谱、百度文心、阿里通义等主流大模型
|
||
"""
|
||
import logging
|
||
from typing import Optional, Dict, Any, List
|
||
from datetime import datetime
|
||
import json
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class AIService:
|
||
"""AI 服务类"""
|
||
|
||
def __init__(self):
|
||
self._client_cache: Dict[str, Any] = {}
|
||
|
||
async def generate_comment(
|
||
self,
|
||
news_content: str,
|
||
writing_style: str,
|
||
persona_description: Optional[str] = None,
|
||
model_config: Optional[Dict[str, Any]] = None
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
AI 生成评论
|
||
:param news_content: 新闻内容
|
||
:param writing_style: 写作风格
|
||
:param persona_description: 人格描述
|
||
:param model_config: 模型配置
|
||
:return: 生成结果(包含 content, tokens_used 等)
|
||
"""
|
||
prompt = self._build_comment_prompt(
|
||
news_content,
|
||
writing_style,
|
||
persona_description
|
||
)
|
||
|
||
return await self._call_ai_api(prompt, model_config)
|
||
|
||
async def generate_reply(
|
||
self,
|
||
original_comment: str,
|
||
news_content: str,
|
||
writing_style: str,
|
||
persona_description: Optional[str] = None,
|
||
model_config: Optional[Dict[str, Any]] = None
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
AI 生成回复
|
||
:param original_comment: 原评论
|
||
:param news_content: 新闻内容
|
||
:param writing_style: 写作风格
|
||
:param persona_description: 人格描述
|
||
:param model_config: 模型配置
|
||
:return: 生成结果
|
||
"""
|
||
prompt = self._build_reply_prompt(
|
||
original_comment,
|
||
news_content,
|
||
writing_style,
|
||
persona_description
|
||
)
|
||
|
||
return await self._call_ai_api(prompt, model_config)
|
||
|
||
def _build_comment_prompt(
|
||
self,
|
||
news_content: str,
|
||
writing_style: str,
|
||
persona_description: Optional[str] = None
|
||
) -> str:
|
||
"""构建评论提示词"""
|
||
base_prompt = f"""你是一位虚拟用户,请根据以下要求写一条简短的评论:
|
||
|
||
写作风格:{writing_style}
|
||
"""
|
||
|
||
if persona_description:
|
||
base_prompt += f"\n人格特征:{persona_description}\n"
|
||
|
||
base_prompt += f"""
|
||
新闻内容:
|
||
{news_content[:1000]} # 限制长度
|
||
|
||
请写一条 50-100 字的评论,要符合你的写作风格和人格特征。直接输出评论内容,不要有其他说明。"""
|
||
|
||
return base_prompt
|
||
|
||
def _build_reply_prompt(
|
||
self,
|
||
original_comment: str,
|
||
news_content: str,
|
||
writing_style: str,
|
||
persona_description: Optional[str] = None
|
||
) -> str:
|
||
"""构建回复提示词"""
|
||
base_prompt = f"""你是一位虚拟用户,请根据以下要求回复另一条评论:
|
||
|
||
写作风格:{writing_style}
|
||
"""
|
||
|
||
if persona_description:
|
||
base_prompt += f"\n人格特征:{persona_description}\n"
|
||
|
||
base_prompt += f"""
|
||
新闻内容:
|
||
{news_content[:500]}
|
||
|
||
原评论:
|
||
{original_comment}
|
||
|
||
请写一条 30-80 字的回复,要符合你的写作风格和人格特征。直接输出回复内容,不要有其他说明。"""
|
||
|
||
return base_prompt
|
||
|
||
async def _call_ai_api(
|
||
self,
|
||
prompt: str,
|
||
model_config: Optional[Dict[str, Any]] = None
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
调用 AI API(根据 model_config 中的 provider 选择对应模型)
|
||
:param prompt: 提示词
|
||
:param model_config: 模型配置
|
||
:return: 生成结果
|
||
"""
|
||
if not model_config:
|
||
# 使用默认配置(需要从数据库加载)
|
||
from app.models.ai_model import AIModelConfig
|
||
from app.models.base import get_db
|
||
|
||
with get_db() as db:
|
||
default_model = db.query(AIModelConfig).filter(
|
||
AIModelConfig.is_default == True,
|
||
AIModelConfig.is_active == True
|
||
).first()
|
||
|
||
if not default_model:
|
||
logger.error("No default AI model configured")
|
||
return None
|
||
|
||
model_config = {
|
||
"provider": default_model.provider,
|
||
"model_name": default_model.model_name,
|
||
"api_key": default_model.api_key,
|
||
"api_url": default_model.api_url,
|
||
"temperature": default_model.temperature,
|
||
"max_tokens": default_model.max_tokens,
|
||
}
|
||
|
||
provider = model_config.get("provider", "").lower()
|
||
|
||
try:
|
||
if provider == "openai":
|
||
return await self._call_openai(prompt, model_config)
|
||
elif provider == "zhipu":
|
||
return await self._call_zhipu(prompt, model_config)
|
||
elif provider in ["baidu", "wenxin"]:
|
||
return await self._call_baidu_wenxin(prompt, model_config)
|
||
elif provider in ["aliyun", "dashscope"]:
|
||
return await self._call_aliyun_dashscope(prompt, model_config)
|
||
else:
|
||
logger.error(f"Unsupported AI provider: {provider}")
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"AI API call error: {e}")
|
||
return None
|
||
|
||
async def _call_openai(
|
||
self,
|
||
prompt: str,
|
||
config: Dict[str, Any]
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""调用 OpenAI API"""
|
||
try:
|
||
from openai import AsyncOpenAI
|
||
|
||
client = AsyncOpenAI(
|
||
api_key=config["api_key"],
|
||
base_url=config.get("api_url")
|
||
)
|
||
|
||
response = await client.chat.completions.create(
|
||
model=config.get("model_name", "gpt-3.5-turbo"),
|
||
messages=[
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
temperature=config.get("temperature", 0.7),
|
||
max_tokens=config.get("max_tokens", 1000)
|
||
)
|
||
|
||
content = response.choices[0].message.content
|
||
tokens_used = response.usage.total_tokens if response.usage else 0
|
||
|
||
logger.info(f"OpenAI generated content, tokens: {tokens_used}")
|
||
|
||
return {
|
||
"content": content,
|
||
"tokens_used": tokens_used,
|
||
"provider": "openai",
|
||
"model": config.get("model_name", "gpt-3.5-turbo")
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"OpenAI API error: {e}")
|
||
return None
|
||
|
||
async def _call_zhipu(
|
||
self,
|
||
prompt: str,
|
||
config: Dict[str, Any]
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""调用智谱 AI API"""
|
||
try:
|
||
from zhipuai import ZhipuAI
|
||
|
||
client = ZhipuAI(api_key=config["api_key"])
|
||
|
||
response = client.chat.completions.create(
|
||
model=config.get("model_name", "glm-4"),
|
||
messages=[
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
temperature=config.get("temperature", 0.7),
|
||
max_tokens=config.get("max_tokens", 1000)
|
||
)
|
||
|
||
content = response.choices[0].message.content
|
||
tokens_used = response.usage.total_tokens if response.usage else 0
|
||
|
||
logger.info(f"Zhipu AI generated content, tokens: {tokens_used}")
|
||
|
||
return {
|
||
"content": content,
|
||
"tokens_used": tokens_used,
|
||
"provider": "zhipu",
|
||
"model": config.get("model_name", "glm-4")
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"Zhipu AI API error: {e}")
|
||
return None
|
||
|
||
async def _call_baidu_wenxin(
|
||
self,
|
||
prompt: str,
|
||
config: Dict[str, Any]
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""调用百度文心一言 API"""
|
||
# TODO: 实现百度文心一言 API 调用
|
||
logger.warning("Baidu Wenxin API not implemented yet")
|
||
return None
|
||
|
||
async def _call_aliyun_dashscope(
|
||
self,
|
||
prompt: str,
|
||
config: Dict[str, Any]
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""调用阿里云通义千问 API"""
|
||
# TODO: 实现阿里云 DashScope API 调用
|
||
logger.warning("Aliyun DashScope API not implemented yet")
|
||
return None
|
||
|
||
async def test_model(
|
||
self,
|
||
model_config: Dict[str, Any],
|
||
test_prompt: str = "测试评论"
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
测试模型配置
|
||
:param model_config: 模型配置
|
||
:param test_prompt: 测试提示词
|
||
:return: 测试结果
|
||
"""
|
||
import time
|
||
start_time = time.time()
|
||
|
||
result = await self._call_ai_api(test_prompt, model_config)
|
||
|
||
cost_time = time.time() - start_time
|
||
|
||
if result:
|
||
return {
|
||
"success": True,
|
||
"content": result.get("content"),
|
||
"tokens_used": result.get("tokens_used", 0),
|
||
"cost_time": round(cost_time, 2),
|
||
"error_message": None
|
||
}
|
||
else:
|
||
return {
|
||
"success": False,
|
||
"content": None,
|
||
"tokens_used": 0,
|
||
"cost_time": round(cost_time, 2),
|
||
"error_message": "Failed to generate content"
|
||
}
|
||
|
||
|
||
# 创建全局服务实例
|
||
ai_service = AIService()
|