"""LLM 客户端基类 定义统一的 LLM 接口 """ from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum from typing import Any, AsyncIterator class Provider(str, Enum): """LLM 提供商""" ANTHROPIC = "anthropic" # Claude OPENAI = "openai" # GPT DEEPSEEK = "deepseek" # DeepSeek ZHIPU = "zhipu" # 智谱 GLM MINIMAX = "minimax" # MiniMax MOONSHOT = "moonshot" # Kimi GEMINI = "gemini" # Google Gemini @property def is_overseas(self) -> bool: """是否为境外服务(需要代理)""" return self in {Provider.ANTHROPIC, Provider.OPENAI, Provider.GEMINI} @property def display_name(self) -> str: """显示名称""" names = { Provider.ANTHROPIC: "Claude (Anthropic)", Provider.OPENAI: "GPT (OpenAI)", Provider.DEEPSEEK: "DeepSeek", Provider.ZHIPU: "GLM (智谱)", Provider.MINIMAX: "MiniMax", Provider.MOONSHOT: "Kimi (Moonshot)", Provider.GEMINI: "Gemini (Google)", } return names.get(self, self.value) @dataclass class Message: """消息""" role: str # "user", "assistant", "system" content: str name: str | None = None tool_calls: list[ToolCall] | None = None tool_call_id: str | None = None # 用于 tool 结果消息 @dataclass class ToolCall: """工具调用""" id: str name: str arguments: dict[str, Any] = field(default_factory=dict) @dataclass class ToolDefinition: """工具定义""" name: str description: str parameters: dict[str, Any] = field(default_factory=dict) @dataclass class LLMResponse: """LLM 响应""" content: str model: str provider: Provider tool_calls: list[ToolCall] | None = None finish_reason: str = "stop" usage: dict[str, int] = field(default_factory=dict) raw_response: Any = None @dataclass class StreamChunk: """流式响应块""" content: str is_final: bool = False tool_calls: list[ToolCall] | None = None usage: dict[str, int] | None = None class BaseLLMClient(ABC): """LLM 客户端基类""" provider: Provider default_model: str def __init__( self, api_key: str, base_url: str | None = None, proxy: str | None = None, timeout: float = 60.0, max_retries: int = 3, ) -> None: """初始化客户端 Args: api_key: API 密钥 base_url: 自定义 API 地址 proxy: 代理地址(境外服务使用) timeout: 请求超时(秒) max_retries: 最大重试次数 """ self.api_key = api_key self.base_url = base_url self.proxy = proxy self.timeout = timeout self.max_retries = max_retries @abstractmethod async def chat( self, messages: list[Message], model: str | None = None, system: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, tools: list[ToolDefinition] | None = None, **kwargs: Any, ) -> LLMResponse: """发送聊天请求 Args: messages: 消息列表 model: 模型名称 system: 系统提示词 max_tokens: 最大输出 token temperature: 温度参数 tools: 工具定义列表 **kwargs: 其他参数 Returns: LLM 响应 """ pass @abstractmethod async def chat_stream( self, messages: list[Message], model: str | None = None, system: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, tools: list[ToolDefinition] | None = None, **kwargs: Any, ) -> AsyncIterator[StreamChunk]: """流式聊天 Args: messages: 消息列表 model: 模型名称 system: 系统提示词 max_tokens: 最大输出 token temperature: 温度参数 tools: 工具定义列表 **kwargs: 其他参数 Yields: 流式响应块 """ pass def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]: """转换消息格式为 API 格式(子类可覆盖)""" result = [] for msg in messages: item: dict[str, Any] = { "role": msg.role, "content": msg.content, } if msg.name: item["name"] = msg.name if msg.tool_call_id: item["tool_call_id"] = msg.tool_call_id result.append(item) return result def _convert_tools(self, tools: list[ToolDefinition]) -> list[dict[str, Any]]: """转换工具定义格式(子类可覆盖)""" return [ { "type": "function", "function": { "name": tool.name, "description": tool.description, "parameters": tool.parameters, }, } for tool in tools ] async def close(self) -> None: """关闭客户端""" pass def get_available_models(self) -> list[str]: """获取可用模型列表""" return [self.default_model]