- 实现Agent管理,支持AI辅助生成系统提示词 - 支持多个AI提供商(OpenRouter、智谱、MiniMax等) - 实现聊天室和讨论引擎 - WebSocket实时消息推送 - 前端使用React + Ant Design - 后端使用FastAPI + MongoDB Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
365 lines
10 KiB
Python
365 lines
10 KiB
Python
"""
|
||
AI接口提供商服务
|
||
管理AI接口的配置和调用
|
||
"""
|
||
import uuid
|
||
from datetime import datetime
|
||
from typing import List, Dict, Any, Optional
|
||
from loguru import logger
|
||
|
||
from models.ai_provider import AIProvider
|
||
from adapters import get_adapter, BaseAdapter, ChatMessage, AdapterResponse
|
||
from utils.encryption import encrypt_api_key, decrypt_api_key
|
||
from utils.rate_limiter import rate_limiter
|
||
|
||
|
||
class AIProviderService:
|
||
"""
|
||
AI接口提供商服务类
|
||
负责AI接口的CRUD操作和调用
|
||
"""
|
||
|
||
# 缓存适配器实例
|
||
_adapter_cache: Dict[str, BaseAdapter] = {}
|
||
|
||
@classmethod
|
||
async def create_provider(
|
||
cls,
|
||
provider_type: str,
|
||
name: str,
|
||
model: str,
|
||
api_key: str = "",
|
||
base_url: str = "",
|
||
use_proxy: bool = False,
|
||
proxy_config: Optional[Dict[str, Any]] = None,
|
||
rate_limit: Optional[Dict[str, int]] = None,
|
||
timeout: int = 60,
|
||
extra_params: Optional[Dict[str, Any]] = None
|
||
) -> AIProvider:
|
||
"""
|
||
创建新的AI接口配置
|
||
|
||
Args:
|
||
provider_type: 提供商类型
|
||
name: 自定义名称
|
||
model: 模型名称
|
||
api_key: API密钥
|
||
base_url: API基础URL
|
||
use_proxy: 是否使用代理
|
||
proxy_config: 代理配置
|
||
rate_limit: 速率限制配置
|
||
timeout: 超时时间
|
||
extra_params: 额外参数
|
||
|
||
Returns:
|
||
创建的AIProvider文档
|
||
"""
|
||
# 验证提供商类型
|
||
try:
|
||
get_adapter(provider_type)
|
||
except ValueError as e:
|
||
raise ValueError(f"不支持的提供商类型: {provider_type}")
|
||
|
||
# 生成唯一ID
|
||
provider_id = f"{provider_type}-{uuid.uuid4().hex[:8]}"
|
||
|
||
# 加密API密钥
|
||
encrypted_key = encrypt_api_key(api_key) if api_key else ""
|
||
|
||
# 创建文档
|
||
provider = AIProvider(
|
||
provider_id=provider_id,
|
||
provider_type=provider_type,
|
||
name=name,
|
||
api_key=encrypted_key,
|
||
base_url=base_url,
|
||
model=model,
|
||
use_proxy=use_proxy,
|
||
proxy_config=proxy_config or {},
|
||
rate_limit=rate_limit or {"requests_per_minute": 60, "tokens_per_minute": 100000},
|
||
timeout=timeout,
|
||
extra_params=extra_params or {},
|
||
enabled=True,
|
||
created_at=datetime.utcnow(),
|
||
updated_at=datetime.utcnow()
|
||
)
|
||
|
||
await provider.insert()
|
||
|
||
# 注册速率限制
|
||
rate_limiter.register(
|
||
provider_id,
|
||
provider.rate_limit.get("requests_per_minute", 60),
|
||
provider.rate_limit.get("tokens_per_minute", 100000)
|
||
)
|
||
|
||
logger.info(f"创建AI接口配置: {provider_id} ({name})")
|
||
return provider
|
||
|
||
@classmethod
|
||
async def get_provider(cls, provider_id: str) -> Optional[AIProvider]:
|
||
"""
|
||
获取指定AI接口配置
|
||
|
||
Args:
|
||
provider_id: 接口ID
|
||
|
||
Returns:
|
||
AIProvider文档或None
|
||
"""
|
||
return await AIProvider.find_one(AIProvider.provider_id == provider_id)
|
||
|
||
@classmethod
|
||
async def get_all_providers(
|
||
cls,
|
||
enabled_only: bool = False
|
||
) -> List[AIProvider]:
|
||
"""
|
||
获取所有AI接口配置
|
||
|
||
Args:
|
||
enabled_only: 是否只返回启用的接口
|
||
|
||
Returns:
|
||
AIProvider列表
|
||
"""
|
||
if enabled_only:
|
||
return await AIProvider.find(AIProvider.enabled == True).to_list()
|
||
return await AIProvider.find_all().to_list()
|
||
|
||
@classmethod
|
||
async def update_provider(
|
||
cls,
|
||
provider_id: str,
|
||
**kwargs
|
||
) -> Optional[AIProvider]:
|
||
"""
|
||
更新AI接口配置
|
||
|
||
Args:
|
||
provider_id: 接口ID
|
||
**kwargs: 要更新的字段
|
||
|
||
Returns:
|
||
更新后的AIProvider或None
|
||
"""
|
||
provider = await cls.get_provider(provider_id)
|
||
if not provider:
|
||
return None
|
||
|
||
# 如果更新了API密钥,需要加密
|
||
if "api_key" in kwargs and kwargs["api_key"]:
|
||
kwargs["api_key"] = encrypt_api_key(kwargs["api_key"])
|
||
|
||
# 更新字段
|
||
kwargs["updated_at"] = datetime.utcnow()
|
||
|
||
for key, value in kwargs.items():
|
||
if hasattr(provider, key):
|
||
setattr(provider, key, value)
|
||
|
||
await provider.save()
|
||
|
||
# 清除适配器缓存
|
||
cls._adapter_cache.pop(provider_id, None)
|
||
|
||
# 更新速率限制
|
||
if "rate_limit" in kwargs:
|
||
rate_limiter.unregister(provider_id)
|
||
rate_limiter.register(
|
||
provider_id,
|
||
provider.rate_limit.get("requests_per_minute", 60),
|
||
provider.rate_limit.get("tokens_per_minute", 100000)
|
||
)
|
||
|
||
logger.info(f"更新AI接口配置: {provider_id}")
|
||
return provider
|
||
|
||
@classmethod
|
||
async def delete_provider(cls, provider_id: str) -> bool:
|
||
"""
|
||
删除AI接口配置
|
||
|
||
Args:
|
||
provider_id: 接口ID
|
||
|
||
Returns:
|
||
是否删除成功
|
||
"""
|
||
provider = await cls.get_provider(provider_id)
|
||
if not provider:
|
||
return False
|
||
|
||
await provider.delete()
|
||
|
||
# 清除缓存和速率限制
|
||
cls._adapter_cache.pop(provider_id, None)
|
||
rate_limiter.unregister(provider_id)
|
||
|
||
logger.info(f"删除AI接口配置: {provider_id}")
|
||
return True
|
||
|
||
@classmethod
|
||
async def get_adapter(cls, provider_id: str) -> Optional[BaseAdapter]:
|
||
"""
|
||
获取AI接口的适配器实例
|
||
|
||
Args:
|
||
provider_id: 接口ID
|
||
|
||
Returns:
|
||
适配器实例或None
|
||
"""
|
||
# 检查缓存
|
||
if provider_id in cls._adapter_cache:
|
||
return cls._adapter_cache[provider_id]
|
||
|
||
provider = await cls.get_provider(provider_id)
|
||
if not provider or not provider.enabled:
|
||
return None
|
||
|
||
# 解密API密钥
|
||
api_key = decrypt_api_key(provider.api_key) if provider.api_key else ""
|
||
|
||
# 创建适配器
|
||
adapter_class = get_adapter(provider.provider_type)
|
||
adapter = adapter_class(
|
||
api_key=api_key,
|
||
base_url=provider.base_url,
|
||
model=provider.model,
|
||
use_proxy=provider.use_proxy,
|
||
proxy_config=provider.proxy_config,
|
||
timeout=provider.timeout,
|
||
**provider.extra_params
|
||
)
|
||
|
||
# 缓存适配器
|
||
cls._adapter_cache[provider_id] = adapter
|
||
|
||
return adapter
|
||
|
||
@classmethod
|
||
async def chat(
|
||
cls,
|
||
provider_id: str,
|
||
messages: List[Dict[str, str]],
|
||
temperature: float = 0.7,
|
||
max_tokens: int = 2000,
|
||
**kwargs
|
||
) -> AdapterResponse:
|
||
"""
|
||
调用AI接口进行对话
|
||
|
||
Args:
|
||
provider_id: 接口ID
|
||
messages: 消息列表 [{"role": "user", "content": "..."}]
|
||
temperature: 温度参数
|
||
max_tokens: 最大token数
|
||
**kwargs: 额外参数
|
||
|
||
Returns:
|
||
适配器响应
|
||
"""
|
||
adapter = await cls.get_adapter(provider_id)
|
||
if not adapter:
|
||
return AdapterResponse(
|
||
success=False,
|
||
error=f"AI接口不存在或未启用: {provider_id}"
|
||
)
|
||
|
||
# 检查速率限制
|
||
estimated_tokens = sum(len(m.get("content", "")) for m in messages) // 4
|
||
if not await rate_limiter.acquire_wait(provider_id, estimated_tokens):
|
||
return AdapterResponse(
|
||
success=False,
|
||
error="请求频率超限,请稍后重试"
|
||
)
|
||
|
||
# 转换消息格式
|
||
chat_messages = [
|
||
ChatMessage(
|
||
role=m.get("role", "user"),
|
||
content=m.get("content", ""),
|
||
name=m.get("name")
|
||
)
|
||
for m in messages
|
||
]
|
||
|
||
# 调用适配器
|
||
response = await adapter.chat(
|
||
messages=chat_messages,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
**kwargs
|
||
)
|
||
|
||
return response
|
||
|
||
@classmethod
|
||
async def test_provider(cls, provider_id: str) -> Dict[str, Any]:
|
||
"""
|
||
测试AI接口连接
|
||
|
||
Args:
|
||
provider_id: 接口ID
|
||
|
||
Returns:
|
||
测试结果
|
||
"""
|
||
adapter = await cls.get_adapter(provider_id)
|
||
if not adapter:
|
||
return {
|
||
"success": False,
|
||
"message": f"AI接口不存在或未启用: {provider_id}"
|
||
}
|
||
|
||
return await adapter.test_connection()
|
||
|
||
@classmethod
|
||
async def test_provider_config(
|
||
cls,
|
||
provider_type: str,
|
||
api_key: str,
|
||
base_url: str = "",
|
||
model: str = "",
|
||
use_proxy: bool = False,
|
||
proxy_config: Optional[Dict[str, Any]] = None,
|
||
timeout: int = 30,
|
||
**kwargs
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
测试AI接口配置(不保存)
|
||
|
||
Args:
|
||
provider_type: 提供商类型
|
||
api_key: API密钥
|
||
base_url: API基础URL
|
||
model: 模型名称
|
||
use_proxy: 是否使用代理
|
||
proxy_config: 代理配置
|
||
timeout: 超时时间
|
||
**kwargs: 额外参数
|
||
|
||
Returns:
|
||
测试结果
|
||
"""
|
||
try:
|
||
adapter_class = get_adapter(provider_type)
|
||
except ValueError:
|
||
return {
|
||
"success": False,
|
||
"message": f"不支持的提供商类型: {provider_type}"
|
||
}
|
||
|
||
adapter = adapter_class(
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
model=model,
|
||
use_proxy=use_proxy,
|
||
proxy_config=proxy_config,
|
||
timeout=timeout,
|
||
**kwargs
|
||
)
|
||
|
||
return await adapter.test_connection()
|