feat: 添加项目规则、环境配置示例及开发文档
This commit is contained in:
215
src/minenasai/llm/base.py
Normal file
215
src/minenasai/llm/base.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""LLM 客户端基类
|
||||
|
||||
定义统一的 LLM 接口
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
|
||||
class Provider(str, Enum):
|
||||
"""LLM 提供商"""
|
||||
|
||||
ANTHROPIC = "anthropic" # Claude
|
||||
OPENAI = "openai" # GPT
|
||||
DEEPSEEK = "deepseek" # DeepSeek
|
||||
ZHIPU = "zhipu" # 智谱 GLM
|
||||
MINIMAX = "minimax" # MiniMax
|
||||
MOONSHOT = "moonshot" # Kimi
|
||||
GEMINI = "gemini" # Google Gemini
|
||||
|
||||
@property
|
||||
def is_overseas(self) -> bool:
|
||||
"""是否为境外服务(需要代理)"""
|
||||
return self in {Provider.ANTHROPIC, Provider.OPENAI, Provider.GEMINI}
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""显示名称"""
|
||||
names = {
|
||||
Provider.ANTHROPIC: "Claude (Anthropic)",
|
||||
Provider.OPENAI: "GPT (OpenAI)",
|
||||
Provider.DEEPSEEK: "DeepSeek",
|
||||
Provider.ZHIPU: "GLM (智谱)",
|
||||
Provider.MINIMAX: "MiniMax",
|
||||
Provider.MOONSHOT: "Kimi (Moonshot)",
|
||||
Provider.GEMINI: "Gemini (Google)",
|
||||
}
|
||||
return names.get(self, self.value)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""消息"""
|
||||
|
||||
role: str # "user", "assistant", "system"
|
||||
content: str
|
||||
name: str | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
tool_call_id: str | None = None # 用于 tool 结果消息
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""工具调用"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
arguments: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolDefinition:
|
||||
"""工具定义"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""LLM 响应"""
|
||||
|
||||
content: str
|
||||
model: str
|
||||
provider: Provider
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
finish_reason: str = "stop"
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
raw_response: Any = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamChunk:
|
||||
"""流式响应块"""
|
||||
|
||||
content: str
|
||||
is_final: bool = False
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
usage: dict[str, int] | None = None
|
||||
|
||||
|
||||
class BaseLLMClient(ABC):
|
||||
"""LLM 客户端基类"""
|
||||
|
||||
provider: Provider
|
||||
default_model: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str | None = None,
|
||||
proxy: str | None = None,
|
||||
timeout: float = 60.0,
|
||||
max_retries: int = 3,
|
||||
) -> None:
|
||||
"""初始化客户端
|
||||
|
||||
Args:
|
||||
api_key: API 密钥
|
||||
base_url: 自定义 API 地址
|
||||
proxy: 代理地址(境外服务使用)
|
||||
timeout: 请求超时(秒)
|
||||
max_retries: 最大重试次数
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.proxy = proxy
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[Message],
|
||||
model: str | None = None,
|
||||
system: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""发送聊天请求
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
model: 模型名称
|
||||
system: 系统提示词
|
||||
max_tokens: 最大输出 token
|
||||
temperature: 温度参数
|
||||
tools: 工具定义列表
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
LLM 响应
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[Message],
|
||||
model: str | None = None,
|
||||
system: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[StreamChunk]:
|
||||
"""流式聊天
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
model: 模型名称
|
||||
system: 系统提示词
|
||||
max_tokens: 最大输出 token
|
||||
temperature: 温度参数
|
||||
tools: 工具定义列表
|
||||
**kwargs: 其他参数
|
||||
|
||||
Yields:
|
||||
流式响应块
|
||||
"""
|
||||
pass
|
||||
|
||||
def _convert_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
|
||||
"""转换消息格式为 API 格式(子类可覆盖)"""
|
||||
result = []
|
||||
for msg in messages:
|
||||
item: dict[str, Any] = {
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
}
|
||||
if msg.name:
|
||||
item["name"] = msg.name
|
||||
if msg.tool_call_id:
|
||||
item["tool_call_id"] = msg.tool_call_id
|
||||
result.append(item)
|
||||
return result
|
||||
|
||||
def _convert_tools(self, tools: list[ToolDefinition]) -> list[dict[str, Any]]:
|
||||
"""转换工具定义格式(子类可覆盖)"""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
},
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭客户端"""
|
||||
pass
|
||||
|
||||
def get_available_models(self) -> list[str]:
|
||||
"""获取可用模型列表"""
|
||||
return [self.default_model]
|
||||
Reference in New Issue
Block a user