""" AI适配器基类 定义统一的AI调用接口 """ from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import List, Dict, Any, Optional, AsyncGenerator from datetime import datetime @dataclass class ChatMessage: """聊天消息""" role: str # system, user, assistant content: str # 消息内容 name: Optional[str] = None # 发送者名称(可选) def to_dict(self) -> Dict[str, Any]: """转换为字典""" d = {"role": self.role, "content": self.content} if self.name: d["name"] = self.name return d @dataclass class AdapterResponse: """适配器响应""" success: bool # 是否成功 content: str = "" # 响应内容 error: Optional[str] = None # 错误信息 # 统计信息 prompt_tokens: int = 0 # 输入token数 completion_tokens: int = 0 # 输出token数 total_tokens: int = 0 # 总token数 # 元数据 model: str = "" # 使用的模型 finish_reason: str = "" # 结束原因 latency_ms: float = 0.0 # 延迟(毫秒) # 工具调用结果 tool_calls: List[Dict[str, Any]] = field(default_factory=list) def __post_init__(self): if self.total_tokens == 0: self.total_tokens = self.prompt_tokens + self.completion_tokens class BaseAdapter(ABC): """ AI适配器基类 所有AI提供商适配器必须继承此类 """ def __init__( self, api_key: str, base_url: str, model: str, use_proxy: bool = False, proxy_config: Optional[Dict[str, Any]] = None, timeout: int = 60, **kwargs ): """ 初始化适配器 Args: api_key: API密钥 base_url: API基础URL model: 模型名称 use_proxy: 是否使用代理 proxy_config: 代理配置 timeout: 超时时间(秒) **kwargs: 额外参数 """ self.api_key = api_key self.base_url = base_url self.model = model self.use_proxy = use_proxy self.proxy_config = proxy_config or {} self.timeout = timeout self.extra_params = kwargs @abstractmethod async def chat( self, messages: List[ChatMessage], temperature: float = 0.7, max_tokens: int = 2000, **kwargs ) -> AdapterResponse: """ 发送聊天请求 Args: messages: 消息列表 temperature: 温度参数 max_tokens: 最大token数 **kwargs: 额外参数 Returns: 适配器响应 """ pass @abstractmethod async def chat_stream( self, messages: List[ChatMessage], temperature: float = 0.7, max_tokens: int = 2000, **kwargs ) -> AsyncGenerator[str, None]: """ 发送流式聊天请求 Args: messages: 消息列表 temperature: 温度参数 max_tokens: 最大token数 **kwargs: 额外参数 Yields: 响应内容片段 """ pass @abstractmethod async def test_connection(self) -> Dict[str, Any]: """ 测试API连接 Returns: 测试结果字典,包含 success, message, latency_ms """ pass def _build_messages( self, messages: List[ChatMessage] ) -> List[Dict[str, Any]]: """ 构建消息列表 Args: messages: ChatMessage列表 Returns: 字典格式的消息列表 """ return [msg.to_dict() for msg in messages] def _calculate_latency(self, start_time: datetime) -> float: """ 计算延迟 Args: start_time: 开始时间 Returns: 延迟毫秒数 """ return (datetime.utcnow() - start_time).total_seconds() * 1000