Files
AIChatRoom/backend/adapters/base_adapter.py

167 lines
4.3 KiB
Python
Raw Permalink Normal View History

"""
AI适配器基类
定义统一的AI调用接口
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional, AsyncGenerator
from datetime import datetime
@dataclass
class ChatMessage:
"""聊天消息"""
role: str # system, user, assistant
content: str # 消息内容
name: Optional[str] = None # 发送者名称(可选)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
d = {"role": self.role, "content": self.content}
if self.name:
d["name"] = self.name
return d
@dataclass
class AdapterResponse:
"""适配器响应"""
success: bool # 是否成功
content: str = "" # 响应内容
error: Optional[str] = None # 错误信息
# 统计信息
prompt_tokens: int = 0 # 输入token数
completion_tokens: int = 0 # 输出token数
total_tokens: int = 0 # 总token数
# 元数据
model: str = "" # 使用的模型
finish_reason: str = "" # 结束原因
latency_ms: float = 0.0 # 延迟(毫秒)
# 工具调用结果
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
def __post_init__(self):
if self.total_tokens == 0:
self.total_tokens = self.prompt_tokens + self.completion_tokens
class BaseAdapter(ABC):
"""
AI适配器基类
所有AI提供商适配器必须继承此类
"""
def __init__(
self,
api_key: str,
base_url: str,
model: str,
use_proxy: bool = False,
proxy_config: Optional[Dict[str, Any]] = None,
timeout: int = 60,
**kwargs
):
"""
初始化适配器
Args:
api_key: API密钥
base_url: API基础URL
model: 模型名称
use_proxy: 是否使用代理
proxy_config: 代理配置
timeout: 超时时间()
**kwargs: 额外参数
"""
self.api_key = api_key
self.base_url = base_url
self.model = model
self.use_proxy = use_proxy
self.proxy_config = proxy_config or {}
self.timeout = timeout
self.extra_params = kwargs
@abstractmethod
async def chat(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AdapterResponse:
"""
发送聊天请求
Args:
messages: 消息列表
temperature: 温度参数
max_tokens: 最大token数
**kwargs: 额外参数
Returns:
适配器响应
"""
pass
@abstractmethod
async def chat_stream(
self,
messages: List[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AsyncGenerator[str, None]:
"""
发送流式聊天请求
Args:
messages: 消息列表
temperature: 温度参数
max_tokens: 最大token数
**kwargs: 额外参数
Yields:
响应内容片段
"""
pass
@abstractmethod
async def test_connection(self) -> Dict[str, Any]:
"""
测试API连接
Returns:
测试结果字典包含 success, message, latency_ms
"""
pass
def _build_messages(
self,
messages: List[ChatMessage]
) -> List[Dict[str, Any]]:
"""
构建消息列表
Args:
messages: ChatMessage列表
Returns:
字典格式的消息列表
"""
return [msg.to_dict() for msg in messages]
def _calculate_latency(self, start_time: datetime) -> float:
"""
计算延迟
Args:
start_time: 开始时间
Returns:
延迟毫秒数
"""
return (datetime.utcnow() - start_time).total_seconds() * 1000