Files
MineNasAI/src/minenasai/llm/base.py

220 lines
5.5 KiB
Python

"""LLM 客户端基类
定义统一的 LLM 接口
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Any
class Provider(StrEnum):
"""LLM 提供商
支持的 AI API 提供商枚举。
"""
ANTHROPIC = "anthropic" # Claude
OPENAI = "openai" # GPT
DEEPSEEK = "deepseek" # DeepSeek
ZHIPU = "zhipu" # 智谱 GLM
MINIMAX = "minimax" # MiniMax
MOONSHOT = "moonshot" # Kimi
GEMINI = "gemini" # Google Gemini
@property
def is_overseas(self) -> bool:
"""是否为境外服务(需要代理)"""
return self in {Provider.ANTHROPIC, Provider.OPENAI, Provider.GEMINI}
@property
def display_name(self) -> str:
"""显示名称"""
names = {
Provider.ANTHROPIC: "Claude (Anthropic)",
Provider.OPENAI: "GPT (OpenAI)",
Provider.DEEPSEEK: "DeepSeek",
Provider.ZHIPU: "GLM (智谱)",
Provider.MINIMAX: "MiniMax",
Provider.MOONSHOT: "Kimi (Moonshot)",
Provider.GEMINI: "Gemini (Google)",
}
return names.get(self, self.value)
@dataclass
class Message:
"""消息"""
role: str # "user", "assistant", "system"
content: str
name: str | None = None
tool_calls: list[ToolCall] | None = None
tool_call_id: str | None = None # 用于 tool 结果消息
@dataclass
class ToolCall:
"""工具调用"""
id: str
name: str
arguments: dict[str, Any] = field(default_factory=dict)
@dataclass
class ToolDefinition:
"""工具定义"""
name: str
description: str
parameters: dict[str, Any] = field(default_factory=dict)
@dataclass
class LLMResponse:
"""LLM 响应"""
content: str
model: str
provider: Provider
tool_calls: list[ToolCall] | None = None
finish_reason: str = "stop"
usage: dict[str, int] = field(default_factory=dict)
raw_response: Any = None
@dataclass
class StreamChunk:
"""流式响应块"""
content: str
is_final: bool = False
tool_calls: list[ToolCall] | None = None
usage: dict[str, int] | None = None
class BaseLLMClient(ABC):
"""LLM 客户端基类"""
provider: Provider
default_model: str
def __init__(
self,
api_key: str,
base_url: str | None = None,
proxy: str | None = None,
timeout: float = 60.0,
max_retries: int = 3,
) -> None:
"""初始化客户端
Args:
api_key: API 密钥
base_url: 自定义 API 地址
proxy: 代理地址(境外服务使用)
timeout: 请求超时(秒)
max_retries: 最大重试次数
"""
self.api_key = api_key
self.base_url = base_url
self.proxy = proxy
self.timeout = timeout
self.max_retries = max_retries
@abstractmethod
async def chat(
self,
messages: list[Message],
model: str | None = None,
system: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
tools: list[ToolDefinition] | None = None,
**kwargs: Any,
) -> LLMResponse:
"""发送聊天请求
Args:
messages: 消息列表
model: 模型名称
system: 系统提示词
max_tokens: 最大输出 token
temperature: 温度参数
tools: 工具定义列表
**kwargs: 其他参数
Returns:
LLM 响应
"""
pass
@abstractmethod
async def chat_stream(
self,
messages: list[Message],
model: str | None = None,
system: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
tools: list[ToolDefinition] | None = None,
**kwargs: Any,
) -> AsyncIterator[StreamChunk]:
"""流式聊天
Args:
messages: 消息列表
model: 模型名称
system: 系统提示词
max_tokens: 最大输出 token
temperature: 温度参数
tools: 工具定义列表
**kwargs: 其他参数
Yields:
流式响应块
"""
pass
def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
"""转换消息格式为 API 格式(子类可覆盖)"""
result = []
for msg in messages:
item: dict[str, Any] = {
"role": msg.role,
"content": msg.content,
}
if msg.name:
item["name"] = msg.name
if msg.tool_call_id:
item["tool_call_id"] = msg.tool_call_id
result.append(item)
return result
def _convert_tools(self, tools: list[ToolDefinition]) -> list[dict[str, Any]]:
"""转换工具定义格式(子类可覆盖)"""
return [
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
},
}
for tool in tools
]
async def close(self) -> None:
"""关闭客户端"""
pass
def get_available_models(self) -> list[str]:
"""获取可用模型列表"""
return [self.default_model]