"""数据库模块测试 测试 SQLite 异步数据库操作 """ import pytest from minenasai.core.database import Database class TestDatabase: """数据库基本操作测试""" @pytest.fixture async def db(self, tmp_path): """创建临时数据库""" db_path = tmp_path / "test.db" db = Database(db_path) await db.connect() yield db await db.close() @pytest.mark.asyncio async def test_connect_and_close(self, tmp_path): """测试数据库连接和关闭""" db_path = tmp_path / "test.db" db = Database(db_path) # 连接前访问 conn 应该报错 with pytest.raises(RuntimeError, match="数据库未连接"): _ = db.conn # 连接 await db.connect() assert db.conn is not None # 关闭 await db.close() assert db._conn is None class TestAgentOperations: """Agent 操作测试""" @pytest.fixture async def db(self, tmp_path): """创建临时数据库""" db_path = tmp_path / "test.db" db = Database(db_path) await db.connect() yield db await db.close() @pytest.mark.asyncio async def test_create_agent(self, db): """测试创建 Agent""" agent = await db.create_agent( agent_id="test-agent-1", name="Test Agent", workspace_path="/tmp/workspace", ) assert agent["id"] == "test-agent-1" assert agent["name"] == "Test Agent" assert agent["workspace_path"] == "/tmp/workspace" assert agent["model"] == "claude-sonnet-4-20250514" assert agent["sandbox_mode"] == "workspace" assert "created_at" in agent assert "updated_at" in agent @pytest.mark.asyncio async def test_create_agent_custom_model(self, db): """测试使用自定义模型创建 Agent""" agent = await db.create_agent( agent_id="test-agent-2", name="Custom Agent", workspace_path="/tmp/workspace", model="gpt-4", sandbox_mode="strict", ) assert agent["model"] == "gpt-4" assert agent["sandbox_mode"] == "strict" @pytest.mark.asyncio async def test_get_agent(self, db): """测试获取 Agent""" # 创建 Agent await db.create_agent( agent_id="test-agent-3", name="Get Test Agent", workspace_path="/tmp/workspace", ) # 获取存在的 Agent agent = await db.get_agent("test-agent-3") assert agent is not None assert agent["name"] == "Get Test Agent" # 获取不存在的 Agent agent = await db.get_agent("nonexistent") assert agent is None @pytest.mark.asyncio async def test_list_agents(self, db): """测试列出所有 Agent""" # 初始为空 agents = await db.list_agents() assert len(agents) == 0 # 创建多个 Agent await db.create_agent("agent-1", "Agent 1", "/tmp/1") await db.create_agent("agent-2", "Agent 2", "/tmp/2") await db.create_agent("agent-3", "Agent 3", "/tmp/3") # 列出所有 agents = await db.list_agents() assert len(agents) == 3 class TestSessionOperations: """Session 操作测试""" @pytest.fixture async def db(self, tmp_path): """创建临时数据库并添加测试 Agent""" db_path = tmp_path / "test.db" db = Database(db_path) await db.connect() await db.create_agent("test-agent", "Test Agent", "/tmp/workspace") yield db await db.close() @pytest.mark.asyncio async def test_create_session(self, db): """测试创建会话""" session = await db.create_session( agent_id="test-agent", channel="websocket", peer_id="user-123", ) assert session["agent_id"] == "test-agent" assert session["channel"] == "websocket" assert session["peer_id"] == "user-123" assert session["status"] == "active" assert "session_key" in session @pytest.mark.asyncio async def test_create_session_with_metadata(self, db): """测试创建带元数据的会话""" metadata = {"client": "web", "version": "1.0"} session = await db.create_session( agent_id="test-agent", channel="websocket", metadata=metadata, ) assert session["metadata"] == metadata @pytest.mark.asyncio async def test_get_session(self, db): """测试获取会话""" # 创建会话 created = await db.create_session( agent_id="test-agent", channel="websocket", metadata={"test": "data"}, ) # 获取会话 session = await db.get_session(created["id"]) assert session is not None assert session["id"] == created["id"] assert session["metadata"] == {"test": "data"} # 获取不存在的会话 session = await db.get_session("nonexistent-id") assert session is None @pytest.mark.asyncio async def test_update_session_status(self, db): """测试更新会话状态""" # 创建会话 session = await db.create_session( agent_id="test-agent", channel="websocket", ) assert session["status"] == "active" # 更新状态 await db.update_session_status(session["id"], "closed") # 验证 updated = await db.get_session(session["id"]) assert updated["status"] == "closed" @pytest.mark.asyncio async def test_list_active_sessions(self, db): """测试列出活跃会话""" # 创建多个会话 session1 = await db.create_session("test-agent", "websocket") session2 = await db.create_session("test-agent", "feishu") await db.create_session("test-agent", "wework") # 关闭一个会话 await db.update_session_status(session2["id"], "closed") # 列出所有活跃会话 active = await db.list_active_sessions() assert len(active) == 2 # 按 agent_id 过滤 active = await db.list_active_sessions("test-agent") assert len(active) == 2 class TestMessageOperations: """Message 操作测试""" @pytest.fixture async def db(self, tmp_path): """创建临时数据库并添加测试会话""" db_path = tmp_path / "test.db" db = Database(db_path) await db.connect() await db.create_agent("test-agent", "Test Agent", "/tmp/workspace") session = await db.create_session("test-agent", "websocket") db._test_session_id = session["id"] yield db await db.close() @pytest.mark.asyncio async def test_add_message(self, db): """测试添加消息""" session_id = db._test_session_id message = await db.add_message( session_id=session_id, role="user", content="Hello, world!", ) assert message["session_id"] == session_id assert message["role"] == "user" assert message["content"] == "Hello, world!" assert message["tokens_used"] == 0 @pytest.mark.asyncio async def test_add_message_with_tool_calls(self, db): """测试添加带工具调用的消息""" session_id = db._test_session_id tool_calls = [ {"id": "call-1", "name": "read_file", "arguments": {"path": "/tmp/test"}} ] message = await db.add_message( session_id=session_id, role="assistant", content=None, tool_calls=tool_calls, tokens_used=100, ) assert message["tool_calls"] == tool_calls assert message["tokens_used"] == 100 @pytest.mark.asyncio async def test_get_messages(self, db): """测试获取会话消息""" session_id = db._test_session_id # 添加多条消息 await db.add_message(session_id, "user", "Message 1") await db.add_message(session_id, "assistant", "Response 1") await db.add_message(session_id, "user", "Message 2") await db.add_message(session_id, "assistant", "Response 2") # 获取所有消息 messages = await db.get_messages(session_id) assert len(messages) == 4 # 验证角色交替 roles = [m["role"] for m in messages] assert roles.count("user") == 2 assert roles.count("assistant") == 2 @pytest.mark.asyncio async def test_get_messages_with_limit(self, db): """测试获取消息数量限制""" session_id = db._test_session_id # 添加多条消息 for i in range(10): await db.add_message(session_id, "user", f"Message {i}") # 限制返回数量 messages = await db.get_messages(session_id, limit=5) assert len(messages) == 5 class TestAuditLog: """审计日志测试""" @pytest.fixture async def db(self, tmp_path): """创建临时数据库""" db_path = tmp_path / "test.db" db = Database(db_path) await db.connect() yield db await db.close() @pytest.mark.asyncio async def test_add_audit_log(self, db): """测试添加审计日志""" # 添加日志不应该报错 await db.add_audit_log( agent_id="test-agent", tool_name="read_file", danger_level="safe", params={"path": "/tmp/test.txt"}, result="success", duration_ms=50, ) @pytest.mark.asyncio async def test_add_audit_log_minimal(self, db): """测试添加最小审计日志""" await db.add_audit_log( agent_id=None, tool_name="python_eval", danger_level="low", ) class TestGlobalDatabase: """全局数据库实例测试""" @pytest.mark.asyncio async def test_import_functions(self): """测试导入全局函数""" from minenasai.core.database import close_database, get_database assert callable(get_database) assert callable(close_database)