1.0.0初始化源代码
This commit is contained in:
301
backend/app/services/ai_service.py
Normal file
301
backend/app/services/ai_service.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""
|
||||
AI 大模型对接服务
|
||||
支持 OpenAI、智谱、百度文心、阿里通义等主流大模型
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIService:
|
||||
"""AI 服务类"""
|
||||
|
||||
def __init__(self):
|
||||
self._client_cache: Dict[str, Any] = {}
|
||||
|
||||
async def generate_comment(
|
||||
self,
|
||||
news_content: str,
|
||||
writing_style: str,
|
||||
persona_description: Optional[str] = None,
|
||||
model_config: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
AI 生成评论
|
||||
:param news_content: 新闻内容
|
||||
:param writing_style: 写作风格
|
||||
:param persona_description: 人格描述
|
||||
:param model_config: 模型配置
|
||||
:return: 生成结果(包含 content, tokens_used 等)
|
||||
"""
|
||||
prompt = self._build_comment_prompt(
|
||||
news_content,
|
||||
writing_style,
|
||||
persona_description
|
||||
)
|
||||
|
||||
return await self._call_ai_api(prompt, model_config)
|
||||
|
||||
async def generate_reply(
|
||||
self,
|
||||
original_comment: str,
|
||||
news_content: str,
|
||||
writing_style: str,
|
||||
persona_description: Optional[str] = None,
|
||||
model_config: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
AI 生成回复
|
||||
:param original_comment: 原评论
|
||||
:param news_content: 新闻内容
|
||||
:param writing_style: 写作风格
|
||||
:param persona_description: 人格描述
|
||||
:param model_config: 模型配置
|
||||
:return: 生成结果
|
||||
"""
|
||||
prompt = self._build_reply_prompt(
|
||||
original_comment,
|
||||
news_content,
|
||||
writing_style,
|
||||
persona_description
|
||||
)
|
||||
|
||||
return await self._call_ai_api(prompt, model_config)
|
||||
|
||||
def _build_comment_prompt(
|
||||
self,
|
||||
news_content: str,
|
||||
writing_style: str,
|
||||
persona_description: Optional[str] = None
|
||||
) -> str:
|
||||
"""构建评论提示词"""
|
||||
base_prompt = f"""你是一位虚拟用户,请根据以下要求写一条简短的评论:
|
||||
|
||||
写作风格:{writing_style}
|
||||
"""
|
||||
|
||||
if persona_description:
|
||||
base_prompt += f"\n人格特征:{persona_description}\n"
|
||||
|
||||
base_prompt += f"""
|
||||
新闻内容:
|
||||
{news_content[:1000]} # 限制长度
|
||||
|
||||
请写一条 50-100 字的评论,要符合你的写作风格和人格特征。直接输出评论内容,不要有其他说明。"""
|
||||
|
||||
return base_prompt
|
||||
|
||||
def _build_reply_prompt(
|
||||
self,
|
||||
original_comment: str,
|
||||
news_content: str,
|
||||
writing_style: str,
|
||||
persona_description: Optional[str] = None
|
||||
) -> str:
|
||||
"""构建回复提示词"""
|
||||
base_prompt = f"""你是一位虚拟用户,请根据以下要求回复另一条评论:
|
||||
|
||||
写作风格:{writing_style}
|
||||
"""
|
||||
|
||||
if persona_description:
|
||||
base_prompt += f"\n人格特征:{persona_description}\n"
|
||||
|
||||
base_prompt += f"""
|
||||
新闻内容:
|
||||
{news_content[:500]}
|
||||
|
||||
原评论:
|
||||
{original_comment}
|
||||
|
||||
请写一条 30-80 字的回复,要符合你的写作风格和人格特征。直接输出回复内容,不要有其他说明。"""
|
||||
|
||||
return base_prompt
|
||||
|
||||
async def _call_ai_api(
|
||||
self,
|
||||
prompt: str,
|
||||
model_config: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
调用 AI API(根据 model_config 中的 provider 选择对应模型)
|
||||
:param prompt: 提示词
|
||||
:param model_config: 模型配置
|
||||
:return: 生成结果
|
||||
"""
|
||||
if not model_config:
|
||||
# 使用默认配置(需要从数据库加载)
|
||||
from app.models.ai_model import AIModelConfig
|
||||
from app.models.base import get_db
|
||||
|
||||
with get_db() as db:
|
||||
default_model = db.query(AIModelConfig).filter(
|
||||
AIModelConfig.is_default == True,
|
||||
AIModelConfig.is_active == True
|
||||
).first()
|
||||
|
||||
if not default_model:
|
||||
logger.error("No default AI model configured")
|
||||
return None
|
||||
|
||||
model_config = {
|
||||
"provider": default_model.provider,
|
||||
"model_name": default_model.model_name,
|
||||
"api_key": default_model.api_key,
|
||||
"api_url": default_model.api_url,
|
||||
"temperature": default_model.temperature,
|
||||
"max_tokens": default_model.max_tokens,
|
||||
}
|
||||
|
||||
provider = model_config.get("provider", "").lower()
|
||||
|
||||
try:
|
||||
if provider == "openai":
|
||||
return await self._call_openai(prompt, model_config)
|
||||
elif provider == "zhipu":
|
||||
return await self._call_zhipu(prompt, model_config)
|
||||
elif provider in ["baidu", "wenxin"]:
|
||||
return await self._call_baidu_wenxin(prompt, model_config)
|
||||
elif provider in ["aliyun", "dashscope"]:
|
||||
return await self._call_aliyun_dashscope(prompt, model_config)
|
||||
else:
|
||||
logger.error(f"Unsupported AI provider: {provider}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"AI API call error: {e}")
|
||||
return None
|
||||
|
||||
async def _call_openai(
|
||||
self,
|
||||
prompt: str,
|
||||
config: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""调用 OpenAI API"""
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
client = AsyncOpenAI(
|
||||
api_key=config["api_key"],
|
||||
base_url=config.get("api_url")
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=config.get("model_name", "gpt-3.5-turbo"),
|
||||
messages=[
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=config.get("temperature", 0.7),
|
||||
max_tokens=config.get("max_tokens", 1000)
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
tokens_used = response.usage.total_tokens if response.usage else 0
|
||||
|
||||
logger.info(f"OpenAI generated content, tokens: {tokens_used}")
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"tokens_used": tokens_used,
|
||||
"provider": "openai",
|
||||
"model": config.get("model_name", "gpt-3.5-turbo")
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
return None
|
||||
|
||||
async def _call_zhipu(
|
||||
self,
|
||||
prompt: str,
|
||||
config: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""调用智谱 AI API"""
|
||||
try:
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
client = ZhipuAI(api_key=config["api_key"])
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=config.get("model_name", "glm-4"),
|
||||
messages=[
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=config.get("temperature", 0.7),
|
||||
max_tokens=config.get("max_tokens", 1000)
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
tokens_used = response.usage.total_tokens if response.usage else 0
|
||||
|
||||
logger.info(f"Zhipu AI generated content, tokens: {tokens_used}")
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"tokens_used": tokens_used,
|
||||
"provider": "zhipu",
|
||||
"model": config.get("model_name", "glm-4")
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Zhipu AI API error: {e}")
|
||||
return None
|
||||
|
||||
async def _call_baidu_wenxin(
|
||||
self,
|
||||
prompt: str,
|
||||
config: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""调用百度文心一言 API"""
|
||||
# TODO: 实现百度文心一言 API 调用
|
||||
logger.warning("Baidu Wenxin API not implemented yet")
|
||||
return None
|
||||
|
||||
async def _call_aliyun_dashscope(
|
||||
self,
|
||||
prompt: str,
|
||||
config: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""调用阿里云通义千问 API"""
|
||||
# TODO: 实现阿里云 DashScope API 调用
|
||||
logger.warning("Aliyun DashScope API not implemented yet")
|
||||
return None
|
||||
|
||||
async def test_model(
|
||||
self,
|
||||
model_config: Dict[str, Any],
|
||||
test_prompt: str = "测试评论"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
测试模型配置
|
||||
:param model_config: 模型配置
|
||||
:param test_prompt: 测试提示词
|
||||
:return: 测试结果
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
result = await self._call_ai_api(test_prompt, model_config)
|
||||
|
||||
cost_time = time.time() - start_time
|
||||
|
||||
if result:
|
||||
return {
|
||||
"success": True,
|
||||
"content": result.get("content"),
|
||||
"tokens_used": result.get("tokens_used", 0),
|
||||
"cost_time": round(cost_time, 2),
|
||||
"error_message": None
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"content": None,
|
||||
"tokens_used": 0,
|
||||
"cost_time": round(cost_time, 2),
|
||||
"error_message": "Failed to generate content"
|
||||
}
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
ai_service = AIService()
|
||||
Reference in New Issue
Block a user