167 lines
4.3 KiB
Python
167 lines
4.3 KiB
Python
|
|
"""
|
|||
|
|
AI适配器基类
|
|||
|
|
定义统一的AI调用接口
|
|||
|
|
"""
|
|||
|
|
from abc import ABC, abstractmethod
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from typing import List, Dict, Any, Optional, AsyncGenerator
|
|||
|
|
from datetime import datetime
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ChatMessage:
|
|||
|
|
"""聊天消息"""
|
|||
|
|
role: str # system, user, assistant
|
|||
|
|
content: str # 消息内容
|
|||
|
|
name: Optional[str] = None # 发送者名称(可选)
|
|||
|
|
|
|||
|
|
def to_dict(self) -> Dict[str, Any]:
|
|||
|
|
"""转换为字典"""
|
|||
|
|
d = {"role": self.role, "content": self.content}
|
|||
|
|
if self.name:
|
|||
|
|
d["name"] = self.name
|
|||
|
|
return d
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class AdapterResponse:
|
|||
|
|
"""适配器响应"""
|
|||
|
|
success: bool # 是否成功
|
|||
|
|
content: str = "" # 响应内容
|
|||
|
|
error: Optional[str] = None # 错误信息
|
|||
|
|
|
|||
|
|
# 统计信息
|
|||
|
|
prompt_tokens: int = 0 # 输入token数
|
|||
|
|
completion_tokens: int = 0 # 输出token数
|
|||
|
|
total_tokens: int = 0 # 总token数
|
|||
|
|
|
|||
|
|
# 元数据
|
|||
|
|
model: str = "" # 使用的模型
|
|||
|
|
finish_reason: str = "" # 结束原因
|
|||
|
|
latency_ms: float = 0.0 # 延迟(毫秒)
|
|||
|
|
|
|||
|
|
# 工具调用结果
|
|||
|
|
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
|
|||
|
|
|
|||
|
|
def __post_init__(self):
|
|||
|
|
if self.total_tokens == 0:
|
|||
|
|
self.total_tokens = self.prompt_tokens + self.completion_tokens
|
|||
|
|
|
|||
|
|
|
|||
|
|
class BaseAdapter(ABC):
|
|||
|
|
"""
|
|||
|
|
AI适配器基类
|
|||
|
|
所有AI提供商适配器必须继承此类
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
api_key: str,
|
|||
|
|
base_url: str,
|
|||
|
|
model: str,
|
|||
|
|
use_proxy: bool = False,
|
|||
|
|
proxy_config: Optional[Dict[str, Any]] = None,
|
|||
|
|
timeout: int = 60,
|
|||
|
|
**kwargs
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
初始化适配器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
api_key: API密钥
|
|||
|
|
base_url: API基础URL
|
|||
|
|
model: 模型名称
|
|||
|
|
use_proxy: 是否使用代理
|
|||
|
|
proxy_config: 代理配置
|
|||
|
|
timeout: 超时时间(秒)
|
|||
|
|
**kwargs: 额外参数
|
|||
|
|
"""
|
|||
|
|
self.api_key = api_key
|
|||
|
|
self.base_url = base_url
|
|||
|
|
self.model = model
|
|||
|
|
self.use_proxy = use_proxy
|
|||
|
|
self.proxy_config = proxy_config or {}
|
|||
|
|
self.timeout = timeout
|
|||
|
|
self.extra_params = kwargs
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
async def chat(
|
|||
|
|
self,
|
|||
|
|
messages: List[ChatMessage],
|
|||
|
|
temperature: float = 0.7,
|
|||
|
|
max_tokens: int = 2000,
|
|||
|
|
**kwargs
|
|||
|
|
) -> AdapterResponse:
|
|||
|
|
"""
|
|||
|
|
发送聊天请求
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
messages: 消息列表
|
|||
|
|
temperature: 温度参数
|
|||
|
|
max_tokens: 最大token数
|
|||
|
|
**kwargs: 额外参数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
适配器响应
|
|||
|
|
"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
async def chat_stream(
|
|||
|
|
self,
|
|||
|
|
messages: List[ChatMessage],
|
|||
|
|
temperature: float = 0.7,
|
|||
|
|
max_tokens: int = 2000,
|
|||
|
|
**kwargs
|
|||
|
|
) -> AsyncGenerator[str, None]:
|
|||
|
|
"""
|
|||
|
|
发送流式聊天请求
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
messages: 消息列表
|
|||
|
|
temperature: 温度参数
|
|||
|
|
max_tokens: 最大token数
|
|||
|
|
**kwargs: 额外参数
|
|||
|
|
|
|||
|
|
Yields:
|
|||
|
|
响应内容片段
|
|||
|
|
"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
async def test_connection(self) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
测试API连接
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
测试结果字典,包含 success, message, latency_ms
|
|||
|
|
"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
def _build_messages(
|
|||
|
|
self,
|
|||
|
|
messages: List[ChatMessage]
|
|||
|
|
) -> List[Dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
构建消息列表
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
messages: ChatMessage列表
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
字典格式的消息列表
|
|||
|
|
"""
|
|||
|
|
return [msg.to_dict() for msg in messages]
|
|||
|
|
|
|||
|
|
def _calculate_latency(self, start_time: datetime) -> float:
|
|||
|
|
"""
|
|||
|
|
计算延迟
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
start_time: 开始时间
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
延迟毫秒数
|
|||
|
|
"""
|
|||
|
|
return (datetime.utcnow() - start_time).total_seconds() * 1000
|