299 lines
11 KiB
Python
299 lines
11 KiB
Python
|
|
"""数据库模块测试"""
|
|||
|
|
|
|||
|
|
import pytest
|
|||
|
|
|
|||
|
|
from app import database as db
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestPlanCRUD:
|
|||
|
|
"""Plan CRUD 操作测试"""
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_create_plan(self, db):
|
|||
|
|
"""测试创建 Plan"""
|
|||
|
|
plan = await db.create_plan(
|
|||
|
|
name="Test Plan",
|
|||
|
|
provider_name="openai",
|
|||
|
|
api_key="sk-test",
|
|||
|
|
api_base="https://api.openai.com/v1",
|
|||
|
|
plan_type="coding",
|
|||
|
|
supported_models=["gpt-4"],
|
|||
|
|
)
|
|||
|
|
assert "id" in plan
|
|||
|
|
assert plan["name"] == "Test Plan"
|
|||
|
|
assert plan["provider_name"] == "openai"
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_get_plan(self, db):
|
|||
|
|
"""测试获取 Plan"""
|
|||
|
|
created = await db.create_plan(
|
|||
|
|
name="Get Test",
|
|||
|
|
provider_name="openai",
|
|||
|
|
api_key="sk-test",
|
|||
|
|
)
|
|||
|
|
plan = await db.get_plan(created["id"])
|
|||
|
|
assert plan is not None
|
|||
|
|
assert plan["name"] == "Get Test"
|
|||
|
|
assert plan["enabled"] is True
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_get_plan_not_found(self, db):
|
|||
|
|
"""测试获取不存在的 Plan"""
|
|||
|
|
plan = await db.get_plan("nonexistent")
|
|||
|
|
assert plan is None
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_list_plans(self, db):
|
|||
|
|
"""测试列出 Plans"""
|
|||
|
|
await db.create_plan(name="Plan 1", provider_name="openai", api_key="sk1")
|
|||
|
|
await db.create_plan(name="Plan 2", provider_name="kimi", api_key="sk2")
|
|||
|
|
await db.create_plan(name="Plan 3", provider_name="google", api_key="sk3", enabled=False)
|
|||
|
|
|
|||
|
|
plans = await db.list_plans()
|
|||
|
|
assert len(plans) == 3
|
|||
|
|
|
|||
|
|
enabled_only = await db.list_plans(enabled_only=True)
|
|||
|
|
assert len(enabled_only) == 2
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_update_plan(self, db):
|
|||
|
|
"""测试更新 Plan"""
|
|||
|
|
plan = await db.create_plan(name="Old Name", provider_name="openai", api_key="sk-test")
|
|||
|
|
ok = await db.update_plan(plan["id"], name="New Name", plan_type="token")
|
|||
|
|
assert ok is True
|
|||
|
|
|
|||
|
|
updated = await db.get_plan(plan["id"])
|
|||
|
|
assert updated["name"] == "New Name"
|
|||
|
|
assert updated["plan_type"] == "token"
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_delete_plan(self, db):
|
|||
|
|
"""测试删除 Plan"""
|
|||
|
|
plan = await db.create_plan(name="To Delete", provider_name="openai", api_key="sk-test")
|
|||
|
|
ok = await db.delete_plan(plan["id"])
|
|||
|
|
assert ok is True
|
|||
|
|
|
|||
|
|
deleted = await db.get_plan(plan["id"])
|
|||
|
|
assert deleted is None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestQuotaRuleCRUD:
|
|||
|
|
"""QuotaRule CRUD 操作测试"""
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_create_quota_rule(self, db):
|
|||
|
|
"""测试创建 QuotaRule"""
|
|||
|
|
plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test")
|
|||
|
|
rule = await db.create_quota_rule(
|
|||
|
|
plan_id=plan["id"],
|
|||
|
|
rule_name="Daily Limit",
|
|||
|
|
quota_total=100,
|
|||
|
|
quota_unit="requests",
|
|||
|
|
refresh_type="calendar_cycle",
|
|||
|
|
calendar_unit="daily",
|
|||
|
|
calendar_anchor={"hour": 0},
|
|||
|
|
)
|
|||
|
|
assert "id" in rule
|
|||
|
|
assert rule["rule_name"] == "Daily Limit"
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_list_quota_rules(self, db):
|
|||
|
|
"""测试列出 QuotaRules"""
|
|||
|
|
plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test")
|
|||
|
|
await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule 1", quota_total=100)
|
|||
|
|
await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule 2", quota_total=200)
|
|||
|
|
|
|||
|
|
rules = await db.list_quota_rules(plan["id"])
|
|||
|
|
assert len(rules) == 2
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_update_quota_rule(self, db):
|
|||
|
|
"""测试更新 QuotaRule"""
|
|||
|
|
plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test")
|
|||
|
|
rule = await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100)
|
|||
|
|
|
|||
|
|
ok = await db.update_quota_rule(rule["id"], quota_total=200)
|
|||
|
|
assert ok is True
|
|||
|
|
|
|||
|
|
rules = await db.list_quota_rules(plan["id"])
|
|||
|
|
assert rules[0]["quota_total"] == 200
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_delete_quota_rule(self, db):
|
|||
|
|
"""测试删除 QuotaRule"""
|
|||
|
|
plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test")
|
|||
|
|
rule = await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100)
|
|||
|
|
|
|||
|
|
ok = await db.delete_quota_rule(rule["id"])
|
|||
|
|
assert ok is True
|
|||
|
|
|
|||
|
|
rules = await db.list_quota_rules(plan["id"])
|
|||
|
|
assert len(rules) == 0
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_increment_quota_used(self, db):
|
|||
|
|
"""测试增加已用额度"""
|
|||
|
|
plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test")
|
|||
|
|
await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100, quota_unit="requests")
|
|||
|
|
await db.create_quota_rule(plan_id=plan["id"], rule_name="Token Rule", quota_total=1000, quota_unit="tokens")
|
|||
|
|
|
|||
|
|
await db.increment_quota_used(plan["id"], token_count=50)
|
|||
|
|
|
|||
|
|
rules = await db.list_quota_rules(plan["id"])
|
|||
|
|
# requests 类型 +1
|
|||
|
|
assert rules[0]["quota_used"] == 1
|
|||
|
|
# tokens 类型 +50
|
|||
|
|
assert rules[1]["quota_used"] == 50
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestCheckPlanAvailable:
|
|||
|
|
"""Plan 可用性检查测试"""
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_all_rules_available(self, db):
|
|||
|
|
"""测试所有规则有余量"""
|
|||
|
|
plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test")
|
|||
|
|
rule = await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100)
|
|||
|
|
|
|||
|
|
available = await db.check_plan_available(plan["id"])
|
|||
|
|
assert available is True
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_one_rule_exhausted(self, db):
|
|||
|
|
"""测试一个规则耗尽"""
|
|||
|
|
plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test")
|
|||
|
|
rule = await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100)
|
|||
|
|
await db.update_quota_rule(rule["id"], quota_used=100)
|
|||
|
|
|
|||
|
|
available = await db.check_plan_available(plan["id"])
|
|||
|
|
assert available is False
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_no_enabled_rules(self, db):
|
|||
|
|
"""测试没有启用的规则"""
|
|||
|
|
plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test")
|
|||
|
|
rule = await db.create_quota_rule(
|
|||
|
|
plan_id=plan["id"], rule_name="Rule", quota_total=100, enabled=False
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
available = await db.check_plan_available(plan["id"])
|
|||
|
|
assert available is False # 当前实现会返回 True,这是需要修复的
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestModelRoute:
|
|||
|
|
"""模型路由测试"""
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_set_model_route(self, db):
|
|||
|
|
"""测试设置模型路由"""
|
|||
|
|
plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test")
|
|||
|
|
route = await db.set_model_route("gpt-4", plan["id"], priority=10)
|
|||
|
|
|
|||
|
|
assert "id" in route
|
|||
|
|
assert route["model_name"] == "gpt-4"
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_set_model_route_update_existing(self, db):
|
|||
|
|
"""测试设置已存在的路由"""
|
|||
|
|
plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test")
|
|||
|
|
route1 = await db.set_model_route("gpt-4", plan["id"], priority=10)
|
|||
|
|
route2 = await db.set_model_route("gpt-4", plan["id"], priority=20)
|
|||
|
|
|
|||
|
|
# 原实现会创建新 ID,当前实现会更新但返回新 ID
|
|||
|
|
# 这里只验证不会抛出错误
|
|||
|
|
assert route2["model_name"] == "gpt-4"
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_resolve_model(self, db):
|
|||
|
|
"""测试解析模型路由"""
|
|||
|
|
plan1 = await db.create_plan(name="Plan 1", provider_name="openai", api_key="sk1")
|
|||
|
|
plan2 = await db.create_plan(name="Plan 2", provider_name="kimi", api_key="sk2")
|
|||
|
|
|
|||
|
|
await db.set_model_route("gpt-4", plan1["id"], priority=10)
|
|||
|
|
await db.set_model_route("gpt-4", plan2["id"], priority=5) # 较低优先级
|
|||
|
|
|
|||
|
|
# 应该返回高优先级的 plan1
|
|||
|
|
resolved = await db.resolve_model("gpt-4")
|
|||
|
|
assert resolved == plan1["id"]
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_resolve_model_with_fallback(self, db):
|
|||
|
|
"""测试额度耗尽时的 fallback"""
|
|||
|
|
plan1 = await db.create_plan(name="Plan 1", provider_name="openai", api_key="sk1")
|
|||
|
|
plan2 = await db.create_plan(name="Plan 2", provider_name="kimi", api_key="sk2")
|
|||
|
|
|
|||
|
|
await db.set_model_route("gpt-4", plan1["id"], priority=10)
|
|||
|
|
await db.set_model_route("gpt-4", plan2["id"], priority=5)
|
|||
|
|
|
|||
|
|
# plan1 额度耗尽
|
|||
|
|
rule = await db.create_quota_rule(plan_id=plan1["id"], rule_name="Rule", quota_total=100)
|
|||
|
|
await db.update_quota_rule(rule["id"], quota_used=100)
|
|||
|
|
|
|||
|
|
# 应该 fallback 到 plan2
|
|||
|
|
resolved = await db.resolve_model("gpt-4")
|
|||
|
|
assert resolved == plan2["id"]
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_delete_model_route(self, db):
|
|||
|
|
"""测试删除模型路由"""
|
|||
|
|
plan = await db.create_plan(name="Test Plan", provider_name="openai", api_key="sk-test")
|
|||
|
|
route = await db.set_model_route("gpt-4", plan["id"], priority=10)
|
|||
|
|
|
|||
|
|
ok = await db.delete_model_route(route["id"])
|
|||
|
|
assert ok is True
|
|||
|
|
|
|||
|
|
routes = await db.list_model_routes()
|
|||
|
|
assert len(routes) == 0
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestTaskQueue:
|
|||
|
|
"""任务队列测试"""
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_create_task(self, db):
|
|||
|
|
"""测试创建任务"""
|
|||
|
|
task = await db.create_task(
|
|||
|
|
task_type="image",
|
|||
|
|
request_payload={"prompt": "A cat"},
|
|||
|
|
priority=1,
|
|||
|
|
)
|
|||
|
|
assert "id" in task
|
|||
|
|
assert task["status"] == "pending"
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_list_tasks(self, db):
|
|||
|
|
"""测试列出任务"""
|
|||
|
|
await db.create_task(task_type="image", request_payload={})
|
|||
|
|
t2 = await db.create_task(task_type="voice", request_payload={})
|
|||
|
|
await db.update_task(t2["id"], status="running")
|
|||
|
|
|
|||
|
|
pending = await db.list_tasks(status="pending")
|
|||
|
|
assert len(pending) == 1
|
|||
|
|
|
|||
|
|
all_tasks = await db.list_tasks()
|
|||
|
|
assert len(all_tasks) == 2
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_update_task(self, db):
|
|||
|
|
"""测试更新任务"""
|
|||
|
|
task = await db.create_task(task_type="image", request_payload={})
|
|||
|
|
|
|||
|
|
ok = await db.update_task(task["id"], status="completed")
|
|||
|
|
assert ok is True
|
|||
|
|
|
|||
|
|
updated = await db.get_task(task["id"])
|
|||
|
|
assert updated["status"] == "completed"
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_get_task(self, db):
|
|||
|
|
"""测试获取任务"""
|
|||
|
|
task = await db.create_task(
|
|||
|
|
task_type="image",
|
|||
|
|
request_payload={"prompt": "test"},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
fetched = await db.get_task(task["id"])
|
|||
|
|
assert fetched is not None
|
|||
|
|
assert fetched["task_type"] == "image"
|
|||
|
|
assert fetched["request_payload"] == {"prompt": "test"}
|