""" AI接口提供商服务 管理AI接口的配置和调用 """ import uuid from datetime import datetime from typing import List, Dict, Any, Optional from loguru import logger from models.ai_provider import AIProvider from adapters import get_adapter, BaseAdapter, ChatMessage, AdapterResponse from utils.encryption import encrypt_api_key, decrypt_api_key from utils.rate_limiter import rate_limiter class AIProviderService: """ AI接口提供商服务类 负责AI接口的CRUD操作和调用 """ # 缓存适配器实例 _adapter_cache: Dict[str, BaseAdapter] = {} @classmethod async def create_provider( cls, provider_type: str, name: str, model: str, api_key: str = "", base_url: str = "", use_proxy: bool = False, proxy_config: Optional[Dict[str, Any]] = None, rate_limit: Optional[Dict[str, int]] = None, timeout: int = 60, extra_params: Optional[Dict[str, Any]] = None ) -> AIProvider: """ 创建新的AI接口配置 Args: provider_type: 提供商类型 name: 自定义名称 model: 模型名称 api_key: API密钥 base_url: API基础URL use_proxy: 是否使用代理 proxy_config: 代理配置 rate_limit: 速率限制配置 timeout: 超时时间 extra_params: 额外参数 Returns: 创建的AIProvider文档 """ # 验证提供商类型 try: get_adapter(provider_type) except ValueError as e: raise ValueError(f"不支持的提供商类型: {provider_type}") # 生成唯一ID provider_id = f"{provider_type}-{uuid.uuid4().hex[:8]}" # 加密API密钥 encrypted_key = encrypt_api_key(api_key) if api_key else "" # 创建文档 provider = AIProvider( provider_id=provider_id, provider_type=provider_type, name=name, api_key=encrypted_key, base_url=base_url, model=model, use_proxy=use_proxy, proxy_config=proxy_config or {}, rate_limit=rate_limit or {"requests_per_minute": 60, "tokens_per_minute": 100000}, timeout=timeout, extra_params=extra_params or {}, enabled=True, created_at=datetime.utcnow(), updated_at=datetime.utcnow() ) await provider.insert() # 注册速率限制 rate_limiter.register( provider_id, provider.rate_limit.get("requests_per_minute", 60), provider.rate_limit.get("tokens_per_minute", 100000) ) logger.info(f"创建AI接口配置: {provider_id} ({name})") return provider @classmethod async def get_provider(cls, provider_id: str) -> Optional[AIProvider]: """ 获取指定AI接口配置 Args: provider_id: 接口ID Returns: AIProvider文档或None """ return await AIProvider.find_one(AIProvider.provider_id == provider_id) @classmethod async def get_all_providers( cls, enabled_only: bool = False ) -> List[AIProvider]: """ 获取所有AI接口配置 Args: enabled_only: 是否只返回启用的接口 Returns: AIProvider列表 """ if enabled_only: return await AIProvider.find(AIProvider.enabled == True).to_list() return await AIProvider.find_all().to_list() @classmethod async def update_provider( cls, provider_id: str, **kwargs ) -> Optional[AIProvider]: """ 更新AI接口配置 Args: provider_id: 接口ID **kwargs: 要更新的字段 Returns: 更新后的AIProvider或None """ provider = await cls.get_provider(provider_id) if not provider: return None # 如果更新了API密钥,需要加密 if "api_key" in kwargs and kwargs["api_key"]: kwargs["api_key"] = encrypt_api_key(kwargs["api_key"]) # 更新字段 kwargs["updated_at"] = datetime.utcnow() for key, value in kwargs.items(): if hasattr(provider, key): setattr(provider, key, value) await provider.save() # 清除适配器缓存 cls._adapter_cache.pop(provider_id, None) # 更新速率限制 if "rate_limit" in kwargs: rate_limiter.unregister(provider_id) rate_limiter.register( provider_id, provider.rate_limit.get("requests_per_minute", 60), provider.rate_limit.get("tokens_per_minute", 100000) ) logger.info(f"更新AI接口配置: {provider_id}") return provider @classmethod async def delete_provider(cls, provider_id: str) -> bool: """ 删除AI接口配置 Args: provider_id: 接口ID Returns: 是否删除成功 """ provider = await cls.get_provider(provider_id) if not provider: return False await provider.delete() # 清除缓存和速率限制 cls._adapter_cache.pop(provider_id, None) rate_limiter.unregister(provider_id) logger.info(f"删除AI接口配置: {provider_id}") return True @classmethod async def get_adapter(cls, provider_id: str) -> Optional[BaseAdapter]: """ 获取AI接口的适配器实例 Args: provider_id: 接口ID Returns: 适配器实例或None """ # 检查缓存 if provider_id in cls._adapter_cache: return cls._adapter_cache[provider_id] provider = await cls.get_provider(provider_id) if not provider or not provider.enabled: return None # 解密API密钥 api_key = decrypt_api_key(provider.api_key) if provider.api_key else "" # 创建适配器 adapter_class = get_adapter(provider.provider_type) adapter = adapter_class( api_key=api_key, base_url=provider.base_url, model=provider.model, use_proxy=provider.use_proxy, proxy_config=provider.proxy_config, timeout=provider.timeout, **provider.extra_params ) # 缓存适配器 cls._adapter_cache[provider_id] = adapter return adapter @classmethod async def chat( cls, provider_id: str, messages: List[Dict[str, str]], temperature: float = 0.7, max_tokens: int = 2000, **kwargs ) -> AdapterResponse: """ 调用AI接口进行对话 Args: provider_id: 接口ID messages: 消息列表 [{"role": "user", "content": "..."}] temperature: 温度参数 max_tokens: 最大token数 **kwargs: 额外参数 Returns: 适配器响应 """ adapter = await cls.get_adapter(provider_id) if not adapter: return AdapterResponse( success=False, error=f"AI接口不存在或未启用: {provider_id}" ) # 检查速率限制 estimated_tokens = sum(len(m.get("content", "")) for m in messages) // 4 if not await rate_limiter.acquire_wait(provider_id, estimated_tokens): return AdapterResponse( success=False, error="请求频率超限,请稍后重试" ) # 转换消息格式 chat_messages = [ ChatMessage( role=m.get("role", "user"), content=m.get("content", ""), name=m.get("name") ) for m in messages ] # 调用适配器 response = await adapter.chat( messages=chat_messages, temperature=temperature, max_tokens=max_tokens, **kwargs ) return response @classmethod async def test_provider(cls, provider_id: str) -> Dict[str, Any]: """ 测试AI接口连接 Args: provider_id: 接口ID Returns: 测试结果 """ adapter = await cls.get_adapter(provider_id) if not adapter: return { "success": False, "message": f"AI接口不存在或未启用: {provider_id}" } return await adapter.test_connection() @classmethod async def test_provider_config( cls, provider_type: str, api_key: str, base_url: str = "", model: str = "", use_proxy: bool = False, proxy_config: Optional[Dict[str, Any]] = None, timeout: int = 30, **kwargs ) -> Dict[str, Any]: """ 测试AI接口配置(不保存) Args: provider_type: 提供商类型 api_key: API密钥 base_url: API基础URL model: 模型名称 use_proxy: 是否使用代理 proxy_config: 代理配置 timeout: 超时时间 **kwargs: 额外参数 Returns: 测试结果 """ try: adapter_class = get_adapter(provider_type) except ValueError: return { "success": False, "message": f"不支持的提供商类型: {provider_type}" } adapter = adapter_class( api_key=api_key, base_url=base_url, model=model, use_proxy=use_proxy, proxy_config=proxy_config, timeout=timeout, **kwargs ) return await adapter.test_connection()