Files
MineNasAI/tests/test_database.py

346 lines
10 KiB
Python

"""数据库模块测试
测试 SQLite 异步数据库操作
"""
import pytest
from minenasai.core.database import Database
class TestDatabase:
"""数据库基本操作测试"""
@pytest.fixture
async def db(self, tmp_path):
"""创建临时数据库"""
db_path = tmp_path / "test.db"
db = Database(db_path)
await db.connect()
yield db
await db.close()
@pytest.mark.asyncio
async def test_connect_and_close(self, tmp_path):
"""测试数据库连接和关闭"""
db_path = tmp_path / "test.db"
db = Database(db_path)
# 连接前访问 conn 应该报错
with pytest.raises(RuntimeError, match="数据库未连接"):
_ = db.conn
# 连接
await db.connect()
assert db.conn is not None
# 关闭
await db.close()
assert db._conn is None
class TestAgentOperations:
"""Agent 操作测试"""
@pytest.fixture
async def db(self, tmp_path):
"""创建临时数据库"""
db_path = tmp_path / "test.db"
db = Database(db_path)
await db.connect()
yield db
await db.close()
@pytest.mark.asyncio
async def test_create_agent(self, db):
"""测试创建 Agent"""
agent = await db.create_agent(
agent_id="test-agent-1",
name="Test Agent",
workspace_path="/tmp/workspace",
)
assert agent["id"] == "test-agent-1"
assert agent["name"] == "Test Agent"
assert agent["workspace_path"] == "/tmp/workspace"
assert agent["model"] == "claude-sonnet-4-20250514"
assert agent["sandbox_mode"] == "workspace"
assert "created_at" in agent
assert "updated_at" in agent
@pytest.mark.asyncio
async def test_create_agent_custom_model(self, db):
"""测试使用自定义模型创建 Agent"""
agent = await db.create_agent(
agent_id="test-agent-2",
name="Custom Agent",
workspace_path="/tmp/workspace",
model="gpt-4",
sandbox_mode="strict",
)
assert agent["model"] == "gpt-4"
assert agent["sandbox_mode"] == "strict"
@pytest.mark.asyncio
async def test_get_agent(self, db):
"""测试获取 Agent"""
# 创建 Agent
await db.create_agent(
agent_id="test-agent-3",
name="Get Test Agent",
workspace_path="/tmp/workspace",
)
# 获取存在的 Agent
agent = await db.get_agent("test-agent-3")
assert agent is not None
assert agent["name"] == "Get Test Agent"
# 获取不存在的 Agent
agent = await db.get_agent("nonexistent")
assert agent is None
@pytest.mark.asyncio
async def test_list_agents(self, db):
"""测试列出所有 Agent"""
# 初始为空
agents = await db.list_agents()
assert len(agents) == 0
# 创建多个 Agent
await db.create_agent("agent-1", "Agent 1", "/tmp/1")
await db.create_agent("agent-2", "Agent 2", "/tmp/2")
await db.create_agent("agent-3", "Agent 3", "/tmp/3")
# 列出所有
agents = await db.list_agents()
assert len(agents) == 3
class TestSessionOperations:
"""Session 操作测试"""
@pytest.fixture
async def db(self, tmp_path):
"""创建临时数据库并添加测试 Agent"""
db_path = tmp_path / "test.db"
db = Database(db_path)
await db.connect()
await db.create_agent("test-agent", "Test Agent", "/tmp/workspace")
yield db
await db.close()
@pytest.mark.asyncio
async def test_create_session(self, db):
"""测试创建会话"""
session = await db.create_session(
agent_id="test-agent",
channel="websocket",
peer_id="user-123",
)
assert session["agent_id"] == "test-agent"
assert session["channel"] == "websocket"
assert session["peer_id"] == "user-123"
assert session["status"] == "active"
assert "session_key" in session
@pytest.mark.asyncio
async def test_create_session_with_metadata(self, db):
"""测试创建带元数据的会话"""
metadata = {"client": "web", "version": "1.0"}
session = await db.create_session(
agent_id="test-agent",
channel="websocket",
metadata=metadata,
)
assert session["metadata"] == metadata
@pytest.mark.asyncio
async def test_get_session(self, db):
"""测试获取会话"""
# 创建会话
created = await db.create_session(
agent_id="test-agent",
channel="websocket",
metadata={"test": "data"},
)
# 获取会话
session = await db.get_session(created["id"])
assert session is not None
assert session["id"] == created["id"]
assert session["metadata"] == {"test": "data"}
# 获取不存在的会话
session = await db.get_session("nonexistent-id")
assert session is None
@pytest.mark.asyncio
async def test_update_session_status(self, db):
"""测试更新会话状态"""
# 创建会话
session = await db.create_session(
agent_id="test-agent",
channel="websocket",
)
assert session["status"] == "active"
# 更新状态
await db.update_session_status(session["id"], "closed")
# 验证
updated = await db.get_session(session["id"])
assert updated["status"] == "closed"
@pytest.mark.asyncio
async def test_list_active_sessions(self, db):
"""测试列出活跃会话"""
# 创建多个会话
session1 = await db.create_session("test-agent", "websocket")
session2 = await db.create_session("test-agent", "feishu")
await db.create_session("test-agent", "wework")
# 关闭一个会话
await db.update_session_status(session2["id"], "closed")
# 列出所有活跃会话
active = await db.list_active_sessions()
assert len(active) == 2
# 按 agent_id 过滤
active = await db.list_active_sessions("test-agent")
assert len(active) == 2
class TestMessageOperations:
"""Message 操作测试"""
@pytest.fixture
async def db(self, tmp_path):
"""创建临时数据库并添加测试会话"""
db_path = tmp_path / "test.db"
db = Database(db_path)
await db.connect()
await db.create_agent("test-agent", "Test Agent", "/tmp/workspace")
session = await db.create_session("test-agent", "websocket")
db._test_session_id = session["id"]
yield db
await db.close()
@pytest.mark.asyncio
async def test_add_message(self, db):
"""测试添加消息"""
session_id = db._test_session_id
message = await db.add_message(
session_id=session_id,
role="user",
content="Hello, world!",
)
assert message["session_id"] == session_id
assert message["role"] == "user"
assert message["content"] == "Hello, world!"
assert message["tokens_used"] == 0
@pytest.mark.asyncio
async def test_add_message_with_tool_calls(self, db):
"""测试添加带工具调用的消息"""
session_id = db._test_session_id
tool_calls = [
{"id": "call-1", "name": "read_file", "arguments": {"path": "/tmp/test"}}
]
message = await db.add_message(
session_id=session_id,
role="assistant",
content=None,
tool_calls=tool_calls,
tokens_used=100,
)
assert message["tool_calls"] == tool_calls
assert message["tokens_used"] == 100
@pytest.mark.asyncio
async def test_get_messages(self, db):
"""测试获取会话消息"""
session_id = db._test_session_id
# 添加多条消息
await db.add_message(session_id, "user", "Message 1")
await db.add_message(session_id, "assistant", "Response 1")
await db.add_message(session_id, "user", "Message 2")
await db.add_message(session_id, "assistant", "Response 2")
# 获取所有消息
messages = await db.get_messages(session_id)
assert len(messages) == 4
# 验证角色交替
roles = [m["role"] for m in messages]
assert roles.count("user") == 2
assert roles.count("assistant") == 2
@pytest.mark.asyncio
async def test_get_messages_with_limit(self, db):
"""测试获取消息数量限制"""
session_id = db._test_session_id
# 添加多条消息
for i in range(10):
await db.add_message(session_id, "user", f"Message {i}")
# 限制返回数量
messages = await db.get_messages(session_id, limit=5)
assert len(messages) == 5
class TestAuditLog:
"""审计日志测试"""
@pytest.fixture
async def db(self, tmp_path):
"""创建临时数据库"""
db_path = tmp_path / "test.db"
db = Database(db_path)
await db.connect()
yield db
await db.close()
@pytest.mark.asyncio
async def test_add_audit_log(self, db):
"""测试添加审计日志"""
# 添加日志不应该报错
await db.add_audit_log(
agent_id="test-agent",
tool_name="read_file",
danger_level="safe",
params={"path": "/tmp/test.txt"},
result="success",
duration_ms=50,
)
@pytest.mark.asyncio
async def test_add_audit_log_minimal(self, db):
"""测试添加最小审计日志"""
await db.add_audit_log(
agent_id=None,
tool_name="python_eval",
danger_level="low",
)
class TestGlobalDatabase:
"""全局数据库实例测试"""
@pytest.mark.asyncio
async def test_import_functions(self):
"""测试导入全局函数"""
from minenasai.core.database import close_database, get_database
assert callable(get_database)
assert callable(close_database)