Files
AIChatRoom/backend/services/ai_provider_service.py

365 lines
10 KiB
Python
Raw Permalink Normal View History

"""
AI接口提供商服务
管理AI接口的配置和调用
"""
import uuid
from datetime import datetime
from typing import List, Dict, Any, Optional
from loguru import logger
from models.ai_provider import AIProvider
from adapters import get_adapter, BaseAdapter, ChatMessage, AdapterResponse
from utils.encryption import encrypt_api_key, decrypt_api_key
from utils.rate_limiter import rate_limiter
class AIProviderService:
"""
AI接口提供商服务类
负责AI接口的CRUD操作和调用
"""
# 缓存适配器实例
_adapter_cache: Dict[str, BaseAdapter] = {}
@classmethod
async def create_provider(
cls,
provider_type: str,
name: str,
model: str,
api_key: str = "",
base_url: str = "",
use_proxy: bool = False,
proxy_config: Optional[Dict[str, Any]] = None,
rate_limit: Optional[Dict[str, int]] = None,
timeout: int = 60,
extra_params: Optional[Dict[str, Any]] = None
) -> AIProvider:
"""
创建新的AI接口配置
Args:
provider_type: 提供商类型
name: 自定义名称
model: 模型名称
api_key: API密钥
base_url: API基础URL
use_proxy: 是否使用代理
proxy_config: 代理配置
rate_limit: 速率限制配置
timeout: 超时时间
extra_params: 额外参数
Returns:
创建的AIProvider文档
"""
# 验证提供商类型
try:
get_adapter(provider_type)
except ValueError as e:
raise ValueError(f"不支持的提供商类型: {provider_type}")
# 生成唯一ID
provider_id = f"{provider_type}-{uuid.uuid4().hex[:8]}"
# 加密API密钥
encrypted_key = encrypt_api_key(api_key) if api_key else ""
# 创建文档
provider = AIProvider(
provider_id=provider_id,
provider_type=provider_type,
name=name,
api_key=encrypted_key,
base_url=base_url,
model=model,
use_proxy=use_proxy,
proxy_config=proxy_config or {},
rate_limit=rate_limit or {"requests_per_minute": 60, "tokens_per_minute": 100000},
timeout=timeout,
extra_params=extra_params or {},
enabled=True,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
await provider.insert()
# 注册速率限制
rate_limiter.register(
provider_id,
provider.rate_limit.get("requests_per_minute", 60),
provider.rate_limit.get("tokens_per_minute", 100000)
)
logger.info(f"创建AI接口配置: {provider_id} ({name})")
return provider
@classmethod
async def get_provider(cls, provider_id: str) -> Optional[AIProvider]:
"""
获取指定AI接口配置
Args:
provider_id: 接口ID
Returns:
AIProvider文档或None
"""
return await AIProvider.find_one(AIProvider.provider_id == provider_id)
@classmethod
async def get_all_providers(
cls,
enabled_only: bool = False
) -> List[AIProvider]:
"""
获取所有AI接口配置
Args:
enabled_only: 是否只返回启用的接口
Returns:
AIProvider列表
"""
if enabled_only:
return await AIProvider.find(AIProvider.enabled == True).to_list()
return await AIProvider.find_all().to_list()
@classmethod
async def update_provider(
cls,
provider_id: str,
**kwargs
) -> Optional[AIProvider]:
"""
更新AI接口配置
Args:
provider_id: 接口ID
**kwargs: 要更新的字段
Returns:
更新后的AIProvider或None
"""
provider = await cls.get_provider(provider_id)
if not provider:
return None
# 如果更新了API密钥需要加密
if "api_key" in kwargs and kwargs["api_key"]:
kwargs["api_key"] = encrypt_api_key(kwargs["api_key"])
# 更新字段
kwargs["updated_at"] = datetime.utcnow()
for key, value in kwargs.items():
if hasattr(provider, key):
setattr(provider, key, value)
await provider.save()
# 清除适配器缓存
cls._adapter_cache.pop(provider_id, None)
# 更新速率限制
if "rate_limit" in kwargs:
rate_limiter.unregister(provider_id)
rate_limiter.register(
provider_id,
provider.rate_limit.get("requests_per_minute", 60),
provider.rate_limit.get("tokens_per_minute", 100000)
)
logger.info(f"更新AI接口配置: {provider_id}")
return provider
@classmethod
async def delete_provider(cls, provider_id: str) -> bool:
"""
删除AI接口配置
Args:
provider_id: 接口ID
Returns:
是否删除成功
"""
provider = await cls.get_provider(provider_id)
if not provider:
return False
await provider.delete()
# 清除缓存和速率限制
cls._adapter_cache.pop(provider_id, None)
rate_limiter.unregister(provider_id)
logger.info(f"删除AI接口配置: {provider_id}")
return True
@classmethod
async def get_adapter(cls, provider_id: str) -> Optional[BaseAdapter]:
"""
获取AI接口的适配器实例
Args:
provider_id: 接口ID
Returns:
适配器实例或None
"""
# 检查缓存
if provider_id in cls._adapter_cache:
return cls._adapter_cache[provider_id]
provider = await cls.get_provider(provider_id)
if not provider or not provider.enabled:
return None
# 解密API密钥
api_key = decrypt_api_key(provider.api_key) if provider.api_key else ""
# 创建适配器
adapter_class = get_adapter(provider.provider_type)
adapter = adapter_class(
api_key=api_key,
base_url=provider.base_url,
model=provider.model,
use_proxy=provider.use_proxy,
proxy_config=provider.proxy_config,
timeout=provider.timeout,
**provider.extra_params
)
# 缓存适配器
cls._adapter_cache[provider_id] = adapter
return adapter
@classmethod
async def chat(
cls,
provider_id: str,
messages: List[Dict[str, str]],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AdapterResponse:
"""
调用AI接口进行对话
Args:
provider_id: 接口ID
messages: 消息列表 [{"role": "user", "content": "..."}]
temperature: 温度参数
max_tokens: 最大token数
**kwargs: 额外参数
Returns:
适配器响应
"""
adapter = await cls.get_adapter(provider_id)
if not adapter:
return AdapterResponse(
success=False,
error=f"AI接口不存在或未启用: {provider_id}"
)
# 检查速率限制
estimated_tokens = sum(len(m.get("content", "")) for m in messages) // 4
if not await rate_limiter.acquire_wait(provider_id, estimated_tokens):
return AdapterResponse(
success=False,
error="请求频率超限,请稍后重试"
)
# 转换消息格式
chat_messages = [
ChatMessage(
role=m.get("role", "user"),
content=m.get("content", ""),
name=m.get("name")
)
for m in messages
]
# 调用适配器
response = await adapter.chat(
messages=chat_messages,
temperature=temperature,
max_tokens=max_tokens,
**kwargs
)
return response
@classmethod
async def test_provider(cls, provider_id: str) -> Dict[str, Any]:
"""
测试AI接口连接
Args:
provider_id: 接口ID
Returns:
测试结果
"""
adapter = await cls.get_adapter(provider_id)
if not adapter:
return {
"success": False,
"message": f"AI接口不存在或未启用: {provider_id}"
}
return await adapter.test_connection()
@classmethod
async def test_provider_config(
cls,
provider_type: str,
api_key: str,
base_url: str = "",
model: str = "",
use_proxy: bool = False,
proxy_config: Optional[Dict[str, Any]] = None,
timeout: int = 30,
**kwargs
) -> Dict[str, Any]:
"""
测试AI接口配置不保存
Args:
provider_type: 提供商类型
api_key: API密钥
base_url: API基础URL
model: 模型名称
use_proxy: 是否使用代理
proxy_config: 代理配置
timeout: 超时时间
**kwargs: 额外参数
Returns:
测试结果
"""
try:
adapter_class = get_adapter(provider_type)
except ValueError:
return {
"success": False,
"message": f"不支持的提供商类型: {provider_type}"
}
adapter = adapter_class(
api_key=api_key,
base_url=base_url,
model=model,
use_proxy=use_proxy,
proxy_config=proxy_config,
timeout=timeout,
**kwargs
)
return await adapter.test_connection()