350 lines
11 KiB
Python
350 lines
11 KiB
Python
"""LLM 模块测试"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from minenasai.llm.base import (
|
||
LLMResponse,
|
||
Message,
|
||
Provider,
|
||
StreamChunk,
|
||
ToolCall,
|
||
ToolDefinition,
|
||
)
|
||
|
||
|
||
class TestProvider:
|
||
"""Provider 枚举测试"""
|
||
|
||
def test_is_overseas(self):
|
||
"""测试境外服务识别"""
|
||
assert Provider.ANTHROPIC.is_overseas is True
|
||
assert Provider.OPENAI.is_overseas is True
|
||
assert Provider.GEMINI.is_overseas is True
|
||
|
||
assert Provider.DEEPSEEK.is_overseas is False
|
||
assert Provider.ZHIPU.is_overseas is False
|
||
assert Provider.MINIMAX.is_overseas is False
|
||
assert Provider.MOONSHOT.is_overseas is False
|
||
|
||
def test_display_name(self):
|
||
"""测试显示名称"""
|
||
assert "Claude" in Provider.ANTHROPIC.display_name
|
||
assert "GPT" in Provider.OPENAI.display_name
|
||
assert "DeepSeek" in Provider.DEEPSEEK.display_name
|
||
assert "GLM" in Provider.ZHIPU.display_name
|
||
assert "Kimi" in Provider.MOONSHOT.display_name
|
||
|
||
|
||
class TestMessage:
|
||
"""消息测试"""
|
||
|
||
def test_basic_message(self):
|
||
"""测试基本消息"""
|
||
msg = Message(role="user", content="Hello")
|
||
assert msg.role == "user"
|
||
assert msg.content == "Hello"
|
||
assert msg.tool_calls is None
|
||
|
||
def test_message_with_tool_call(self):
|
||
"""测试带工具调用的消息"""
|
||
tool_call = ToolCall(
|
||
id="tc_123",
|
||
name="read_file",
|
||
arguments={"path": "/test.txt"},
|
||
)
|
||
msg = Message(
|
||
role="assistant",
|
||
content="Let me read that file.",
|
||
tool_calls=[tool_call],
|
||
)
|
||
assert len(msg.tool_calls) == 1
|
||
assert msg.tool_calls[0].name == "read_file"
|
||
|
||
|
||
class TestToolDefinition:
|
||
"""工具定义测试"""
|
||
|
||
def test_tool_definition(self):
|
||
"""测试工具定义"""
|
||
tool = ToolDefinition(
|
||
name="read_file",
|
||
description="Read a file",
|
||
parameters={
|
||
"type": "object",
|
||
"properties": {
|
||
"path": {"type": "string"},
|
||
},
|
||
"required": ["path"],
|
||
},
|
||
)
|
||
assert tool.name == "read_file"
|
||
assert "path" in tool.parameters["properties"]
|
||
|
||
|
||
class TestClientImports:
|
||
"""客户端导入测试"""
|
||
|
||
def test_import_all_clients(self):
|
||
"""测试导入所有客户端"""
|
||
from minenasai.llm.clients import (
|
||
AnthropicClient,
|
||
DeepSeekClient,
|
||
GeminiClient,
|
||
MiniMaxClient,
|
||
MoonshotClient,
|
||
OpenAICompatClient,
|
||
ZhipuClient,
|
||
)
|
||
|
||
assert AnthropicClient.provider == Provider.ANTHROPIC
|
||
assert OpenAICompatClient.provider == Provider.OPENAI
|
||
assert DeepSeekClient.provider == Provider.DEEPSEEK
|
||
assert ZhipuClient.provider == Provider.ZHIPU
|
||
assert MiniMaxClient.provider == Provider.MINIMAX
|
||
assert MoonshotClient.provider == Provider.MOONSHOT
|
||
assert GeminiClient.provider == Provider.GEMINI
|
||
|
||
def test_client_models(self):
|
||
"""测试客户端模型列表"""
|
||
from minenasai.llm.clients import (
|
||
AnthropicClient,
|
||
DeepSeekClient,
|
||
ZhipuClient,
|
||
)
|
||
|
||
assert "claude-sonnet-4-20250514" in AnthropicClient.MODELS
|
||
assert "deepseek-chat" in DeepSeekClient.MODELS
|
||
assert "glm-4-plus" in ZhipuClient.MODELS
|
||
|
||
|
||
class TestLLMManager:
|
||
"""LLM 管理器测试"""
|
||
|
||
def test_import_manager(self):
|
||
"""测试导入管理器"""
|
||
from minenasai.llm import LLMManager, get_llm_manager
|
||
|
||
manager = get_llm_manager()
|
||
assert isinstance(manager, LLMManager)
|
||
|
||
def test_no_api_keys(self):
|
||
"""测试无 API Key 时的行为"""
|
||
from minenasai.llm import LLMManager
|
||
|
||
manager = LLMManager()
|
||
manager.initialize()
|
||
|
||
# 没有配置 API Key,应该没有可用的提供商
|
||
providers = manager.get_available_providers()
|
||
# 可能为空,取决于环境变量
|
||
assert isinstance(providers, list)
|
||
|
||
|
||
class TestOpenAICompatClientMock:
|
||
"""OpenAI 兼容客户端 Mock 测试"""
|
||
|
||
@pytest.fixture
|
||
def client(self):
|
||
"""创建测试客户端"""
|
||
from minenasai.llm.clients import OpenAICompatClient
|
||
|
||
return OpenAICompatClient(api_key="test-key", base_url="https://api.test.com/v1")
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chat_mock(self, client):
|
||
"""测试聊天功能(Mock)"""
|
||
# Mock HTTP 响应
|
||
mock_response = MagicMock()
|
||
mock_response.status_code = 200
|
||
mock_response.json.return_value = {
|
||
"id": "chatcmpl-123",
|
||
"object": "chat.completion",
|
||
"created": 1677652288,
|
||
"model": "gpt-4o",
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"message": {
|
||
"role": "assistant",
|
||
"content": "Hello! How can I help you?",
|
||
},
|
||
"finish_reason": "stop",
|
||
}
|
||
],
|
||
"usage": {
|
||
"prompt_tokens": 10,
|
||
"completion_tokens": 20,
|
||
"total_tokens": 30,
|
||
},
|
||
}
|
||
|
||
with patch.object(client, "_get_client") as mock_get_client:
|
||
mock_http_client = AsyncMock()
|
||
mock_http_client.post = AsyncMock(return_value=mock_response)
|
||
mock_get_client.return_value = mock_http_client
|
||
|
||
messages = [Message(role="user", content="Hello")]
|
||
response = await client.chat(messages)
|
||
|
||
assert isinstance(response, LLMResponse)
|
||
assert response.content == "Hello! How can I help you?"
|
||
assert response.model == "gpt-4o"
|
||
assert response.finish_reason == "stop"
|
||
assert response.usage["total_tokens"] == 30
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chat_with_tools_mock(self, client):
|
||
"""测试带工具调用的聊天(Mock)"""
|
||
mock_response = MagicMock()
|
||
mock_response.status_code = 200
|
||
mock_response.json.return_value = {
|
||
"id": "chatcmpl-456",
|
||
"model": "gpt-4o",
|
||
"choices": [
|
||
{
|
||
"message": {
|
||
"role": "assistant",
|
||
"content": None,
|
||
"tool_calls": [
|
||
{
|
||
"id": "call_123",
|
||
"type": "function",
|
||
"function": {
|
||
"name": "read_file",
|
||
"arguments": '{"path": "/tmp/test.txt"}',
|
||
},
|
||
}
|
||
],
|
||
},
|
||
"finish_reason": "tool_calls",
|
||
}
|
||
],
|
||
"usage": {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40},
|
||
}
|
||
|
||
with patch.object(client, "_get_client") as mock_get_client:
|
||
mock_http_client = AsyncMock()
|
||
mock_http_client.post = AsyncMock(return_value=mock_response)
|
||
mock_get_client.return_value = mock_http_client
|
||
|
||
messages = [Message(role="user", content="Read the test file")]
|
||
tools = [
|
||
ToolDefinition(
|
||
name="read_file",
|
||
description="Read a file",
|
||
parameters={"type": "object", "properties": {"path": {"type": "string"}}},
|
||
)
|
||
]
|
||
response = await client.chat(messages, tools=tools)
|
||
|
||
assert response.tool_calls is not None
|
||
assert len(response.tool_calls) == 1
|
||
assert response.tool_calls[0].name == "read_file"
|
||
assert response.tool_calls[0].arguments == {"path": "/tmp/test.txt"}
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_close_client(self, client):
|
||
"""测试关闭客户端"""
|
||
# 创建 mock 客户端
|
||
mock_http_client = AsyncMock()
|
||
client._client = mock_http_client
|
||
|
||
await client.close()
|
||
|
||
mock_http_client.aclose.assert_called_once()
|
||
assert client._client is None
|
||
|
||
|
||
class TestAnthropicClientMock:
|
||
"""Anthropic 客户端 Mock 测试"""
|
||
|
||
@pytest.fixture
|
||
def client(self):
|
||
"""创建测试客户端"""
|
||
from minenasai.llm.clients import AnthropicClient
|
||
|
||
return AnthropicClient(api_key="test-key")
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chat_mock(self, client):
|
||
"""测试聊天功能(Mock)"""
|
||
mock_response = MagicMock()
|
||
mock_response.status_code = 200
|
||
mock_response.json.return_value = {
|
||
"id": "msg_123",
|
||
"type": "message",
|
||
"role": "assistant",
|
||
"content": [{"type": "text", "text": "Hello from Claude!"}],
|
||
"model": "claude-sonnet-4-20250514",
|
||
"stop_reason": "end_turn",
|
||
"usage": {"input_tokens": 10, "output_tokens": 15},
|
||
}
|
||
|
||
with patch.object(client, "_get_client") as mock_get_client:
|
||
mock_http_client = AsyncMock()
|
||
mock_http_client.post = AsyncMock(return_value=mock_response)
|
||
mock_get_client.return_value = mock_http_client
|
||
|
||
messages = [Message(role="user", content="Hello")]
|
||
response = await client.chat(messages)
|
||
|
||
assert isinstance(response, LLMResponse)
|
||
assert response.content == "Hello from Claude!"
|
||
assert response.provider == Provider.ANTHROPIC
|
||
|
||
|
||
class TestLLMResponse:
|
||
"""LLM 响应测试"""
|
||
|
||
def test_response_basic(self):
|
||
"""测试基本响应"""
|
||
response = LLMResponse(
|
||
content="Test response",
|
||
model="gpt-4o",
|
||
provider=Provider.OPENAI,
|
||
)
|
||
assert response.content == "Test response"
|
||
assert response.model == "gpt-4o"
|
||
assert response.provider == Provider.OPENAI
|
||
assert response.finish_reason == "stop"
|
||
|
||
def test_response_with_tool_calls(self):
|
||
"""测试带工具调用的响应"""
|
||
tool_calls = [
|
||
ToolCall(id="tc_1", name="read_file", arguments={"path": "/test"}),
|
||
ToolCall(id="tc_2", name="list_dir", arguments={"path": "/"}),
|
||
]
|
||
response = LLMResponse(
|
||
content="",
|
||
model="claude-sonnet-4-20250514",
|
||
provider=Provider.ANTHROPIC,
|
||
tool_calls=tool_calls,
|
||
finish_reason="tool_use",
|
||
)
|
||
assert len(response.tool_calls) == 2
|
||
assert response.finish_reason == "tool_use"
|
||
|
||
|
||
class TestStreamChunk:
|
||
"""流式响应块测试"""
|
||
|
||
def test_chunk_basic(self):
|
||
"""测试基本响应块"""
|
||
chunk = StreamChunk(content="Hello")
|
||
assert chunk.content == "Hello"
|
||
assert chunk.is_final is False
|
||
|
||
def test_chunk_final(self):
|
||
"""测试最终响应块"""
|
||
chunk = StreamChunk(
|
||
content="",
|
||
is_final=True,
|
||
usage={"prompt_tokens": 10, "completion_tokens": 20},
|
||
)
|
||
assert chunk.is_final is True
|
||
assert chunk.usage is not None
|