346 lines
10 KiB
Python
346 lines
10 KiB
Python
|
|
"""数据库模块测试
|
||
|
|
|
||
|
|
测试 SQLite 异步数据库操作
|
||
|
|
"""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from minenasai.core.database import Database
|
||
|
|
|
||
|
|
|
||
|
|
class TestDatabase:
|
||
|
|
"""数据库基本操作测试"""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
async def db(self, tmp_path):
|
||
|
|
"""创建临时数据库"""
|
||
|
|
db_path = tmp_path / "test.db"
|
||
|
|
db = Database(db_path)
|
||
|
|
await db.connect()
|
||
|
|
yield db
|
||
|
|
await db.close()
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_connect_and_close(self, tmp_path):
|
||
|
|
"""测试数据库连接和关闭"""
|
||
|
|
db_path = tmp_path / "test.db"
|
||
|
|
db = Database(db_path)
|
||
|
|
|
||
|
|
# 连接前访问 conn 应该报错
|
||
|
|
with pytest.raises(RuntimeError, match="数据库未连接"):
|
||
|
|
_ = db.conn
|
||
|
|
|
||
|
|
# 连接
|
||
|
|
await db.connect()
|
||
|
|
assert db.conn is not None
|
||
|
|
|
||
|
|
# 关闭
|
||
|
|
await db.close()
|
||
|
|
assert db._conn is None
|
||
|
|
|
||
|
|
|
||
|
|
class TestAgentOperations:
|
||
|
|
"""Agent 操作测试"""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
async def db(self, tmp_path):
|
||
|
|
"""创建临时数据库"""
|
||
|
|
db_path = tmp_path / "test.db"
|
||
|
|
db = Database(db_path)
|
||
|
|
await db.connect()
|
||
|
|
yield db
|
||
|
|
await db.close()
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_create_agent(self, db):
|
||
|
|
"""测试创建 Agent"""
|
||
|
|
agent = await db.create_agent(
|
||
|
|
agent_id="test-agent-1",
|
||
|
|
name="Test Agent",
|
||
|
|
workspace_path="/tmp/workspace",
|
||
|
|
)
|
||
|
|
|
||
|
|
assert agent["id"] == "test-agent-1"
|
||
|
|
assert agent["name"] == "Test Agent"
|
||
|
|
assert agent["workspace_path"] == "/tmp/workspace"
|
||
|
|
assert agent["model"] == "claude-sonnet-4-20250514"
|
||
|
|
assert agent["sandbox_mode"] == "workspace"
|
||
|
|
assert "created_at" in agent
|
||
|
|
assert "updated_at" in agent
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_create_agent_custom_model(self, db):
|
||
|
|
"""测试使用自定义模型创建 Agent"""
|
||
|
|
agent = await db.create_agent(
|
||
|
|
agent_id="test-agent-2",
|
||
|
|
name="Custom Agent",
|
||
|
|
workspace_path="/tmp/workspace",
|
||
|
|
model="gpt-4",
|
||
|
|
sandbox_mode="strict",
|
||
|
|
)
|
||
|
|
|
||
|
|
assert agent["model"] == "gpt-4"
|
||
|
|
assert agent["sandbox_mode"] == "strict"
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_get_agent(self, db):
|
||
|
|
"""测试获取 Agent"""
|
||
|
|
# 创建 Agent
|
||
|
|
await db.create_agent(
|
||
|
|
agent_id="test-agent-3",
|
||
|
|
name="Get Test Agent",
|
||
|
|
workspace_path="/tmp/workspace",
|
||
|
|
)
|
||
|
|
|
||
|
|
# 获取存在的 Agent
|
||
|
|
agent = await db.get_agent("test-agent-3")
|
||
|
|
assert agent is not None
|
||
|
|
assert agent["name"] == "Get Test Agent"
|
||
|
|
|
||
|
|
# 获取不存在的 Agent
|
||
|
|
agent = await db.get_agent("nonexistent")
|
||
|
|
assert agent is None
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_list_agents(self, db):
|
||
|
|
"""测试列出所有 Agent"""
|
||
|
|
# 初始为空
|
||
|
|
agents = await db.list_agents()
|
||
|
|
assert len(agents) == 0
|
||
|
|
|
||
|
|
# 创建多个 Agent
|
||
|
|
await db.create_agent("agent-1", "Agent 1", "/tmp/1")
|
||
|
|
await db.create_agent("agent-2", "Agent 2", "/tmp/2")
|
||
|
|
await db.create_agent("agent-3", "Agent 3", "/tmp/3")
|
||
|
|
|
||
|
|
# 列出所有
|
||
|
|
agents = await db.list_agents()
|
||
|
|
assert len(agents) == 3
|
||
|
|
|
||
|
|
|
||
|
|
class TestSessionOperations:
|
||
|
|
"""Session 操作测试"""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
async def db(self, tmp_path):
|
||
|
|
"""创建临时数据库并添加测试 Agent"""
|
||
|
|
db_path = tmp_path / "test.db"
|
||
|
|
db = Database(db_path)
|
||
|
|
await db.connect()
|
||
|
|
await db.create_agent("test-agent", "Test Agent", "/tmp/workspace")
|
||
|
|
yield db
|
||
|
|
await db.close()
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_create_session(self, db):
|
||
|
|
"""测试创建会话"""
|
||
|
|
session = await db.create_session(
|
||
|
|
agent_id="test-agent",
|
||
|
|
channel="websocket",
|
||
|
|
peer_id="user-123",
|
||
|
|
)
|
||
|
|
|
||
|
|
assert session["agent_id"] == "test-agent"
|
||
|
|
assert session["channel"] == "websocket"
|
||
|
|
assert session["peer_id"] == "user-123"
|
||
|
|
assert session["status"] == "active"
|
||
|
|
assert "session_key" in session
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_create_session_with_metadata(self, db):
|
||
|
|
"""测试创建带元数据的会话"""
|
||
|
|
metadata = {"client": "web", "version": "1.0"}
|
||
|
|
session = await db.create_session(
|
||
|
|
agent_id="test-agent",
|
||
|
|
channel="websocket",
|
||
|
|
metadata=metadata,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert session["metadata"] == metadata
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_get_session(self, db):
|
||
|
|
"""测试获取会话"""
|
||
|
|
# 创建会话
|
||
|
|
created = await db.create_session(
|
||
|
|
agent_id="test-agent",
|
||
|
|
channel="websocket",
|
||
|
|
metadata={"test": "data"},
|
||
|
|
)
|
||
|
|
|
||
|
|
# 获取会话
|
||
|
|
session = await db.get_session(created["id"])
|
||
|
|
assert session is not None
|
||
|
|
assert session["id"] == created["id"]
|
||
|
|
assert session["metadata"] == {"test": "data"}
|
||
|
|
|
||
|
|
# 获取不存在的会话
|
||
|
|
session = await db.get_session("nonexistent-id")
|
||
|
|
assert session is None
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_update_session_status(self, db):
|
||
|
|
"""测试更新会话状态"""
|
||
|
|
# 创建会话
|
||
|
|
session = await db.create_session(
|
||
|
|
agent_id="test-agent",
|
||
|
|
channel="websocket",
|
||
|
|
)
|
||
|
|
assert session["status"] == "active"
|
||
|
|
|
||
|
|
# 更新状态
|
||
|
|
await db.update_session_status(session["id"], "closed")
|
||
|
|
|
||
|
|
# 验证
|
||
|
|
updated = await db.get_session(session["id"])
|
||
|
|
assert updated["status"] == "closed"
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_list_active_sessions(self, db):
|
||
|
|
"""测试列出活跃会话"""
|
||
|
|
# 创建多个会话
|
||
|
|
session1 = await db.create_session("test-agent", "websocket")
|
||
|
|
session2 = await db.create_session("test-agent", "feishu")
|
||
|
|
await db.create_session("test-agent", "wework")
|
||
|
|
|
||
|
|
# 关闭一个会话
|
||
|
|
await db.update_session_status(session2["id"], "closed")
|
||
|
|
|
||
|
|
# 列出所有活跃会话
|
||
|
|
active = await db.list_active_sessions()
|
||
|
|
assert len(active) == 2
|
||
|
|
|
||
|
|
# 按 agent_id 过滤
|
||
|
|
active = await db.list_active_sessions("test-agent")
|
||
|
|
assert len(active) == 2
|
||
|
|
|
||
|
|
|
||
|
|
class TestMessageOperations:
|
||
|
|
"""Message 操作测试"""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
async def db(self, tmp_path):
|
||
|
|
"""创建临时数据库并添加测试会话"""
|
||
|
|
db_path = tmp_path / "test.db"
|
||
|
|
db = Database(db_path)
|
||
|
|
await db.connect()
|
||
|
|
await db.create_agent("test-agent", "Test Agent", "/tmp/workspace")
|
||
|
|
session = await db.create_session("test-agent", "websocket")
|
||
|
|
db._test_session_id = session["id"]
|
||
|
|
yield db
|
||
|
|
await db.close()
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_add_message(self, db):
|
||
|
|
"""测试添加消息"""
|
||
|
|
session_id = db._test_session_id
|
||
|
|
|
||
|
|
message = await db.add_message(
|
||
|
|
session_id=session_id,
|
||
|
|
role="user",
|
||
|
|
content="Hello, world!",
|
||
|
|
)
|
||
|
|
|
||
|
|
assert message["session_id"] == session_id
|
||
|
|
assert message["role"] == "user"
|
||
|
|
assert message["content"] == "Hello, world!"
|
||
|
|
assert message["tokens_used"] == 0
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_add_message_with_tool_calls(self, db):
|
||
|
|
"""测试添加带工具调用的消息"""
|
||
|
|
session_id = db._test_session_id
|
||
|
|
|
||
|
|
tool_calls = [
|
||
|
|
{"id": "call-1", "name": "read_file", "arguments": {"path": "/tmp/test"}}
|
||
|
|
]
|
||
|
|
message = await db.add_message(
|
||
|
|
session_id=session_id,
|
||
|
|
role="assistant",
|
||
|
|
content=None,
|
||
|
|
tool_calls=tool_calls,
|
||
|
|
tokens_used=100,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert message["tool_calls"] == tool_calls
|
||
|
|
assert message["tokens_used"] == 100
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_get_messages(self, db):
|
||
|
|
"""测试获取会话消息"""
|
||
|
|
session_id = db._test_session_id
|
||
|
|
|
||
|
|
# 添加多条消息
|
||
|
|
await db.add_message(session_id, "user", "Message 1")
|
||
|
|
await db.add_message(session_id, "assistant", "Response 1")
|
||
|
|
await db.add_message(session_id, "user", "Message 2")
|
||
|
|
await db.add_message(session_id, "assistant", "Response 2")
|
||
|
|
|
||
|
|
# 获取所有消息
|
||
|
|
messages = await db.get_messages(session_id)
|
||
|
|
assert len(messages) == 4
|
||
|
|
# 验证角色交替
|
||
|
|
roles = [m["role"] for m in messages]
|
||
|
|
assert roles.count("user") == 2
|
||
|
|
assert roles.count("assistant") == 2
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_get_messages_with_limit(self, db):
|
||
|
|
"""测试获取消息数量限制"""
|
||
|
|
session_id = db._test_session_id
|
||
|
|
|
||
|
|
# 添加多条消息
|
||
|
|
for i in range(10):
|
||
|
|
await db.add_message(session_id, "user", f"Message {i}")
|
||
|
|
|
||
|
|
# 限制返回数量
|
||
|
|
messages = await db.get_messages(session_id, limit=5)
|
||
|
|
assert len(messages) == 5
|
||
|
|
|
||
|
|
|
||
|
|
class TestAuditLog:
|
||
|
|
"""审计日志测试"""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
async def db(self, tmp_path):
|
||
|
|
"""创建临时数据库"""
|
||
|
|
db_path = tmp_path / "test.db"
|
||
|
|
db = Database(db_path)
|
||
|
|
await db.connect()
|
||
|
|
yield db
|
||
|
|
await db.close()
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_add_audit_log(self, db):
|
||
|
|
"""测试添加审计日志"""
|
||
|
|
# 添加日志不应该报错
|
||
|
|
await db.add_audit_log(
|
||
|
|
agent_id="test-agent",
|
||
|
|
tool_name="read_file",
|
||
|
|
danger_level="safe",
|
||
|
|
params={"path": "/tmp/test.txt"},
|
||
|
|
result="success",
|
||
|
|
duration_ms=50,
|
||
|
|
)
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_add_audit_log_minimal(self, db):
|
||
|
|
"""测试添加最小审计日志"""
|
||
|
|
await db.add_audit_log(
|
||
|
|
agent_id=None,
|
||
|
|
tool_name="python_eval",
|
||
|
|
danger_level="low",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class TestGlobalDatabase:
|
||
|
|
"""全局数据库实例测试"""
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_import_functions(self):
|
||
|
|
"""测试导入全局函数"""
|
||
|
|
from minenasai.core.database import close_database, get_database
|
||
|
|
|
||
|
|
assert callable(get_database)
|
||
|
|
assert callable(close_database)
|