feat: 更新模块文档,添加详细说明和使用示例
This commit is contained in:
345
tests/test_database.py
Normal file
345
tests/test_database.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""数据库模块测试
|
||||
|
||||
测试 SQLite 异步数据库操作
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from minenasai.core.database import Database
|
||||
|
||||
|
||||
class TestDatabase:
|
||||
"""数据库基本操作测试"""
|
||||
|
||||
@pytest.fixture
|
||||
async def db(self, tmp_path):
|
||||
"""创建临时数据库"""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = Database(db_path)
|
||||
await db.connect()
|
||||
yield db
|
||||
await db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_and_close(self, tmp_path):
|
||||
"""测试数据库连接和关闭"""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = Database(db_path)
|
||||
|
||||
# 连接前访问 conn 应该报错
|
||||
with pytest.raises(RuntimeError, match="数据库未连接"):
|
||||
_ = db.conn
|
||||
|
||||
# 连接
|
||||
await db.connect()
|
||||
assert db.conn is not None
|
||||
|
||||
# 关闭
|
||||
await db.close()
|
||||
assert db._conn is None
|
||||
|
||||
|
||||
class TestAgentOperations:
|
||||
"""Agent 操作测试"""
|
||||
|
||||
@pytest.fixture
|
||||
async def db(self, tmp_path):
|
||||
"""创建临时数据库"""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = Database(db_path)
|
||||
await db.connect()
|
||||
yield db
|
||||
await db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent(self, db):
|
||||
"""测试创建 Agent"""
|
||||
agent = await db.create_agent(
|
||||
agent_id="test-agent-1",
|
||||
name="Test Agent",
|
||||
workspace_path="/tmp/workspace",
|
||||
)
|
||||
|
||||
assert agent["id"] == "test-agent-1"
|
||||
assert agent["name"] == "Test Agent"
|
||||
assert agent["workspace_path"] == "/tmp/workspace"
|
||||
assert agent["model"] == "claude-sonnet-4-20250514"
|
||||
assert agent["sandbox_mode"] == "workspace"
|
||||
assert "created_at" in agent
|
||||
assert "updated_at" in agent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_custom_model(self, db):
|
||||
"""测试使用自定义模型创建 Agent"""
|
||||
agent = await db.create_agent(
|
||||
agent_id="test-agent-2",
|
||||
name="Custom Agent",
|
||||
workspace_path="/tmp/workspace",
|
||||
model="gpt-4",
|
||||
sandbox_mode="strict",
|
||||
)
|
||||
|
||||
assert agent["model"] == "gpt-4"
|
||||
assert agent["sandbox_mode"] == "strict"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_agent(self, db):
|
||||
"""测试获取 Agent"""
|
||||
# 创建 Agent
|
||||
await db.create_agent(
|
||||
agent_id="test-agent-3",
|
||||
name="Get Test Agent",
|
||||
workspace_path="/tmp/workspace",
|
||||
)
|
||||
|
||||
# 获取存在的 Agent
|
||||
agent = await db.get_agent("test-agent-3")
|
||||
assert agent is not None
|
||||
assert agent["name"] == "Get Test Agent"
|
||||
|
||||
# 获取不存在的 Agent
|
||||
agent = await db.get_agent("nonexistent")
|
||||
assert agent is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents(self, db):
|
||||
"""测试列出所有 Agent"""
|
||||
# 初始为空
|
||||
agents = await db.list_agents()
|
||||
assert len(agents) == 0
|
||||
|
||||
# 创建多个 Agent
|
||||
await db.create_agent("agent-1", "Agent 1", "/tmp/1")
|
||||
await db.create_agent("agent-2", "Agent 2", "/tmp/2")
|
||||
await db.create_agent("agent-3", "Agent 3", "/tmp/3")
|
||||
|
||||
# 列出所有
|
||||
agents = await db.list_agents()
|
||||
assert len(agents) == 3
|
||||
|
||||
|
||||
class TestSessionOperations:
|
||||
"""Session 操作测试"""
|
||||
|
||||
@pytest.fixture
|
||||
async def db(self, tmp_path):
|
||||
"""创建临时数据库并添加测试 Agent"""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = Database(db_path)
|
||||
await db.connect()
|
||||
await db.create_agent("test-agent", "Test Agent", "/tmp/workspace")
|
||||
yield db
|
||||
await db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session(self, db):
|
||||
"""测试创建会话"""
|
||||
session = await db.create_session(
|
||||
agent_id="test-agent",
|
||||
channel="websocket",
|
||||
peer_id="user-123",
|
||||
)
|
||||
|
||||
assert session["agent_id"] == "test-agent"
|
||||
assert session["channel"] == "websocket"
|
||||
assert session["peer_id"] == "user-123"
|
||||
assert session["status"] == "active"
|
||||
assert "session_key" in session
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_with_metadata(self, db):
|
||||
"""测试创建带元数据的会话"""
|
||||
metadata = {"client": "web", "version": "1.0"}
|
||||
session = await db.create_session(
|
||||
agent_id="test-agent",
|
||||
channel="websocket",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
assert session["metadata"] == metadata
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session(self, db):
|
||||
"""测试获取会话"""
|
||||
# 创建会话
|
||||
created = await db.create_session(
|
||||
agent_id="test-agent",
|
||||
channel="websocket",
|
||||
metadata={"test": "data"},
|
||||
)
|
||||
|
||||
# 获取会话
|
||||
session = await db.get_session(created["id"])
|
||||
assert session is not None
|
||||
assert session["id"] == created["id"]
|
||||
assert session["metadata"] == {"test": "data"}
|
||||
|
||||
# 获取不存在的会话
|
||||
session = await db.get_session("nonexistent-id")
|
||||
assert session is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_session_status(self, db):
|
||||
"""测试更新会话状态"""
|
||||
# 创建会话
|
||||
session = await db.create_session(
|
||||
agent_id="test-agent",
|
||||
channel="websocket",
|
||||
)
|
||||
assert session["status"] == "active"
|
||||
|
||||
# 更新状态
|
||||
await db.update_session_status(session["id"], "closed")
|
||||
|
||||
# 验证
|
||||
updated = await db.get_session(session["id"])
|
||||
assert updated["status"] == "closed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_active_sessions(self, db):
|
||||
"""测试列出活跃会话"""
|
||||
# 创建多个会话
|
||||
session1 = await db.create_session("test-agent", "websocket")
|
||||
session2 = await db.create_session("test-agent", "feishu")
|
||||
await db.create_session("test-agent", "wework")
|
||||
|
||||
# 关闭一个会话
|
||||
await db.update_session_status(session2["id"], "closed")
|
||||
|
||||
# 列出所有活跃会话
|
||||
active = await db.list_active_sessions()
|
||||
assert len(active) == 2
|
||||
|
||||
# 按 agent_id 过滤
|
||||
active = await db.list_active_sessions("test-agent")
|
||||
assert len(active) == 2
|
||||
|
||||
|
||||
class TestMessageOperations:
|
||||
"""Message 操作测试"""
|
||||
|
||||
@pytest.fixture
|
||||
async def db(self, tmp_path):
|
||||
"""创建临时数据库并添加测试会话"""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = Database(db_path)
|
||||
await db.connect()
|
||||
await db.create_agent("test-agent", "Test Agent", "/tmp/workspace")
|
||||
session = await db.create_session("test-agent", "websocket")
|
||||
db._test_session_id = session["id"]
|
||||
yield db
|
||||
await db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_message(self, db):
|
||||
"""测试添加消息"""
|
||||
session_id = db._test_session_id
|
||||
|
||||
message = await db.add_message(
|
||||
session_id=session_id,
|
||||
role="user",
|
||||
content="Hello, world!",
|
||||
)
|
||||
|
||||
assert message["session_id"] == session_id
|
||||
assert message["role"] == "user"
|
||||
assert message["content"] == "Hello, world!"
|
||||
assert message["tokens_used"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_message_with_tool_calls(self, db):
|
||||
"""测试添加带工具调用的消息"""
|
||||
session_id = db._test_session_id
|
||||
|
||||
tool_calls = [
|
||||
{"id": "call-1", "name": "read_file", "arguments": {"path": "/tmp/test"}}
|
||||
]
|
||||
message = await db.add_message(
|
||||
session_id=session_id,
|
||||
role="assistant",
|
||||
content=None,
|
||||
tool_calls=tool_calls,
|
||||
tokens_used=100,
|
||||
)
|
||||
|
||||
assert message["tool_calls"] == tool_calls
|
||||
assert message["tokens_used"] == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_messages(self, db):
|
||||
"""测试获取会话消息"""
|
||||
session_id = db._test_session_id
|
||||
|
||||
# 添加多条消息
|
||||
await db.add_message(session_id, "user", "Message 1")
|
||||
await db.add_message(session_id, "assistant", "Response 1")
|
||||
await db.add_message(session_id, "user", "Message 2")
|
||||
await db.add_message(session_id, "assistant", "Response 2")
|
||||
|
||||
# 获取所有消息
|
||||
messages = await db.get_messages(session_id)
|
||||
assert len(messages) == 4
|
||||
# 验证角色交替
|
||||
roles = [m["role"] for m in messages]
|
||||
assert roles.count("user") == 2
|
||||
assert roles.count("assistant") == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_messages_with_limit(self, db):
|
||||
"""测试获取消息数量限制"""
|
||||
session_id = db._test_session_id
|
||||
|
||||
# 添加多条消息
|
||||
for i in range(10):
|
||||
await db.add_message(session_id, "user", f"Message {i}")
|
||||
|
||||
# 限制返回数量
|
||||
messages = await db.get_messages(session_id, limit=5)
|
||||
assert len(messages) == 5
|
||||
|
||||
|
||||
class TestAuditLog:
|
||||
"""审计日志测试"""
|
||||
|
||||
@pytest.fixture
|
||||
async def db(self, tmp_path):
|
||||
"""创建临时数据库"""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = Database(db_path)
|
||||
await db.connect()
|
||||
yield db
|
||||
await db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_audit_log(self, db):
|
||||
"""测试添加审计日志"""
|
||||
# 添加日志不应该报错
|
||||
await db.add_audit_log(
|
||||
agent_id="test-agent",
|
||||
tool_name="read_file",
|
||||
danger_level="safe",
|
||||
params={"path": "/tmp/test.txt"},
|
||||
result="success",
|
||||
duration_ms=50,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_audit_log_minimal(self, db):
|
||||
"""测试添加最小审计日志"""
|
||||
await db.add_audit_log(
|
||||
agent_id=None,
|
||||
tool_name="python_eval",
|
||||
danger_level="low",
|
||||
)
|
||||
|
||||
|
||||
class TestGlobalDatabase:
|
||||
"""全局数据库实例测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_functions(self):
|
||||
"""测试导入全局函数"""
|
||||
from minenasai.core.database import close_database, get_database
|
||||
|
||||
assert callable(get_database)
|
||||
assert callable(close_database)
|
||||
Reference in New Issue
Block a user