"""数据库模块测试""" import pytest from app import database as db class TestPlanCRUD: """Plan CRUD 操作测试""" @pytest.mark.asyncio async def test_create_plan(self, db): """测试创建 Plan""" plan = await db.create_plan( name="Test Plan", provider_name="openai", api_key="sk-test", api_base="https://api.openai.com/v1", plan_type="coding", supported_models=["gpt-4"], ) assert "id" in plan assert plan["name"] == "Test Plan" assert plan["provider_name"] == "openai" @pytest.mark.asyncio async def test_get_plan(self, db): """测试获取 Plan""" created = await db.create_plan( name="Get Test", provider_name="openai", api_key="sk-test", ) plan = await db.get_plan(created["id"]) assert plan is not None assert plan["name"] == "Get Test" assert plan["enabled"] is True @pytest.mark.asyncio async def test_get_plan_not_found(self, db): """测试获取不存在的 Plan""" plan = await db.get_plan("nonexistent") assert plan is None @pytest.mark.asyncio async def test_list_plans(self, db): """测试列出 Plans""" await db.create_plan(name="Plan 1", provider_name="openai", api_key="sk1") await db.create_plan(name="Plan 2", provider_name="kimi", api_key="sk2") await db.create_plan(name="Plan 3", provider_name="google", api_key="sk3", enabled=False) plans = await db.list_plans() assert len(plans) == 3 enabled_only = await db.list_plans(enabled_only=True) assert len(enabled_only) == 2 @pytest.mark.asyncio async def test_update_plan(self, db): """测试更新 Plan""" plan = await db.create_plan(name="Old Name", provider_name="openai", api_key="sk-test") ok = await db.update_plan(plan["id"], name="New Name", plan_type="token") assert ok is True updated = await db.get_plan(plan["id"]) assert updated["name"] == "New Name" assert updated["plan_type"] == "token" @pytest.mark.asyncio async def test_delete_plan(self, db): """测试删除 Plan""" plan = await db.create_plan(name="To Delete", provider_name="openai", api_key="sk-test") ok = await db.delete_plan(plan["id"]) assert ok is True deleted = await db.get_plan(plan["id"]) assert deleted is None class TestQuotaRuleCRUD: """QuotaRule CRUD 操作测试""" @pytest.mark.asyncio async def test_create_quota_rule(self, db): """测试创建 QuotaRule""" plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test") rule = await db.create_quota_rule( plan_id=plan["id"], rule_name="Daily Limit", quota_total=100, quota_unit="requests", refresh_type="calendar_cycle", calendar_unit="daily", calendar_anchor={"hour": 0}, ) assert "id" in rule assert rule["rule_name"] == "Daily Limit" @pytest.mark.asyncio async def test_list_quota_rules(self, db): """测试列出 QuotaRules""" plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test") await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule 1", quota_total=100) await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule 2", quota_total=200) rules = await db.list_quota_rules(plan["id"]) assert len(rules) == 2 @pytest.mark.asyncio async def test_update_quota_rule(self, db): """测试更新 QuotaRule""" plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test") rule = await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100) ok = await db.update_quota_rule(rule["id"], quota_total=200) assert ok is True rules = await db.list_quota_rules(plan["id"]) assert rules[0]["quota_total"] == 200 @pytest.mark.asyncio async def test_delete_quota_rule(self, db): """测试删除 QuotaRule""" plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test") rule = await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100) ok = await db.delete_quota_rule(rule["id"]) assert ok is True rules = await db.list_quota_rules(plan["id"]) assert len(rules) == 0 @pytest.mark.asyncio async def test_increment_quota_used(self, db): """测试增加已用额度""" plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test") await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100, quota_unit="requests") await db.create_quota_rule(plan_id=plan["id"], rule_name="Token Rule", quota_total=1000, quota_unit="tokens") await db.increment_quota_used(plan["id"], token_count=50) rules = await db.list_quota_rules(plan["id"]) # requests 类型 +1 assert rules[0]["quota_used"] == 1 # tokens 类型 +50 assert rules[1]["quota_used"] == 50 class TestCheckPlanAvailable: """Plan 可用性检查测试""" @pytest.mark.asyncio async def test_all_rules_available(self, db): """测试所有规则有余量""" plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test") rule = await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100) available = await db.check_plan_available(plan["id"]) assert available is True @pytest.mark.asyncio async def test_one_rule_exhausted(self, db): """测试一个规则耗尽""" plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test") rule = await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100) await db.update_quota_rule(rule["id"], quota_used=100) available = await db.check_plan_available(plan["id"]) assert available is False @pytest.mark.asyncio async def test_no_enabled_rules(self, db): """测试没有启用的规则""" plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test") rule = await db.create_quota_rule( plan_id=plan["id"], rule_name="Rule", quota_total=100, enabled=False ) available = await db.check_plan_available(plan["id"]) assert available is False # 当前实现会返回 True,这是需要修复的 class TestModelRoute: """模型路由测试""" @pytest.mark.asyncio async def test_set_model_route(self, db): """测试设置模型路由""" plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test") route = await db.set_model_route("gpt-4", plan["id"], priority=10) assert "id" in route assert route["model_name"] == "gpt-4" @pytest.mark.asyncio async def test_set_model_route_update_existing(self, db): """测试设置已存在的路由""" plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test") route1 = await db.set_model_route("gpt-4", plan["id"], priority=10) route2 = await db.set_model_route("gpt-4", plan["id"], priority=20) # 原实现会创建新 ID,当前实现会更新但返回新 ID # 这里只验证不会抛出错误 assert route2["model_name"] == "gpt-4" @pytest.mark.asyncio async def test_resolve_model(self, db): """测试解析模型路由""" plan1 = await db.create_plan(name="Plan 1", provider_name="openai", api_key="sk1") plan2 = await db.create_plan(name="Plan 2", provider_name="kimi", api_key="sk2") await db.set_model_route("gpt-4", plan1["id"], priority=10) await db.set_model_route("gpt-4", plan2["id"], priority=5) # 较低优先级 # 应该返回高优先级的 plan1 resolved = await db.resolve_model("gpt-4") assert resolved == plan1["id"] @pytest.mark.asyncio async def test_resolve_model_with_fallback(self, db): """测试额度耗尽时的 fallback""" plan1 = await db.create_plan(name="Plan 1", provider_name="openai", api_key="sk1") plan2 = await db.create_plan(name="Plan 2", provider_name="kimi", api_key="sk2") await db.set_model_route("gpt-4", plan1["id"], priority=10) await db.set_model_route("gpt-4", plan2["id"], priority=5) # plan1 额度耗尽 rule = await db.create_quota_rule(plan_id=plan1["id"], rule_name="Rule", quota_total=100) await db.update_quota_rule(rule["id"], quota_used=100) # 应该 fallback 到 plan2 resolved = await db.resolve_model("gpt-4") assert resolved == plan2["id"] @pytest.mark.asyncio async def test_delete_model_route(self, db): """测试删除模型路由""" plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test") route = await db.set_model_route("gpt-4", plan["id"], priority=10) ok = await db.delete_model_route(route["id"]) assert ok is True routes = await db.list_model_routes() assert len(routes) == 0 class TestTaskQueue: """任务队列测试""" @pytest.mark.asyncio async def test_create_task(self, db): """测试创建任务""" task = await db.create_task( task_type="image", request_payload={"prompt": "A cat"}, priority=1, ) assert "id" in task assert task["status"] == "pending" @pytest.mark.asyncio async def test_list_tasks(self, db): """测试列出任务""" await db.create_task(task_type="image", request_payload={}) t2 = await db.create_task(task_type="voice", request_payload={}) await db.update_task(t2["id"], status="running") pending = await db.list_tasks(status="pending") assert len(pending) == 1 all_tasks = await db.list_tasks() assert len(all_tasks) == 2 @pytest.mark.asyncio async def test_update_task(self, db): """测试更新任务""" task = await db.create_task(task_type="image", request_payload={}) ok = await db.update_task(task["id"], status="completed") assert ok is True updated = await db.get_task(task["id"]) assert updated["status"] == "completed" @pytest.mark.asyncio async def test_get_task(self, db): """测试获取任务""" task = await db.create_task( task_type="image", request_payload={"prompt": "test"}, ) fetched = await db.get_task(task["id"]) assert fetched is not None assert fetched["task_type"] == "image" assert fetched["request_payload"] == {"prompt": "test"}