feat: 添加项目规则、环境配置示例及开发文档
This commit is contained in:
19
src/minenasai/llm/clients/__init__.py
Normal file
19
src/minenasai/llm/clients/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""LLM 客户端实现"""
|
||||
|
||||
from minenasai.llm.clients.anthropic import AnthropicClient
|
||||
from minenasai.llm.clients.openai_compat import OpenAICompatClient
|
||||
from minenasai.llm.clients.deepseek import DeepSeekClient
|
||||
from minenasai.llm.clients.zhipu import ZhipuClient
|
||||
from minenasai.llm.clients.minimax import MiniMaxClient
|
||||
from minenasai.llm.clients.moonshot import MoonshotClient
|
||||
from minenasai.llm.clients.gemini import GeminiClient
|
||||
|
||||
__all__ = [
|
||||
"AnthropicClient",
|
||||
"OpenAICompatClient",
|
||||
"DeepSeekClient",
|
||||
"ZhipuClient",
|
||||
"MiniMaxClient",
|
||||
"MoonshotClient",
|
||||
"GeminiClient",
|
||||
]
|
||||
232
src/minenasai/llm/clients/anthropic.py
Normal file
232
src/minenasai/llm/clients/anthropic.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""Anthropic Claude 客户端"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
import httpx
|
||||
|
||||
from minenasai.core import get_logger
|
||||
from minenasai.llm.base import (
|
||||
BaseLLMClient,
|
||||
LLMResponse,
|
||||
Message,
|
||||
Provider,
|
||||
StreamChunk,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AnthropicClient(BaseLLMClient):
|
||||
"""Anthropic Claude 客户端"""
|
||||
|
||||
provider = Provider.ANTHROPIC
|
||||
default_model = "claude-sonnet-4-20250514"
|
||||
|
||||
MODELS = [
|
||||
"claude-sonnet-4-20250514",
|
||||
"claude-opus-4-20250514",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-5-haiku-20241022",
|
||||
"claude-3-opus-20240229",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str | None = None,
|
||||
proxy: str | None = None,
|
||||
timeout: float = 60.0,
|
||||
max_retries: int = 3,
|
||||
) -> None:
|
||||
super().__init__(api_key, base_url, proxy, timeout, max_retries)
|
||||
self.base_url = base_url or "https://api.anthropic.com"
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""获取 HTTP 客户端"""
|
||||
if self._client is None:
|
||||
transport = None
|
||||
if self.proxy:
|
||||
transport = httpx.AsyncHTTPTransport(proxy=self.proxy)
|
||||
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
timeout=self.timeout,
|
||||
transport=transport,
|
||||
headers={
|
||||
"x-api-key": self.api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[Message],
|
||||
model: str | None = None,
|
||||
system: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""发送聊天请求"""
|
||||
client = await self._get_client()
|
||||
model = model or self.default_model
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"messages": self._convert_messages_anthropic(messages),
|
||||
}
|
||||
|
||||
if system:
|
||||
payload["system"] = system
|
||||
|
||||
if tools:
|
||||
payload["tools"] = self._convert_tools_anthropic(tools)
|
||||
|
||||
try:
|
||||
response = await client.post("/v1/messages", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# 解析响应
|
||||
content_parts = []
|
||||
tool_calls = []
|
||||
|
||||
for block in data.get("content", []):
|
||||
if block["type"] == "text":
|
||||
content_parts.append(block["text"])
|
||||
elif block["type"] == "tool_use":
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=block["id"],
|
||||
name=block["name"],
|
||||
arguments=block.get("input", {}),
|
||||
)
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content="\n".join(content_parts),
|
||||
model=model,
|
||||
provider=self.provider,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
finish_reason=data.get("stop_reason", "stop"),
|
||||
usage={
|
||||
"input_tokens": data.get("usage", {}).get("input_tokens", 0),
|
||||
"output_tokens": data.get("usage", {}).get("output_tokens", 0),
|
||||
},
|
||||
raw_response=data,
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error("Anthropic API 错误", status=e.response.status_code, body=e.response.text)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Anthropic 请求失败", error=str(e))
|
||||
raise
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[Message],
|
||||
model: str | None = None,
|
||||
system: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[StreamChunk]:
|
||||
"""流式聊天"""
|
||||
client = await self._get_client()
|
||||
model = model or self.default_model
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"messages": self._convert_messages_anthropic(messages),
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if system:
|
||||
payload["system"] = system
|
||||
|
||||
if tools:
|
||||
payload["tools"] = self._convert_tools_anthropic(tools)
|
||||
|
||||
try:
|
||||
async with client.stream("POST", "/v1/messages", json=payload) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
|
||||
data = json.loads(line[6:])
|
||||
event_type = data.get("type")
|
||||
|
||||
if event_type == "content_block_delta":
|
||||
delta = data.get("delta", {})
|
||||
if delta.get("type") == "text_delta":
|
||||
yield StreamChunk(content=delta.get("text", ""))
|
||||
|
||||
elif event_type == "message_stop":
|
||||
yield StreamChunk(content="", is_final=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Anthropic 流式请求失败", error=str(e))
|
||||
raise
|
||||
|
||||
def _convert_messages_anthropic(self, messages: list[Message]) -> list[dict[str, Any]]:
|
||||
"""转换消息为 Anthropic 格式"""
|
||||
result = []
|
||||
for msg in messages:
|
||||
if msg.role == "system":
|
||||
continue # system 单独处理
|
||||
|
||||
item: dict[str, Any] = {"role": msg.role}
|
||||
|
||||
# 处理工具调用响应
|
||||
if msg.tool_call_id:
|
||||
item["role"] = "user"
|
||||
item["content"] = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": msg.tool_call_id,
|
||||
"content": msg.content,
|
||||
}
|
||||
]
|
||||
else:
|
||||
item["content"] = msg.content
|
||||
|
||||
result.append(item)
|
||||
|
||||
return result
|
||||
|
||||
def _convert_tools_anthropic(self, tools: list[ToolDefinition]) -> list[dict[str, Any]]:
|
||||
"""转换工具定义为 Anthropic 格式"""
|
||||
return [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.parameters,
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭客户端"""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
def get_available_models(self) -> list[str]:
|
||||
return self.MODELS
|
||||
45
src/minenasai/llm/clients/deepseek.py
Normal file
45
src/minenasai/llm/clients/deepseek.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""DeepSeek 客户端
|
||||
|
||||
DeepSeek 使用 OpenAI 兼容接口
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from minenasai.llm.base import Provider
|
||||
from minenasai.llm.clients.openai_compat import OpenAICompatClient
|
||||
|
||||
|
||||
class DeepSeekClient(OpenAICompatClient):
|
||||
"""DeepSeek 客户端
|
||||
|
||||
DeepSeek API 兼容 OpenAI 接口格式
|
||||
官方文档: https://platform.deepseek.com/api-docs
|
||||
"""
|
||||
|
||||
provider = Provider.DEEPSEEK
|
||||
default_model = "deepseek-chat"
|
||||
|
||||
MODELS = [
|
||||
"deepseek-chat", # DeepSeek-V3
|
||||
"deepseek-reasoner", # DeepSeek-R1
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str | None = None,
|
||||
proxy: str | None = None,
|
||||
timeout: float = 60.0,
|
||||
max_retries: int = 3,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url or "https://api.deepseek.com",
|
||||
proxy=proxy, # 国内服务,通常不需要代理
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
provider=Provider.DEEPSEEK,
|
||||
)
|
||||
|
||||
def get_available_models(self) -> list[str]:
|
||||
return self.MODELS
|
||||
246
src/minenasai/llm/clients/gemini.py
Normal file
246
src/minenasai/llm/clients/gemini.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""Google Gemini 客户端
|
||||
|
||||
Gemini API 使用独特的接口格式
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
import httpx
|
||||
|
||||
from minenasai.core import get_logger
|
||||
from minenasai.llm.base import (
|
||||
BaseLLMClient,
|
||||
LLMResponse,
|
||||
Message,
|
||||
Provider,
|
||||
StreamChunk,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GeminiClient(BaseLLMClient):
|
||||
"""Google Gemini 客户端
|
||||
|
||||
官方文档: https://ai.google.dev/docs
|
||||
"""
|
||||
|
||||
provider = Provider.GEMINI
|
||||
default_model = "gemini-2.0-flash"
|
||||
|
||||
MODELS = [
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-thinking",
|
||||
"gemini-1.5-pro",
|
||||
"gemini-1.5-flash",
|
||||
"gemini-1.5-flash-8b",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str | None = None,
|
||||
proxy: str | None = None,
|
||||
timeout: float = 60.0,
|
||||
max_retries: int = 3,
|
||||
) -> None:
|
||||
super().__init__(api_key, base_url, proxy, timeout, max_retries)
|
||||
self.base_url = base_url or "https://generativelanguage.googleapis.com/v1beta"
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""获取 HTTP 客户端"""
|
||||
if self._client is None:
|
||||
transport = None
|
||||
if self.proxy:
|
||||
transport = httpx.AsyncHTTPTransport(proxy=self.proxy)
|
||||
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
timeout=self.timeout,
|
||||
transport=transport,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[Message],
|
||||
model: str | None = None,
|
||||
system: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""发送聊天请求"""
|
||||
client = await self._get_client()
|
||||
model = model or self.default_model
|
||||
|
||||
# 构建 Gemini 格式的请求
|
||||
contents = self._convert_messages_gemini(messages)
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"contents": contents,
|
||||
"generationConfig": {
|
||||
"maxOutputTokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
}
|
||||
|
||||
if system:
|
||||
payload["systemInstruction"] = {"parts": [{"text": system}]}
|
||||
|
||||
if tools:
|
||||
payload["tools"] = self._convert_tools_gemini(tools)
|
||||
|
||||
url = f"/models/{model}:generateContent?key={self.api_key}"
|
||||
|
||||
try:
|
||||
response = await client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# 解析响应
|
||||
candidates = data.get("candidates", [])
|
||||
if not candidates:
|
||||
return LLMResponse(
|
||||
content="",
|
||||
model=model,
|
||||
provider=self.provider,
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
candidate = candidates[0]
|
||||
content_parts = candidate.get("content", {}).get("parts", [])
|
||||
|
||||
text_parts = []
|
||||
tool_calls = []
|
||||
|
||||
for part in content_parts:
|
||||
if "text" in part:
|
||||
text_parts.append(part["text"])
|
||||
elif "functionCall" in part:
|
||||
fc = part["functionCall"]
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=fc.get("name", ""), # Gemini 没有独立 ID
|
||||
name=fc["name"],
|
||||
arguments=fc.get("args", {}),
|
||||
)
|
||||
)
|
||||
|
||||
usage_meta = data.get("usageMetadata", {})
|
||||
|
||||
return LLMResponse(
|
||||
content="\n".join(text_parts),
|
||||
model=model,
|
||||
provider=self.provider,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
finish_reason=candidate.get("finishReason", "STOP"),
|
||||
usage={
|
||||
"input_tokens": usage_meta.get("promptTokenCount", 0),
|
||||
"output_tokens": usage_meta.get("candidatesTokenCount", 0),
|
||||
},
|
||||
raw_response=data,
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error("Gemini API 错误", status=e.response.status_code, body=e.response.text)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Gemini 请求失败", error=str(e))
|
||||
raise
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[Message],
|
||||
model: str | None = None,
|
||||
system: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[StreamChunk]:
|
||||
"""流式聊天"""
|
||||
client = await self._get_client()
|
||||
model = model or self.default_model
|
||||
|
||||
contents = self._convert_messages_gemini(messages)
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"contents": contents,
|
||||
"generationConfig": {
|
||||
"maxOutputTokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
}
|
||||
|
||||
if system:
|
||||
payload["systemInstruction"] = {"parts": [{"text": system}]}
|
||||
|
||||
url = f"/models/{model}:streamGenerateContent?key={self.api_key}&alt=sse"
|
||||
|
||||
try:
|
||||
async with client.stream("POST", url, json=payload) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
|
||||
data = json.loads(line[6:])
|
||||
candidates = data.get("candidates", [])
|
||||
|
||||
if candidates:
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
yield StreamChunk(content=part["text"])
|
||||
|
||||
yield StreamChunk(content="", is_final=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Gemini 流式请求失败", error=str(e))
|
||||
raise
|
||||
|
||||
def _convert_messages_gemini(self, messages: list[Message]) -> list[dict[str, Any]]:
|
||||
"""转换消息为 Gemini 格式"""
|
||||
result = []
|
||||
for msg in messages:
|
||||
if msg.role == "system":
|
||||
continue # system 单独处理
|
||||
|
||||
role = "user" if msg.role == "user" else "model"
|
||||
result.append({
|
||||
"role": role,
|
||||
"parts": [{"text": msg.content}],
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def _convert_tools_gemini(self, tools: list[ToolDefinition]) -> list[dict[str, Any]]:
|
||||
"""转换工具定义为 Gemini 格式"""
|
||||
function_declarations = []
|
||||
for tool in tools:
|
||||
function_declarations.append({
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
})
|
||||
|
||||
return [{"functionDeclarations": function_declarations}]
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭客户端"""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
def get_available_models(self) -> list[str]:
|
||||
return self.MODELS
|
||||
78
src/minenasai/llm/clients/minimax.py
Normal file
78
src/minenasai/llm/clients/minimax.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""MiniMax 客户端
|
||||
|
||||
MiniMax 使用 OpenAI 兼容接口
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from minenasai.llm.base import Provider
|
||||
from minenasai.llm.clients.openai_compat import OpenAICompatClient
|
||||
|
||||
|
||||
class MiniMaxClient(OpenAICompatClient):
|
||||
"""MiniMax 客户端
|
||||
|
||||
MiniMax API 兼容 OpenAI 接口格式
|
||||
官方文档: https://platform.minimaxi.com/document
|
||||
"""
|
||||
|
||||
provider = Provider.MINIMAX
|
||||
default_model = "MiniMax-Text-01"
|
||||
|
||||
MODELS = [
|
||||
"MiniMax-Text-01", # 最新旗舰
|
||||
"abab6.5s-chat", # ABAB 6.5s
|
||||
"abab6.5g-chat", # ABAB 6.5g
|
||||
"abab6.5t-chat", # ABAB 6.5t (长文本)
|
||||
"abab5.5-chat", # ABAB 5.5
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
group_id: str | None = None,
|
||||
base_url: str | None = None,
|
||||
proxy: str | None = None,
|
||||
timeout: float = 60.0,
|
||||
max_retries: int = 3,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url or "https://api.minimax.chat/v1",
|
||||
proxy=proxy,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
provider=Provider.MINIMAX,
|
||||
)
|
||||
self.group_id = group_id
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""获取 HTTP 客户端(MiniMax 使用不同的认证头)"""
|
||||
if self._client is None:
|
||||
transport = None
|
||||
if self.proxy:
|
||||
transport = httpx.AsyncHTTPTransport(proxy=self.proxy)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# MiniMax 可能需要 GroupId
|
||||
if self.group_id:
|
||||
headers["GroupId"] = self.group_id
|
||||
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
timeout=self.timeout,
|
||||
transport=transport,
|
||||
headers=headers,
|
||||
)
|
||||
return self._client
|
||||
|
||||
def get_available_models(self) -> list[str]:
|
||||
return self.MODELS
|
||||
46
src/minenasai/llm/clients/moonshot.py
Normal file
46
src/minenasai/llm/clients/moonshot.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Moonshot Kimi 客户端
|
||||
|
||||
Moonshot 使用 OpenAI 兼容接口
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from minenasai.llm.base import Provider
|
||||
from minenasai.llm.clients.openai_compat import OpenAICompatClient
|
||||
|
||||
|
||||
class MoonshotClient(OpenAICompatClient):
|
||||
"""Moonshot Kimi 客户端
|
||||
|
||||
Moonshot API 兼容 OpenAI 接口格式
|
||||
官方文档: https://platform.moonshot.cn/docs
|
||||
"""
|
||||
|
||||
provider = Provider.MOONSHOT
|
||||
default_model = "moonshot-v1-8k"
|
||||
|
||||
MODELS = [
|
||||
"moonshot-v1-8k", # 8K 上下文
|
||||
"moonshot-v1-32k", # 32K 上下文
|
||||
"moonshot-v1-128k", # 128K 上下文
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str | None = None,
|
||||
proxy: str | None = None,
|
||||
timeout: float = 60.0,
|
||||
max_retries: int = 3,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url or "https://api.moonshot.cn/v1",
|
||||
proxy=proxy,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
provider=Provider.MOONSHOT,
|
||||
)
|
||||
|
||||
def get_available_models(self) -> list[str]:
|
||||
return self.MODELS
|
||||
209
src/minenasai/llm/clients/openai_compat.py
Normal file
209
src/minenasai/llm/clients/openai_compat.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""OpenAI 兼容客户端
|
||||
|
||||
支持 OpenAI 及所有兼容 OpenAI API 格式的服务
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
import httpx
|
||||
|
||||
from minenasai.core import get_logger
|
||||
from minenasai.llm.base import (
|
||||
BaseLLMClient,
|
||||
LLMResponse,
|
||||
Message,
|
||||
Provider,
|
||||
StreamChunk,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OpenAICompatClient(BaseLLMClient):
|
||||
"""OpenAI 兼容客户端
|
||||
|
||||
支持:
|
||||
- OpenAI 官方 API
|
||||
- Azure OpenAI
|
||||
- 任何兼容 OpenAI 接口的服务
|
||||
"""
|
||||
|
||||
provider = Provider.OPENAI
|
||||
default_model = "gpt-4o"
|
||||
|
||||
MODELS = [
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4",
|
||||
"gpt-3.5-turbo",
|
||||
"o1",
|
||||
"o1-mini",
|
||||
"o3-mini",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str | None = None,
|
||||
proxy: str | None = None,
|
||||
timeout: float = 60.0,
|
||||
max_retries: int = 3,
|
||||
provider: Provider = Provider.OPENAI,
|
||||
) -> None:
|
||||
super().__init__(api_key, base_url, proxy, timeout, max_retries)
|
||||
self.base_url = base_url or "https://api.openai.com/v1"
|
||||
self.provider = provider
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""获取 HTTP 客户端"""
|
||||
if self._client is None:
|
||||
transport = None
|
||||
if self.proxy:
|
||||
transport = httpx.AsyncHTTPTransport(proxy=self.proxy)
|
||||
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
timeout=self.timeout,
|
||||
transport=transport,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[Message],
|
||||
model: str | None = None,
|
||||
system: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""发送聊天请求"""
|
||||
client = await self._get_client()
|
||||
model = model or self.default_model
|
||||
|
||||
# 构建消息列表
|
||||
api_messages = []
|
||||
if system:
|
||||
api_messages.append({"role": "system", "content": system})
|
||||
api_messages.extend(self._convert_messages(messages))
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": api_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
if tools:
|
||||
payload["tools"] = self._convert_tools(tools)
|
||||
|
||||
try:
|
||||
response = await client.post("/chat/completions", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
choice = data["choices"][0]
|
||||
message = choice["message"]
|
||||
|
||||
# 解析工具调用
|
||||
tool_calls = None
|
||||
if message.get("tool_calls"):
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=tc["id"],
|
||||
name=tc["function"]["name"],
|
||||
arguments=json.loads(tc["function"]["arguments"]),
|
||||
)
|
||||
for tc in message["tool_calls"]
|
||||
]
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content", ""),
|
||||
model=model,
|
||||
provider=self.provider,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=choice.get("finish_reason", "stop"),
|
||||
usage=data.get("usage", {}),
|
||||
raw_response=data,
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error("OpenAI API 错误", status=e.response.status_code, body=e.response.text)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("OpenAI 请求失败", error=str(e))
|
||||
raise
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[Message],
|
||||
model: str | None = None,
|
||||
system: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[StreamChunk]:
|
||||
"""流式聊天"""
|
||||
client = await self._get_client()
|
||||
model = model or self.default_model
|
||||
|
||||
api_messages = []
|
||||
if system:
|
||||
api_messages.append({"role": "system", "content": system})
|
||||
api_messages.extend(self._convert_messages(messages))
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": api_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if tools:
|
||||
payload["tools"] = self._convert_tools(tools)
|
||||
|
||||
try:
|
||||
async with client.stream("POST", "/chat/completions", json=payload) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
yield StreamChunk(content="", is_final=True)
|
||||
break
|
||||
|
||||
data = json.loads(data_str)
|
||||
delta = data["choices"][0].get("delta", {})
|
||||
|
||||
if "content" in delta:
|
||||
yield StreamChunk(content=delta["content"])
|
||||
|
||||
except Exception as e:
|
||||
logger.error("OpenAI 流式请求失败", error=str(e))
|
||||
raise
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭客户端"""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
def get_available_models(self) -> list[str]:
|
||||
return self.MODELS
|
||||
51
src/minenasai/llm/clients/zhipu.py
Normal file
51
src/minenasai/llm/clients/zhipu.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""智谱 GLM 客户端
|
||||
|
||||
智谱 AI 使用 OpenAI 兼容接口
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from minenasai.llm.base import Provider
|
||||
from minenasai.llm.clients.openai_compat import OpenAICompatClient
|
||||
|
||||
|
||||
class ZhipuClient(OpenAICompatClient):
|
||||
"""智谱 GLM 客户端
|
||||
|
||||
智谱 API 兼容 OpenAI 接口格式
|
||||
官方文档: https://open.bigmodel.cn/dev/api
|
||||
"""
|
||||
|
||||
provider = Provider.ZHIPU
|
||||
default_model = "glm-4-plus"
|
||||
|
||||
MODELS = [
|
||||
"glm-4-plus", # 最新旗舰
|
||||
"glm-4-0520", # GLM-4
|
||||
"glm-4-air", # 高性价比
|
||||
"glm-4-airx", # 极速版
|
||||
"glm-4-long", # 长文本
|
||||
"glm-4-flash", # 免费版
|
||||
"glm-4v-plus", # 多模态
|
||||
"codegeex-4", # 代码模型
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str | None = None,
|
||||
proxy: str | None = None,
|
||||
timeout: float = 60.0,
|
||||
max_retries: int = 3,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url or "https://open.bigmodel.cn/api/paas/v4",
|
||||
proxy=proxy,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
provider=Provider.ZHIPU,
|
||||
)
|
||||
|
||||
def get_available_models(self) -> list[str]:
|
||||
return self.MODELS
|
||||
Reference in New Issue
Block a user