- 添加 pytest 配置和测试依赖到 requirements.txt - 创建测试包结构和 fixtures (conftest.py) - 添加数据库模块的 CRUD 操作测试 (test_database.py) - 添加 Provider 插件系统测试 (test_providers.py) - 添加调度器模块测试 (test_scheduler.py) - 添加 API 路由测试 (test_api.py) - 添加回归测试覆盖边界条件和错误处理 (test_regressions.py) - 添加健康检查端点用于容器监控 - 修复调度器中的日历计算逻辑和任务执行参数处理 - 更新数据库函数以返回操作结果状态
299 lines
11 KiB
Python
299 lines
11 KiB
Python
"""数据库模块测试"""
|
||
|
||
import pytest
|
||
|
||
from app import database as db
|
||
|
||
|
||
class TestPlanCRUD:
|
||
"""Plan CRUD 操作测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_create_plan(self, db):
|
||
"""测试创建 Plan"""
|
||
plan = await db.create_plan(
|
||
name="Test Plan",
|
||
provider_name="openai",
|
||
api_key="sk-test",
|
||
api_base="https://api.openai.com/v1",
|
||
plan_type="coding",
|
||
supported_models=["gpt-4"],
|
||
)
|
||
assert "id" in plan
|
||
assert plan["name"] == "Test Plan"
|
||
assert plan["provider_name"] == "openai"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_plan(self, db):
|
||
"""测试获取 Plan"""
|
||
created = await db.create_plan(
|
||
name="Get Test",
|
||
provider_name="openai",
|
||
api_key="sk-test",
|
||
)
|
||
plan = await db.get_plan(created["id"])
|
||
assert plan is not None
|
||
assert plan["name"] == "Get Test"
|
||
assert plan["enabled"] is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_plan_not_found(self, db):
|
||
"""测试获取不存在的 Plan"""
|
||
plan = await db.get_plan("nonexistent")
|
||
assert plan is None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_plans(self, db):
|
||
"""测试列出 Plans"""
|
||
await db.create_plan(name="Plan 1", provider_name="openai", api_key="sk1")
|
||
await db.create_plan(name="Plan 2", provider_name="kimi", api_key="sk2")
|
||
await db.create_plan(name="Plan 3", provider_name="google", api_key="sk3", enabled=False)
|
||
|
||
plans = await db.list_plans()
|
||
assert len(plans) == 3
|
||
|
||
enabled_only = await db.list_plans(enabled_only=True)
|
||
assert len(enabled_only) == 2
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_update_plan(self, db):
|
||
"""测试更新 Plan"""
|
||
plan = await db.create_plan(name="Old Name", provider_name="openai", api_key="sk-test")
|
||
ok = await db.update_plan(plan["id"], name="New Name", plan_type="token")
|
||
assert ok is True
|
||
|
||
updated = await db.get_plan(plan["id"])
|
||
assert updated["name"] == "New Name"
|
||
assert updated["plan_type"] == "token"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_delete_plan(self, db):
|
||
"""测试删除 Plan"""
|
||
plan = await db.create_plan(name="To Delete", provider_name="openai", api_key="sk-test")
|
||
ok = await db.delete_plan(plan["id"])
|
||
assert ok is True
|
||
|
||
deleted = await db.get_plan(plan["id"])
|
||
assert deleted is None
|
||
|
||
|
||
class TestQuotaRuleCRUD:
|
||
"""QuotaRule CRUD 操作测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_create_quota_rule(self, db):
|
||
"""测试创建 QuotaRule"""
|
||
plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test")
|
||
rule = await db.create_quota_rule(
|
||
plan_id=plan["id"],
|
||
rule_name="Daily Limit",
|
||
quota_total=100,
|
||
quota_unit="requests",
|
||
refresh_type="calendar_cycle",
|
||
calendar_unit="daily",
|
||
calendar_anchor={"hour": 0},
|
||
)
|
||
assert "id" in rule
|
||
assert rule["rule_name"] == "Daily Limit"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_quota_rules(self, db):
|
||
"""测试列出 QuotaRules"""
|
||
plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test")
|
||
await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule 1", quota_total=100)
|
||
await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule 2", quota_total=200)
|
||
|
||
rules = await db.list_quota_rules(plan["id"])
|
||
assert len(rules) == 2
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_update_quota_rule(self, db):
|
||
"""测试更新 QuotaRule"""
|
||
plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test")
|
||
rule = await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100)
|
||
|
||
ok = await db.update_quota_rule(rule["id"], quota_total=200)
|
||
assert ok is True
|
||
|
||
rules = await db.list_quota_rules(plan["id"])
|
||
assert rules[0]["quota_total"] == 200
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_delete_quota_rule(self, db):
|
||
"""测试删除 QuotaRule"""
|
||
plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test")
|
||
rule = await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100)
|
||
|
||
ok = await db.delete_quota_rule(rule["id"])
|
||
assert ok is True
|
||
|
||
rules = await db.list_quota_rules(plan["id"])
|
||
assert len(rules) == 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_increment_quota_used(self, db):
|
||
"""测试增加已用额度"""
|
||
plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test")
|
||
await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100, quota_unit="requests")
|
||
await db.create_quota_rule(plan_id=plan["id"], rule_name="Token Rule", quota_total=1000, quota_unit="tokens")
|
||
|
||
await db.increment_quota_used(plan["id"], token_count=50)
|
||
|
||
rules = await db.list_quota_rules(plan["id"])
|
||
# requests 类型 +1
|
||
assert rules[0]["quota_used"] == 1
|
||
# tokens 类型 +50
|
||
assert rules[1]["quota_used"] == 50
|
||
|
||
|
||
class TestCheckPlanAvailable:
|
||
"""Plan 可用性检查测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_all_rules_available(self, db):
|
||
"""测试所有规则有余量"""
|
||
plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test")
|
||
rule = await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100)
|
||
|
||
available = await db.check_plan_available(plan["id"])
|
||
assert available is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_one_rule_exhausted(self, db):
|
||
"""测试一个规则耗尽"""
|
||
plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test")
|
||
rule = await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100)
|
||
await db.update_quota_rule(rule["id"], quota_used=100)
|
||
|
||
available = await db.check_plan_available(plan["id"])
|
||
assert available is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_no_enabled_rules(self, db):
|
||
"""测试没有启用的规则"""
|
||
plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test")
|
||
rule = await db.create_quota_rule(
|
||
plan_id=plan["id"], rule_name="Rule", quota_total=100, enabled=False
|
||
)
|
||
|
||
available = await db.check_plan_available(plan["id"])
|
||
assert available is False # 当前实现会返回 True,这是需要修复的
|
||
|
||
|
||
class TestModelRoute:
|
||
"""模型路由测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_set_model_route(self, db):
|
||
"""测试设置模型路由"""
|
||
plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test")
|
||
route = await db.set_model_route("gpt-4", plan["id"], priority=10)
|
||
|
||
assert "id" in route
|
||
assert route["model_name"] == "gpt-4"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_set_model_route_update_existing(self, db):
|
||
"""测试设置已存在的路由"""
|
||
plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test")
|
||
route1 = await db.set_model_route("gpt-4", plan["id"], priority=10)
|
||
route2 = await db.set_model_route("gpt-4", plan["id"], priority=20)
|
||
|
||
# 原实现会创建新 ID,当前实现会更新但返回新 ID
|
||
# 这里只验证不会抛出错误
|
||
assert route2["model_name"] == "gpt-4"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_resolve_model(self, db):
|
||
"""测试解析模型路由"""
|
||
plan1 = await db.create_plan(name="Plan 1", provider_name="openai", api_key="sk1")
|
||
plan2 = await db.create_plan(name="Plan 2", provider_name="kimi", api_key="sk2")
|
||
|
||
await db.set_model_route("gpt-4", plan1["id"], priority=10)
|
||
await db.set_model_route("gpt-4", plan2["id"], priority=5) # 较低优先级
|
||
|
||
# 应该返回高优先级的 plan1
|
||
resolved = await db.resolve_model("gpt-4")
|
||
assert resolved == plan1["id"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_resolve_model_with_fallback(self, db):
|
||
"""测试额度耗尽时的 fallback"""
|
||
plan1 = await db.create_plan(name="Plan 1", provider_name="openai", api_key="sk1")
|
||
plan2 = await db.create_plan(name="Plan 2", provider_name="kimi", api_key="sk2")
|
||
|
||
await db.set_model_route("gpt-4", plan1["id"], priority=10)
|
||
await db.set_model_route("gpt-4", plan2["id"], priority=5)
|
||
|
||
# plan1 额度耗尽
|
||
rule = await db.create_quota_rule(plan_id=plan1["id"], rule_name="Rule", quota_total=100)
|
||
await db.update_quota_rule(rule["id"], quota_used=100)
|
||
|
||
# 应该 fallback 到 plan2
|
||
resolved = await db.resolve_model("gpt-4")
|
||
assert resolved == plan2["id"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_delete_model_route(self, db):
|
||
"""测试删除模型路由"""
|
||
plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test")
|
||
route = await db.set_model_route("gpt-4", plan["id"], priority=10)
|
||
|
||
ok = await db.delete_model_route(route["id"])
|
||
assert ok is True
|
||
|
||
routes = await db.list_model_routes()
|
||
assert len(routes) == 0
|
||
|
||
|
||
class TestTaskQueue:
|
||
"""任务队列测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_create_task(self, db):
|
||
"""测试创建任务"""
|
||
task = await db.create_task(
|
||
task_type="image",
|
||
request_payload={"prompt": "A cat"},
|
||
priority=1,
|
||
)
|
||
assert "id" in task
|
||
assert task["status"] == "pending"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_tasks(self, db):
|
||
"""测试列出任务"""
|
||
await db.create_task(task_type="image", request_payload={})
|
||
t2 = await db.create_task(task_type="voice", request_payload={})
|
||
await db.update_task(t2["id"], status="running")
|
||
|
||
pending = await db.list_tasks(status="pending")
|
||
assert len(pending) == 1
|
||
|
||
all_tasks = await db.list_tasks()
|
||
assert len(all_tasks) == 2
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_update_task(self, db):
|
||
"""测试更新任务"""
|
||
task = await db.create_task(task_type="image", request_payload={})
|
||
|
||
ok = await db.update_task(task["id"], status="completed")
|
||
assert ok is True
|
||
|
||
updated = await db.get_task(task["id"])
|
||
assert updated["status"] == "completed"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_task(self, db):
|
||
"""测试获取任务"""
|
||
task = await db.create_task(
|
||
task_type="image",
|
||
request_payload={"prompt": "test"},
|
||
)
|
||
|
||
fetched = await db.get_task(task["id"])
|
||
assert fetched is not None
|
||
assert fetched["task_type"] == "image"
|
||
assert fetched["request_payload"] == {"prompt": "test"}
|