216 lines
5.4 KiB
Python
216 lines
5.4 KiB
Python
|
|
"""LLM 客户端基类
|
|||
|
|
|
|||
|
|
定义统一的 LLM 接口
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
from abc import ABC, abstractmethod
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from enum import Enum
|
|||
|
|
from typing import Any, AsyncIterator
|
|||
|
|
|
|||
|
|
|
|||
|
|
class Provider(str, Enum):
|
|||
|
|
"""LLM 提供商"""
|
|||
|
|
|
|||
|
|
ANTHROPIC = "anthropic" # Claude
|
|||
|
|
OPENAI = "openai" # GPT
|
|||
|
|
DEEPSEEK = "deepseek" # DeepSeek
|
|||
|
|
ZHIPU = "zhipu" # 智谱 GLM
|
|||
|
|
MINIMAX = "minimax" # MiniMax
|
|||
|
|
MOONSHOT = "moonshot" # Kimi
|
|||
|
|
GEMINI = "gemini" # Google Gemini
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def is_overseas(self) -> bool:
|
|||
|
|
"""是否为境外服务(需要代理)"""
|
|||
|
|
return self in {Provider.ANTHROPIC, Provider.OPENAI, Provider.GEMINI}
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def display_name(self) -> str:
|
|||
|
|
"""显示名称"""
|
|||
|
|
names = {
|
|||
|
|
Provider.ANTHROPIC: "Claude (Anthropic)",
|
|||
|
|
Provider.OPENAI: "GPT (OpenAI)",
|
|||
|
|
Provider.DEEPSEEK: "DeepSeek",
|
|||
|
|
Provider.ZHIPU: "GLM (智谱)",
|
|||
|
|
Provider.MINIMAX: "MiniMax",
|
|||
|
|
Provider.MOONSHOT: "Kimi (Moonshot)",
|
|||
|
|
Provider.GEMINI: "Gemini (Google)",
|
|||
|
|
}
|
|||
|
|
return names.get(self, self.value)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class Message:
|
|||
|
|
"""消息"""
|
|||
|
|
|
|||
|
|
role: str # "user", "assistant", "system"
|
|||
|
|
content: str
|
|||
|
|
name: str | None = None
|
|||
|
|
tool_calls: list[ToolCall] | None = None
|
|||
|
|
tool_call_id: str | None = None # 用于 tool 结果消息
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ToolCall:
|
|||
|
|
"""工具调用"""
|
|||
|
|
|
|||
|
|
id: str
|
|||
|
|
name: str
|
|||
|
|
arguments: dict[str, Any] = field(default_factory=dict)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ToolDefinition:
|
|||
|
|
"""工具定义"""
|
|||
|
|
|
|||
|
|
name: str
|
|||
|
|
description: str
|
|||
|
|
parameters: dict[str, Any] = field(default_factory=dict)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class LLMResponse:
|
|||
|
|
"""LLM 响应"""
|
|||
|
|
|
|||
|
|
content: str
|
|||
|
|
model: str
|
|||
|
|
provider: Provider
|
|||
|
|
tool_calls: list[ToolCall] | None = None
|
|||
|
|
finish_reason: str = "stop"
|
|||
|
|
usage: dict[str, int] = field(default_factory=dict)
|
|||
|
|
raw_response: Any = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class StreamChunk:
|
|||
|
|
"""流式响应块"""
|
|||
|
|
|
|||
|
|
content: str
|
|||
|
|
is_final: bool = False
|
|||
|
|
tool_calls: list[ToolCall] | None = None
|
|||
|
|
usage: dict[str, int] | None = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class BaseLLMClient(ABC):
|
|||
|
|
"""LLM 客户端基类"""
|
|||
|
|
|
|||
|
|
provider: Provider
|
|||
|
|
default_model: str
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
api_key: str,
|
|||
|
|
base_url: str | None = None,
|
|||
|
|
proxy: str | None = None,
|
|||
|
|
timeout: float = 60.0,
|
|||
|
|
max_retries: int = 3,
|
|||
|
|
) -> None:
|
|||
|
|
"""初始化客户端
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
api_key: API 密钥
|
|||
|
|
base_url: 自定义 API 地址
|
|||
|
|
proxy: 代理地址(境外服务使用)
|
|||
|
|
timeout: 请求超时(秒)
|
|||
|
|
max_retries: 最大重试次数
|
|||
|
|
"""
|
|||
|
|
self.api_key = api_key
|
|||
|
|
self.base_url = base_url
|
|||
|
|
self.proxy = proxy
|
|||
|
|
self.timeout = timeout
|
|||
|
|
self.max_retries = max_retries
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
async def chat(
|
|||
|
|
self,
|
|||
|
|
messages: list[Message],
|
|||
|
|
model: str | None = None,
|
|||
|
|
system: str | None = None,
|
|||
|
|
max_tokens: int = 4096,
|
|||
|
|
temperature: float = 0.7,
|
|||
|
|
tools: list[ToolDefinition] | None = None,
|
|||
|
|
**kwargs: Any,
|
|||
|
|
) -> LLMResponse:
|
|||
|
|
"""发送聊天请求
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
messages: 消息列表
|
|||
|
|
model: 模型名称
|
|||
|
|
system: 系统提示词
|
|||
|
|
max_tokens: 最大输出 token
|
|||
|
|
temperature: 温度参数
|
|||
|
|
tools: 工具定义列表
|
|||
|
|
**kwargs: 其他参数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
LLM 响应
|
|||
|
|
"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
async def chat_stream(
|
|||
|
|
self,
|
|||
|
|
messages: list[Message],
|
|||
|
|
model: str | None = None,
|
|||
|
|
system: str | None = None,
|
|||
|
|
max_tokens: int = 4096,
|
|||
|
|
temperature: float = 0.7,
|
|||
|
|
tools: list[ToolDefinition] | None = None,
|
|||
|
|
**kwargs: Any,
|
|||
|
|
) -> AsyncIterator[StreamChunk]:
|
|||
|
|
"""流式聊天
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
messages: 消息列表
|
|||
|
|
model: 模型名称
|
|||
|
|
system: 系统提示词
|
|||
|
|
max_tokens: 最大输出 token
|
|||
|
|
temperature: 温度参数
|
|||
|
|
tools: 工具定义列表
|
|||
|
|
**kwargs: 其他参数
|
|||
|
|
|
|||
|
|
Yields:
|
|||
|
|
流式响应块
|
|||
|
|
"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
|
|||
|
|
"""转换消息格式为 API 格式(子类可覆盖)"""
|
|||
|
|
result = []
|
|||
|
|
for msg in messages:
|
|||
|
|
item: dict[str, Any] = {
|
|||
|
|
"role": msg.role,
|
|||
|
|
"content": msg.content,
|
|||
|
|
}
|
|||
|
|
if msg.name:
|
|||
|
|
item["name"] = msg.name
|
|||
|
|
if msg.tool_call_id:
|
|||
|
|
item["tool_call_id"] = msg.tool_call_id
|
|||
|
|
result.append(item)
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
def _convert_tools(self, tools: list[ToolDefinition]) -> list[dict[str, Any]]:
|
|||
|
|
"""转换工具定义格式(子类可覆盖)"""
|
|||
|
|
return [
|
|||
|
|
{
|
|||
|
|
"type": "function",
|
|||
|
|
"function": {
|
|||
|
|
"name": tool.name,
|
|||
|
|
"description": tool.description,
|
|||
|
|
"parameters": tool.parameters,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
for tool in tools
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
async def close(self) -> None:
|
|||
|
|
"""关闭客户端"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
def get_available_models(self) -> list[str]:
|
|||
|
|
"""获取可用模型列表"""
|
|||
|
|
return [self.default_model]
|