feat: 更新模块文档,添加详细说明和使用示例
This commit is contained in:
@@ -4,8 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from minenasai.core import Settings, get_settings, load_config, reset_settings
|
||||
from minenasai.core.config import expand_path
|
||||
|
||||
|
||||
345
tests/test_database.py
Normal file
345
tests/test_database.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""数据库模块测试
|
||||
|
||||
测试 SQLite 异步数据库操作
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from minenasai.core.database import Database
|
||||
|
||||
|
||||
class TestDatabase:
|
||||
"""数据库基本操作测试"""
|
||||
|
||||
@pytest.fixture
|
||||
async def db(self, tmp_path):
|
||||
"""创建临时数据库"""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = Database(db_path)
|
||||
await db.connect()
|
||||
yield db
|
||||
await db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_and_close(self, tmp_path):
|
||||
"""测试数据库连接和关闭"""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = Database(db_path)
|
||||
|
||||
# 连接前访问 conn 应该报错
|
||||
with pytest.raises(RuntimeError, match="数据库未连接"):
|
||||
_ = db.conn
|
||||
|
||||
# 连接
|
||||
await db.connect()
|
||||
assert db.conn is not None
|
||||
|
||||
# 关闭
|
||||
await db.close()
|
||||
assert db._conn is None
|
||||
|
||||
|
||||
class TestAgentOperations:
|
||||
"""Agent 操作测试"""
|
||||
|
||||
@pytest.fixture
|
||||
async def db(self, tmp_path):
|
||||
"""创建临时数据库"""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = Database(db_path)
|
||||
await db.connect()
|
||||
yield db
|
||||
await db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent(self, db):
|
||||
"""测试创建 Agent"""
|
||||
agent = await db.create_agent(
|
||||
agent_id="test-agent-1",
|
||||
name="Test Agent",
|
||||
workspace_path="/tmp/workspace",
|
||||
)
|
||||
|
||||
assert agent["id"] == "test-agent-1"
|
||||
assert agent["name"] == "Test Agent"
|
||||
assert agent["workspace_path"] == "/tmp/workspace"
|
||||
assert agent["model"] == "claude-sonnet-4-20250514"
|
||||
assert agent["sandbox_mode"] == "workspace"
|
||||
assert "created_at" in agent
|
||||
assert "updated_at" in agent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_custom_model(self, db):
|
||||
"""测试使用自定义模型创建 Agent"""
|
||||
agent = await db.create_agent(
|
||||
agent_id="test-agent-2",
|
||||
name="Custom Agent",
|
||||
workspace_path="/tmp/workspace",
|
||||
model="gpt-4",
|
||||
sandbox_mode="strict",
|
||||
)
|
||||
|
||||
assert agent["model"] == "gpt-4"
|
||||
assert agent["sandbox_mode"] == "strict"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_agent(self, db):
|
||||
"""测试获取 Agent"""
|
||||
# 创建 Agent
|
||||
await db.create_agent(
|
||||
agent_id="test-agent-3",
|
||||
name="Get Test Agent",
|
||||
workspace_path="/tmp/workspace",
|
||||
)
|
||||
|
||||
# 获取存在的 Agent
|
||||
agent = await db.get_agent("test-agent-3")
|
||||
assert agent is not None
|
||||
assert agent["name"] == "Get Test Agent"
|
||||
|
||||
# 获取不存在的 Agent
|
||||
agent = await db.get_agent("nonexistent")
|
||||
assert agent is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents(self, db):
|
||||
"""测试列出所有 Agent"""
|
||||
# 初始为空
|
||||
agents = await db.list_agents()
|
||||
assert len(agents) == 0
|
||||
|
||||
# 创建多个 Agent
|
||||
await db.create_agent("agent-1", "Agent 1", "/tmp/1")
|
||||
await db.create_agent("agent-2", "Agent 2", "/tmp/2")
|
||||
await db.create_agent("agent-3", "Agent 3", "/tmp/3")
|
||||
|
||||
# 列出所有
|
||||
agents = await db.list_agents()
|
||||
assert len(agents) == 3
|
||||
|
||||
|
||||
class TestSessionOperations:
|
||||
"""Session 操作测试"""
|
||||
|
||||
@pytest.fixture
|
||||
async def db(self, tmp_path):
|
||||
"""创建临时数据库并添加测试 Agent"""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = Database(db_path)
|
||||
await db.connect()
|
||||
await db.create_agent("test-agent", "Test Agent", "/tmp/workspace")
|
||||
yield db
|
||||
await db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session(self, db):
|
||||
"""测试创建会话"""
|
||||
session = await db.create_session(
|
||||
agent_id="test-agent",
|
||||
channel="websocket",
|
||||
peer_id="user-123",
|
||||
)
|
||||
|
||||
assert session["agent_id"] == "test-agent"
|
||||
assert session["channel"] == "websocket"
|
||||
assert session["peer_id"] == "user-123"
|
||||
assert session["status"] == "active"
|
||||
assert "session_key" in session
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_with_metadata(self, db):
|
||||
"""测试创建带元数据的会话"""
|
||||
metadata = {"client": "web", "version": "1.0"}
|
||||
session = await db.create_session(
|
||||
agent_id="test-agent",
|
||||
channel="websocket",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
assert session["metadata"] == metadata
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session(self, db):
|
||||
"""测试获取会话"""
|
||||
# 创建会话
|
||||
created = await db.create_session(
|
||||
agent_id="test-agent",
|
||||
channel="websocket",
|
||||
metadata={"test": "data"},
|
||||
)
|
||||
|
||||
# 获取会话
|
||||
session = await db.get_session(created["id"])
|
||||
assert session is not None
|
||||
assert session["id"] == created["id"]
|
||||
assert session["metadata"] == {"test": "data"}
|
||||
|
||||
# 获取不存在的会话
|
||||
session = await db.get_session("nonexistent-id")
|
||||
assert session is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_session_status(self, db):
|
||||
"""测试更新会话状态"""
|
||||
# 创建会话
|
||||
session = await db.create_session(
|
||||
agent_id="test-agent",
|
||||
channel="websocket",
|
||||
)
|
||||
assert session["status"] == "active"
|
||||
|
||||
# 更新状态
|
||||
await db.update_session_status(session["id"], "closed")
|
||||
|
||||
# 验证
|
||||
updated = await db.get_session(session["id"])
|
||||
assert updated["status"] == "closed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_active_sessions(self, db):
|
||||
"""测试列出活跃会话"""
|
||||
# 创建多个会话
|
||||
session1 = await db.create_session("test-agent", "websocket")
|
||||
session2 = await db.create_session("test-agent", "feishu")
|
||||
await db.create_session("test-agent", "wework")
|
||||
|
||||
# 关闭一个会话
|
||||
await db.update_session_status(session2["id"], "closed")
|
||||
|
||||
# 列出所有活跃会话
|
||||
active = await db.list_active_sessions()
|
||||
assert len(active) == 2
|
||||
|
||||
# 按 agent_id 过滤
|
||||
active = await db.list_active_sessions("test-agent")
|
||||
assert len(active) == 2
|
||||
|
||||
|
||||
class TestMessageOperations:
|
||||
"""Message 操作测试"""
|
||||
|
||||
@pytest.fixture
|
||||
async def db(self, tmp_path):
|
||||
"""创建临时数据库并添加测试会话"""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = Database(db_path)
|
||||
await db.connect()
|
||||
await db.create_agent("test-agent", "Test Agent", "/tmp/workspace")
|
||||
session = await db.create_session("test-agent", "websocket")
|
||||
db._test_session_id = session["id"]
|
||||
yield db
|
||||
await db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_message(self, db):
|
||||
"""测试添加消息"""
|
||||
session_id = db._test_session_id
|
||||
|
||||
message = await db.add_message(
|
||||
session_id=session_id,
|
||||
role="user",
|
||||
content="Hello, world!",
|
||||
)
|
||||
|
||||
assert message["session_id"] == session_id
|
||||
assert message["role"] == "user"
|
||||
assert message["content"] == "Hello, world!"
|
||||
assert message["tokens_used"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_message_with_tool_calls(self, db):
|
||||
"""测试添加带工具调用的消息"""
|
||||
session_id = db._test_session_id
|
||||
|
||||
tool_calls = [
|
||||
{"id": "call-1", "name": "read_file", "arguments": {"path": "/tmp/test"}}
|
||||
]
|
||||
message = await db.add_message(
|
||||
session_id=session_id,
|
||||
role="assistant",
|
||||
content=None,
|
||||
tool_calls=tool_calls,
|
||||
tokens_used=100,
|
||||
)
|
||||
|
||||
assert message["tool_calls"] == tool_calls
|
||||
assert message["tokens_used"] == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_messages(self, db):
|
||||
"""测试获取会话消息"""
|
||||
session_id = db._test_session_id
|
||||
|
||||
# 添加多条消息
|
||||
await db.add_message(session_id, "user", "Message 1")
|
||||
await db.add_message(session_id, "assistant", "Response 1")
|
||||
await db.add_message(session_id, "user", "Message 2")
|
||||
await db.add_message(session_id, "assistant", "Response 2")
|
||||
|
||||
# 获取所有消息
|
||||
messages = await db.get_messages(session_id)
|
||||
assert len(messages) == 4
|
||||
# 验证角色交替
|
||||
roles = [m["role"] for m in messages]
|
||||
assert roles.count("user") == 2
|
||||
assert roles.count("assistant") == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_messages_with_limit(self, db):
|
||||
"""测试获取消息数量限制"""
|
||||
session_id = db._test_session_id
|
||||
|
||||
# 添加多条消息
|
||||
for i in range(10):
|
||||
await db.add_message(session_id, "user", f"Message {i}")
|
||||
|
||||
# 限制返回数量
|
||||
messages = await db.get_messages(session_id, limit=5)
|
||||
assert len(messages) == 5
|
||||
|
||||
|
||||
class TestAuditLog:
|
||||
"""审计日志测试"""
|
||||
|
||||
@pytest.fixture
|
||||
async def db(self, tmp_path):
|
||||
"""创建临时数据库"""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = Database(db_path)
|
||||
await db.connect()
|
||||
yield db
|
||||
await db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_audit_log(self, db):
|
||||
"""测试添加审计日志"""
|
||||
# 添加日志不应该报错
|
||||
await db.add_audit_log(
|
||||
agent_id="test-agent",
|
||||
tool_name="read_file",
|
||||
danger_level="safe",
|
||||
params={"path": "/tmp/test.txt"},
|
||||
result="success",
|
||||
duration_ms=50,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_audit_log_minimal(self, db):
|
||||
"""测试添加最小审计日志"""
|
||||
await db.add_audit_log(
|
||||
agent_id=None,
|
||||
tool_name="python_eval",
|
||||
danger_level="low",
|
||||
)
|
||||
|
||||
|
||||
class TestGlobalDatabase:
|
||||
"""全局数据库实例测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_functions(self):
|
||||
"""测试导入全局函数"""
|
||||
from minenasai.core.database import close_database, get_database
|
||||
|
||||
assert callable(get_database)
|
||||
assert callable(close_database)
|
||||
@@ -142,3 +142,159 @@ class TestRouterEdgeCases:
|
||||
result = self.router.evaluate("查看 /tmp/test.txt 文件内容")
|
||||
|
||||
assert result["complexity"] in [TaskComplexity.SIMPLE, TaskComplexity.MEDIUM]
|
||||
|
||||
|
||||
class TestConnectionManager:
|
||||
"""WebSocket 连接管理器测试"""
|
||||
|
||||
def test_import_manager(self):
|
||||
"""测试导入连接管理器"""
|
||||
from minenasai.gateway.server import ConnectionManager
|
||||
|
||||
manager = ConnectionManager()
|
||||
assert manager.active_connections == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_and_disconnect(self):
|
||||
"""测试连接和断开"""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from minenasai.gateway.server import ConnectionManager
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
# Mock WebSocket
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.accept = AsyncMock()
|
||||
|
||||
# 连接
|
||||
await manager.connect(mock_ws, "client-1")
|
||||
assert "client-1" in manager.active_connections
|
||||
mock_ws.accept.assert_called_once()
|
||||
|
||||
# 断开
|
||||
manager.disconnect("client-1")
|
||||
assert "client-1" not in manager.active_connections
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_nonexistent(self):
|
||||
"""测试断开不存在的连接"""
|
||||
from minenasai.gateway.server import ConnectionManager
|
||||
|
||||
manager = ConnectionManager()
|
||||
# 不应该抛出异常
|
||||
manager.disconnect("nonexistent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message(self):
|
||||
"""测试发送消息"""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from minenasai.gateway.server import ConnectionManager
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
# Mock WebSocket
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.accept = AsyncMock()
|
||||
mock_ws.send_json = AsyncMock()
|
||||
|
||||
# 连接
|
||||
await manager.connect(mock_ws, "client-1")
|
||||
|
||||
# 发送消息
|
||||
await manager.send_message("client-1", {"type": "test"})
|
||||
mock_ws.send_json.assert_called_once_with({"type": "test"})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_to_nonexistent(self):
|
||||
"""测试发送消息给不存在的客户端"""
|
||||
from minenasai.gateway.server import ConnectionManager
|
||||
|
||||
manager = ConnectionManager()
|
||||
# 不应该抛出异常
|
||||
await manager.send_message("nonexistent", {"type": "test"})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast(self):
|
||||
"""测试广播消息"""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from minenasai.gateway.server import ConnectionManager
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
# Mock 多个 WebSocket
|
||||
mock_ws1 = AsyncMock()
|
||||
mock_ws1.accept = AsyncMock()
|
||||
mock_ws1.send_json = AsyncMock()
|
||||
|
||||
mock_ws2 = AsyncMock()
|
||||
mock_ws2.accept = AsyncMock()
|
||||
mock_ws2.send_json = AsyncMock()
|
||||
|
||||
# 连接
|
||||
await manager.connect(mock_ws1, "client-1")
|
||||
await manager.connect(mock_ws2, "client-2")
|
||||
|
||||
# 广播
|
||||
await manager.broadcast({"type": "broadcast"})
|
||||
mock_ws1.send_json.assert_called_once_with({"type": "broadcast"})
|
||||
mock_ws2.send_json.assert_called_once_with({"type": "broadcast"})
|
||||
|
||||
|
||||
class TestGatewayServer:
|
||||
"""Gateway 服务器测试"""
|
||||
|
||||
def test_import_app(self):
|
||||
"""测试导入应用"""
|
||||
from minenasai.gateway.server import app
|
||||
|
||||
assert app is not None
|
||||
assert app.title == "MineNASAI Gateway"
|
||||
|
||||
def test_import_endpoints(self):
|
||||
"""测试导入端点函数"""
|
||||
from minenasai.gateway.server import list_agents, list_sessions, root
|
||||
|
||||
assert callable(root)
|
||||
assert callable(list_agents)
|
||||
assert callable(list_sessions)
|
||||
|
||||
|
||||
class TestMessageTypes:
|
||||
"""消息类型测试"""
|
||||
|
||||
def test_status_message(self):
|
||||
"""测试状态消息"""
|
||||
from minenasai.gateway.protocol import StatusMessage
|
||||
|
||||
msg = StatusMessage(status="thinking", message="处理中...")
|
||||
assert msg.type == MessageType.STATUS
|
||||
assert msg.status == "thinking"
|
||||
assert msg.message == "处理中..."
|
||||
|
||||
def test_response_message(self):
|
||||
"""测试响应消息"""
|
||||
from minenasai.gateway.protocol import ResponseMessage
|
||||
|
||||
msg = ResponseMessage(content="Hello!", in_reply_to="msg-123")
|
||||
assert msg.type == MessageType.RESPONSE
|
||||
assert msg.content == "Hello!"
|
||||
assert msg.in_reply_to == "msg-123"
|
||||
|
||||
def test_error_message(self):
|
||||
"""测试错误消息"""
|
||||
from minenasai.gateway.protocol import ErrorMessage
|
||||
|
||||
msg = ErrorMessage(message="Something went wrong", code="ERR_001")
|
||||
assert msg.type == MessageType.ERROR
|
||||
assert msg.message == "Something went wrong"
|
||||
assert msg.code == "ERR_001"
|
||||
|
||||
def test_pong_message(self):
|
||||
"""测试心跳响应消息"""
|
||||
from minenasai.gateway.protocol import PongMessage
|
||||
|
||||
msg = PongMessage()
|
||||
assert msg.type == MessageType.PONG
|
||||
|
||||
@@ -2,9 +2,18 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from minenasai.llm.base import Message, Provider, ToolCall, ToolDefinition
|
||||
from minenasai.llm.base import (
|
||||
LLMResponse,
|
||||
Message,
|
||||
Provider,
|
||||
StreamChunk,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
)
|
||||
|
||||
|
||||
class TestProvider:
|
||||
@@ -133,3 +142,208 @@ class TestLLMManager:
|
||||
providers = manager.get_available_providers()
|
||||
# 可能为空,取决于环境变量
|
||||
assert isinstance(providers, list)
|
||||
|
||||
|
||||
class TestOpenAICompatClientMock:
|
||||
"""OpenAI 兼容客户端 Mock 测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
"""创建测试客户端"""
|
||||
from minenasai.llm.clients import OpenAICompatClient
|
||||
|
||||
return OpenAICompatClient(api_key="test-key", base_url="https://api.test.com/v1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_mock(self, client):
|
||||
"""测试聊天功能(Mock)"""
|
||||
# Mock HTTP 响应
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1677652288,
|
||||
"model": "gpt-4o",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I help you?",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30,
|
||||
},
|
||||
}
|
||||
|
||||
with patch.object(client, "_get_client") as mock_get_client:
|
||||
mock_http_client = AsyncMock()
|
||||
mock_http_client.post = AsyncMock(return_value=mock_response)
|
||||
mock_get_client.return_value = mock_http_client
|
||||
|
||||
messages = [Message(role="user", content="Hello")]
|
||||
response = await client.chat(messages)
|
||||
|
||||
assert isinstance(response, LLMResponse)
|
||||
assert response.content == "Hello! How can I help you?"
|
||||
assert response.model == "gpt-4o"
|
||||
assert response.finish_reason == "stop"
|
||||
assert response.usage["total_tokens"] == 30
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_tools_mock(self, client):
|
||||
"""测试带工具调用的聊天(Mock)"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"id": "chatcmpl-456",
|
||||
"model": "gpt-4o",
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read_file",
|
||||
"arguments": '{"path": "/tmp/test.txt"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40},
|
||||
}
|
||||
|
||||
with patch.object(client, "_get_client") as mock_get_client:
|
||||
mock_http_client = AsyncMock()
|
||||
mock_http_client.post = AsyncMock(return_value=mock_response)
|
||||
mock_get_client.return_value = mock_http_client
|
||||
|
||||
messages = [Message(role="user", content="Read the test file")]
|
||||
tools = [
|
||||
ToolDefinition(
|
||||
name="read_file",
|
||||
description="Read a file",
|
||||
parameters={"type": "object", "properties": {"path": {"type": "string"}}},
|
||||
)
|
||||
]
|
||||
response = await client.chat(messages, tools=tools)
|
||||
|
||||
assert response.tool_calls is not None
|
||||
assert len(response.tool_calls) == 1
|
||||
assert response.tool_calls[0].name == "read_file"
|
||||
assert response.tool_calls[0].arguments == {"path": "/tmp/test.txt"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_client(self, client):
|
||||
"""测试关闭客户端"""
|
||||
# 创建 mock 客户端
|
||||
mock_http_client = AsyncMock()
|
||||
client._client = mock_http_client
|
||||
|
||||
await client.close()
|
||||
|
||||
mock_http_client.aclose.assert_called_once()
|
||||
assert client._client is None
|
||||
|
||||
|
||||
class TestAnthropicClientMock:
|
||||
"""Anthropic 客户端 Mock 测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
"""创建测试客户端"""
|
||||
from minenasai.llm.clients import AnthropicClient
|
||||
|
||||
return AnthropicClient(api_key="test-key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_mock(self, client):
|
||||
"""测试聊天功能(Mock)"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"id": "msg_123",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "Hello from Claude!"}],
|
||||
"model": "claude-sonnet-4-20250514",
|
||||
"stop_reason": "end_turn",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 15},
|
||||
}
|
||||
|
||||
with patch.object(client, "_get_client") as mock_get_client:
|
||||
mock_http_client = AsyncMock()
|
||||
mock_http_client.post = AsyncMock(return_value=mock_response)
|
||||
mock_get_client.return_value = mock_http_client
|
||||
|
||||
messages = [Message(role="user", content="Hello")]
|
||||
response = await client.chat(messages)
|
||||
|
||||
assert isinstance(response, LLMResponse)
|
||||
assert response.content == "Hello from Claude!"
|
||||
assert response.provider == Provider.ANTHROPIC
|
||||
|
||||
|
||||
class TestLLMResponse:
|
||||
"""LLM 响应测试"""
|
||||
|
||||
def test_response_basic(self):
|
||||
"""测试基本响应"""
|
||||
response = LLMResponse(
|
||||
content="Test response",
|
||||
model="gpt-4o",
|
||||
provider=Provider.OPENAI,
|
||||
)
|
||||
assert response.content == "Test response"
|
||||
assert response.model == "gpt-4o"
|
||||
assert response.provider == Provider.OPENAI
|
||||
assert response.finish_reason == "stop"
|
||||
|
||||
def test_response_with_tool_calls(self):
|
||||
"""测试带工具调用的响应"""
|
||||
tool_calls = [
|
||||
ToolCall(id="tc_1", name="read_file", arguments={"path": "/test"}),
|
||||
ToolCall(id="tc_2", name="list_dir", arguments={"path": "/"}),
|
||||
]
|
||||
response = LLMResponse(
|
||||
content="",
|
||||
model="claude-sonnet-4-20250514",
|
||||
provider=Provider.ANTHROPIC,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason="tool_use",
|
||||
)
|
||||
assert len(response.tool_calls) == 2
|
||||
assert response.finish_reason == "tool_use"
|
||||
|
||||
|
||||
class TestStreamChunk:
|
||||
"""流式响应块测试"""
|
||||
|
||||
def test_chunk_basic(self):
|
||||
"""测试基本响应块"""
|
||||
chunk = StreamChunk(content="Hello")
|
||||
assert chunk.content == "Hello"
|
||||
assert chunk.is_final is False
|
||||
|
||||
def test_chunk_final(self):
|
||||
"""测试最终响应块"""
|
||||
chunk = StreamChunk(
|
||||
content="",
|
||||
is_final=True,
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 20},
|
||||
)
|
||||
assert chunk.is_final is True
|
||||
assert chunk.usage is not None
|
||||
|
||||
@@ -30,7 +30,7 @@ class TestToolPermission:
|
||||
name="test_tool",
|
||||
danger_level=DangerLevel.SAFE,
|
||||
)
|
||||
|
||||
|
||||
assert perm.requires_confirmation is False
|
||||
assert perm.rate_limit is None
|
||||
|
||||
@@ -45,7 +45,7 @@ class TestPermissionManager:
|
||||
def test_get_default_permission(self):
|
||||
"""测试获取默认权限"""
|
||||
perm = self.manager.get_permission("read_file")
|
||||
|
||||
|
||||
assert perm is not None
|
||||
assert perm.danger_level == DangerLevel.SAFE
|
||||
|
||||
@@ -57,7 +57,7 @@ class TestPermissionManager:
|
||||
description="自定义工具",
|
||||
)
|
||||
self.manager.register_tool(perm)
|
||||
|
||||
|
||||
result = self.manager.get_permission("custom_tool")
|
||||
assert result is not None
|
||||
assert result.danger_level == DangerLevel.MEDIUM
|
||||
@@ -65,13 +65,13 @@ class TestPermissionManager:
|
||||
def test_check_permission_allowed(self):
|
||||
"""测试权限检查 - 允许"""
|
||||
allowed, reason = self.manager.check_permission("read_file")
|
||||
|
||||
|
||||
assert allowed is True
|
||||
|
||||
def test_check_permission_unknown_tool(self):
|
||||
"""测试权限检查 - 未知工具"""
|
||||
allowed, reason = self.manager.check_permission("unknown_tool")
|
||||
|
||||
|
||||
assert allowed is False
|
||||
assert "未知工具" in reason
|
||||
|
||||
@@ -83,12 +83,12 @@ class TestPermissionManager:
|
||||
denied_paths=["/etc/", "/root/"],
|
||||
)
|
||||
self.manager.register_tool(perm)
|
||||
|
||||
|
||||
allowed, reason = self.manager.check_permission(
|
||||
"restricted_read",
|
||||
params={"path": "/etc/passwd"},
|
||||
)
|
||||
|
||||
|
||||
assert allowed is False
|
||||
assert "禁止访问" in reason
|
||||
|
||||
@@ -96,7 +96,7 @@ class TestPermissionManager:
|
||||
"""测试确认要求 - 按等级"""
|
||||
# HIGH 级别需要确认
|
||||
assert self.manager.requires_confirmation("delete_file") is True
|
||||
|
||||
|
||||
# SAFE 级别不需要确认
|
||||
assert self.manager.requires_confirmation("read_file") is False
|
||||
|
||||
@@ -108,7 +108,7 @@ class TestPermissionManager:
|
||||
requires_confirmation=True,
|
||||
)
|
||||
self.manager.register_tool(perm)
|
||||
|
||||
|
||||
assert self.manager.requires_confirmation("explicit_confirm") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -119,9 +119,9 @@ class TestPermissionManager:
|
||||
tool_name="delete_file",
|
||||
params={"path": "/test.txt"},
|
||||
)
|
||||
|
||||
|
||||
assert request.status == ConfirmationStatus.PENDING
|
||||
|
||||
|
||||
# 批准
|
||||
self.manager.approve_confirmation("req-1")
|
||||
assert request.status == ConfirmationStatus.APPROVED
|
||||
@@ -134,7 +134,7 @@ class TestPermissionManager:
|
||||
tool_name="delete_file",
|
||||
params={"path": "/test.txt"},
|
||||
)
|
||||
|
||||
|
||||
self.manager.deny_confirmation("req-2")
|
||||
assert request.status == ConfirmationStatus.DENIED
|
||||
|
||||
@@ -146,7 +146,7 @@ class TestPermissionManager:
|
||||
tool_name="test",
|
||||
params={},
|
||||
)
|
||||
|
||||
|
||||
pending = self.manager.get_pending_confirmations()
|
||||
assert len(pending) >= 1
|
||||
|
||||
@@ -157,48 +157,48 @@ class TestToolRegistry:
|
||||
def test_import_registry(self):
|
||||
"""测试导入注册中心"""
|
||||
from minenasai.agent import ToolRegistry, get_tool_registry
|
||||
|
||||
|
||||
registry = get_tool_registry()
|
||||
assert isinstance(registry, ToolRegistry)
|
||||
|
||||
def test_register_builtin_tools(self):
|
||||
"""测试注册内置工具"""
|
||||
from minenasai.agent import get_tool_registry, register_builtin_tools
|
||||
|
||||
|
||||
registry = get_tool_registry()
|
||||
initial_count = len(registry.list_tools())
|
||||
|
||||
|
||||
register_builtin_tools()
|
||||
|
||||
|
||||
# 应该有更多工具
|
||||
new_count = len(registry.list_tools())
|
||||
assert new_count >= initial_count
|
||||
|
||||
def test_tool_decorator(self):
|
||||
"""测试工具装饰器"""
|
||||
from minenasai.agent import tool, get_tool_registry
|
||||
|
||||
from minenasai.agent import get_tool_registry, tool
|
||||
|
||||
@tool(name="decorated_tool", description="装饰器测试")
|
||||
async def decorated_tool(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
|
||||
registry = get_tool_registry()
|
||||
tool_obj = registry.get("decorated_tool")
|
||||
|
||||
|
||||
assert tool_obj is not None
|
||||
assert tool_obj.description == "装饰器测试"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_tool(self):
|
||||
"""测试执行工具"""
|
||||
from minenasai.agent import get_tool_registry, DangerLevel
|
||||
|
||||
from minenasai.agent import DangerLevel, get_tool_registry
|
||||
|
||||
registry = get_tool_registry()
|
||||
|
||||
|
||||
# 注册测试工具
|
||||
async def echo(message: str) -> dict:
|
||||
return {"echo": message}
|
||||
|
||||
|
||||
registry.register(
|
||||
name="echo",
|
||||
description="回显消息",
|
||||
@@ -210,18 +210,18 @@ class TestToolRegistry:
|
||||
},
|
||||
danger_level=DangerLevel.SAFE,
|
||||
)
|
||||
|
||||
|
||||
result = await registry.execute("echo", {"message": "hello"})
|
||||
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["result"]["echo"] == "hello"
|
||||
|
||||
def test_get_stats(self):
|
||||
"""测试获取统计"""
|
||||
from minenasai.agent import get_tool_registry
|
||||
|
||||
|
||||
registry = get_tool_registry()
|
||||
stats = registry.get_stats()
|
||||
|
||||
|
||||
assert "total_tools" in stats
|
||||
assert "categories" in stats
|
||||
|
||||
@@ -15,7 +15,7 @@ class TestCronParser:
|
||||
def test_parse_all_stars(self):
|
||||
"""测试全星号表达式"""
|
||||
result = CronParser.parse("* * * * *")
|
||||
|
||||
|
||||
assert len(result["minute"]) == 60
|
||||
assert len(result["hour"]) == 24
|
||||
assert len(result["day"]) == 31
|
||||
@@ -25,39 +25,39 @@ class TestCronParser:
|
||||
def test_parse_specific_values(self):
|
||||
"""测试具体值"""
|
||||
result = CronParser.parse("30 8 * * *")
|
||||
|
||||
|
||||
assert result["minute"] == {30}
|
||||
assert result["hour"] == {8}
|
||||
|
||||
def test_parse_range(self):
|
||||
"""测试范围"""
|
||||
result = CronParser.parse("0 9-17 * * *")
|
||||
|
||||
|
||||
assert result["hour"] == {9, 10, 11, 12, 13, 14, 15, 16, 17}
|
||||
|
||||
def test_parse_step(self):
|
||||
"""测试步进"""
|
||||
result = CronParser.parse("*/15 * * * *")
|
||||
|
||||
|
||||
assert result["minute"] == {0, 15, 30, 45}
|
||||
|
||||
def test_parse_list(self):
|
||||
"""测试列表"""
|
||||
result = CronParser.parse("0 8,12,18 * * *")
|
||||
|
||||
|
||||
assert result["hour"] == {8, 12, 18}
|
||||
|
||||
def test_parse_preset_daily(self):
|
||||
"""测试预定义表达式 @daily"""
|
||||
result = CronParser.parse("@daily")
|
||||
|
||||
|
||||
assert result["minute"] == {0}
|
||||
assert result["hour"] == {0}
|
||||
|
||||
def test_parse_preset_hourly(self):
|
||||
"""测试预定义表达式 @hourly"""
|
||||
result = CronParser.parse("@hourly")
|
||||
|
||||
|
||||
assert result["minute"] == {0}
|
||||
assert len(result["hour"]) == 24
|
||||
|
||||
@@ -73,7 +73,7 @@ class TestCronParser:
|
||||
"0 * * * *",
|
||||
after=datetime(2026, 1, 1, 10, 30)
|
||||
)
|
||||
|
||||
|
||||
assert next_run.minute == 0
|
||||
assert next_run.hour == 11
|
||||
|
||||
@@ -89,7 +89,7 @@ class TestCronJob:
|
||||
schedule="*/5 * * * *",
|
||||
task="测试",
|
||||
)
|
||||
|
||||
|
||||
assert job.id == "test-job"
|
||||
assert job.enabled is True
|
||||
assert job.last_status == JobStatus.PENDING
|
||||
@@ -113,7 +113,7 @@ class TestCronScheduler:
|
||||
schedule="*/5 * * * *",
|
||||
callback=task,
|
||||
)
|
||||
|
||||
|
||||
assert job.id == "test-1"
|
||||
assert job.next_run is not None
|
||||
|
||||
@@ -123,7 +123,7 @@ class TestCronScheduler:
|
||||
pass
|
||||
|
||||
self.scheduler.add_job("test-1", "测试", "* * * * *", task)
|
||||
|
||||
|
||||
assert self.scheduler.remove_job("test-1") is True
|
||||
assert self.scheduler.get_job("test-1") is None
|
||||
|
||||
@@ -133,12 +133,12 @@ class TestCronScheduler:
|
||||
pass
|
||||
|
||||
self.scheduler.add_job("test-1", "测试", "* * * * *", task)
|
||||
|
||||
|
||||
assert self.scheduler.disable_job("test-1") is True
|
||||
job = self.scheduler.get_job("test-1")
|
||||
assert job.enabled is False
|
||||
assert job.last_status == JobStatus.DISABLED
|
||||
|
||||
|
||||
assert self.scheduler.enable_job("test-1") is True
|
||||
assert job.enabled is True
|
||||
|
||||
@@ -149,7 +149,7 @@ class TestCronScheduler:
|
||||
|
||||
self.scheduler.add_job("test-1", "任务1", "* * * * *", task)
|
||||
self.scheduler.add_job("test-2", "任务2", "*/5 * * * *", task)
|
||||
|
||||
|
||||
jobs = self.scheduler.list_jobs()
|
||||
assert len(jobs) == 2
|
||||
|
||||
@@ -159,7 +159,7 @@ class TestCronScheduler:
|
||||
pass
|
||||
|
||||
self.scheduler.add_job("test-1", "任务1", "* * * * *", task)
|
||||
|
||||
|
||||
stats = self.scheduler.get_stats()
|
||||
assert stats["total_jobs"] == 1
|
||||
assert stats["enabled_jobs"] == 1
|
||||
|
||||
@@ -4,8 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from minenasai.webtui.auth import AuthManager, AuthToken
|
||||
|
||||
|
||||
@@ -45,7 +43,7 @@ class TestAuthManager:
|
||||
def test_generate_token(self):
|
||||
"""测试生成令牌"""
|
||||
token = self.manager.generate_token("user1")
|
||||
|
||||
|
||||
assert token is not None
|
||||
assert len(token) > 20
|
||||
|
||||
@@ -53,7 +51,7 @@ class TestAuthManager:
|
||||
"""测试验证令牌"""
|
||||
token = self.manager.generate_token("user1")
|
||||
auth_token = self.manager.verify_token(token)
|
||||
|
||||
|
||||
assert auth_token is not None
|
||||
assert auth_token.user_id == "user1"
|
||||
|
||||
@@ -65,17 +63,17 @@ class TestAuthManager:
|
||||
def test_verify_expired_token(self):
|
||||
"""测试验证过期令牌"""
|
||||
token = self.manager.generate_token("user1", expires_in=0)
|
||||
|
||||
|
||||
# 等待过期
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
auth_token = self.manager.verify_token(token)
|
||||
assert auth_token is None
|
||||
|
||||
def test_revoke_token(self):
|
||||
"""测试撤销令牌"""
|
||||
token = self.manager.generate_token("user1")
|
||||
|
||||
|
||||
assert self.manager.revoke_token(token) is True
|
||||
assert self.manager.verify_token(token) is None
|
||||
|
||||
@@ -88,9 +86,9 @@ class TestAuthManager:
|
||||
self.manager.generate_token("user1")
|
||||
self.manager.generate_token("user1")
|
||||
self.manager.generate_token("user2")
|
||||
|
||||
|
||||
count = self.manager.revoke_user_tokens("user1")
|
||||
|
||||
|
||||
assert count == 2
|
||||
assert self.manager.get_stats()["total_tokens"] == 1
|
||||
|
||||
@@ -98,7 +96,7 @@ class TestAuthManager:
|
||||
"""测试刷新令牌"""
|
||||
old_token = self.manager.generate_token("user1")
|
||||
new_token = self.manager.refresh_token(old_token)
|
||||
|
||||
|
||||
assert new_token is not None
|
||||
assert new_token != old_token
|
||||
assert self.manager.verify_token(old_token) is None
|
||||
@@ -108,9 +106,9 @@ class TestAuthManager:
|
||||
"""测试令牌元数据"""
|
||||
metadata = {"channel": "wework", "task_id": "123"}
|
||||
token = self.manager.generate_token("user1", metadata=metadata)
|
||||
|
||||
|
||||
auth_token = self.manager.verify_token(token)
|
||||
|
||||
|
||||
assert auth_token is not None
|
||||
assert auth_token.metadata == metadata
|
||||
|
||||
@@ -118,10 +116,10 @@ class TestAuthManager:
|
||||
"""测试清理过期令牌"""
|
||||
self.manager.generate_token("user1", expires_in=0)
|
||||
self.manager.generate_token("user2", expires_in=3600)
|
||||
|
||||
|
||||
time.sleep(0.1)
|
||||
count = self.manager.cleanup_expired()
|
||||
|
||||
|
||||
assert count == 1
|
||||
assert self.manager.get_stats()["total_tokens"] == 1
|
||||
|
||||
@@ -132,17 +130,17 @@ class TestSSHManager:
|
||||
def test_import_ssh_manager(self):
|
||||
"""测试导入 SSH 管理器"""
|
||||
from minenasai.webtui import SSHManager, get_ssh_manager
|
||||
|
||||
|
||||
manager = get_ssh_manager()
|
||||
assert isinstance(manager, SSHManager)
|
||||
|
||||
def test_ssh_manager_stats(self):
|
||||
"""测试 SSH 管理器统计"""
|
||||
from minenasai.webtui import SSHManager
|
||||
|
||||
|
||||
manager = SSHManager()
|
||||
stats = manager.get_stats()
|
||||
|
||||
|
||||
assert "active_sessions" in stats
|
||||
assert stats["active_sessions"] == 0
|
||||
|
||||
@@ -153,6 +151,6 @@ class TestWebTUIServer:
|
||||
def test_import_server(self):
|
||||
"""测试导入服务器"""
|
||||
from minenasai.webtui.server import app
|
||||
|
||||
|
||||
assert app is not None
|
||||
assert app.title == "MineNASAI Web TUI"
|
||||
|
||||
Reference in New Issue
Block a user