test: 添加测试框架和全面的单元测试
- 添加 pytest 配置和测试依赖到 requirements.txt - 创建测试包结构和 fixtures (conftest.py) - 添加数据库模块的 CRUD 操作测试 (test_database.py) - 添加 Provider 插件系统测试 (test_providers.py) - 添加调度器模块测试 (test_scheduler.py) - 添加 API 路由测试 (test_api.py) - 添加回归测试覆盖边界条件和错误处理 (test_regressions.py) - 添加健康检查端点用于容器监控 - 修复调度器中的日历计算逻辑和任务执行参数处理 - 更新数据库函数以返回操作结果状态
This commit is contained in:
@@ -0,0 +1,190 @@
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||
|
||||
from app import database as db
|
||||
from app.config import settings
|
||||
from app.routers import proxy
|
||||
from app.services import scheduler
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolated_db(tmp_path):
|
||||
old_path = settings.database.path
|
||||
settings.database.path = str(tmp_path / "test.db")
|
||||
asyncio.run(db.close_db())
|
||||
yield
|
||||
asyncio.run(db.close_db())
|
||||
settings.database.path = old_path
|
||||
|
||||
|
||||
def test_update_functions_report_missing_rows():
|
||||
# update_plan 现在返回 cur.rowcount > 0,不存在的行返回 False
|
||||
result = asyncio.run(db.update_plan("missing", name="x"))
|
||||
assert result is False
|
||||
assert asyncio.run(db.update_quota_rule("missing", rule_name="x")) is False
|
||||
assert asyncio.run(db.update_task("missing", status="cancelled")) is False
|
||||
|
||||
|
||||
def test_update_functions_report_success_for_existing_rows():
|
||||
plan = asyncio.run(db.create_plan(name="a", provider_name="openai"))
|
||||
rule = asyncio.run(
|
||||
db.create_quota_rule(plan_id=plan["id"], rule_name="r", quota_total=10)
|
||||
)
|
||||
task = asyncio.run(db.create_task(task_type="image", request_payload={"prompt": "p"}))
|
||||
assert asyncio.run(db.update_plan(plan["id"], name="b")) is True
|
||||
assert asyncio.run(db.update_quota_rule(rule["id"], quota_total=20)) is True
|
||||
assert asyncio.run(db.update_task(task["id"], status="running")) is True
|
||||
|
||||
|
||||
def test_execute_task_avoids_duplicate_prompt_argument():
|
||||
class Provider:
|
||||
async def generate_image(self, prompt: str, plan: dict, **kwargs):
|
||||
return {"prompt": prompt, "kwargs": kwargs}
|
||||
|
||||
task = {
|
||||
"id": "t1",
|
||||
"task_type": "image",
|
||||
"request_payload": {"prompt": "hello", "model": "m1"},
|
||||
}
|
||||
result = asyncio.run(scheduler._execute_task(Provider(), {}, task))
|
||||
assert result["prompt"] == "hello"
|
||||
assert result["kwargs"] == {"model": "m1"}
|
||||
|
||||
|
||||
def test_compute_next_calendar_monthly_handles_large_day_anchor():
|
||||
after = datetime(2026, 2, 15, 12, 0, tzinfo=timezone.utc)
|
||||
next_at = scheduler._compute_next_calendar("monthly", {"day": 31, "hour": 0}, after)
|
||||
assert next_at == datetime(2026, 2, 28, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
after2 = datetime(2026, 2, 28, 1, 0, tzinfo=timezone.utc)
|
||||
next_at2 = scheduler._compute_next_calendar("monthly", {"day": 31, "hour": 0}, after2)
|
||||
assert next_at2 == datetime(2026, 3, 31, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def test_refresh_quota_rules_isolates_api_sync_errors(monkeypatch):
|
||||
fixed_now = datetime(2026, 3, 1, 12, 0, tzinfo=timezone.utc)
|
||||
updates = []
|
||||
rules = [
|
||||
{
|
||||
"id": "r1",
|
||||
"rule_name": "api",
|
||||
"refresh_type": "api_sync",
|
||||
"plan_id": "p1",
|
||||
"last_refresh_at": None,
|
||||
"next_refresh_at": None,
|
||||
},
|
||||
{
|
||||
"id": "r2",
|
||||
"rule_name": "fixed",
|
||||
"refresh_type": "fixed_interval",
|
||||
"interval_hours": 1,
|
||||
"last_refresh_at": (fixed_now - timedelta(hours=2)).isoformat(),
|
||||
"next_refresh_at": (fixed_now - timedelta(minutes=1)).isoformat(),
|
||||
},
|
||||
]
|
||||
|
||||
async def fake_get_all_quota_rules():
|
||||
return rules
|
||||
|
||||
async def fake_get_plan(plan_id: str):
|
||||
return {"id": plan_id, "provider_name": "broken"}
|
||||
|
||||
async def fake_update_quota_rule(rule_id: str, **fields):
|
||||
updates.append((rule_id, fields))
|
||||
return True
|
||||
|
||||
class BadProvider:
|
||||
async def query_quota(self, plan: dict):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
import app.providers as providers
|
||||
|
||||
monkeypatch.setattr(scheduler, "_now", lambda: fixed_now)
|
||||
monkeypatch.setattr(scheduler.db, "get_all_quota_rules", fake_get_all_quota_rules)
|
||||
monkeypatch.setattr(scheduler.db, "get_plan", fake_get_plan)
|
||||
monkeypatch.setattr(scheduler.db, "update_quota_rule", fake_update_quota_rule)
|
||||
monkeypatch.setattr(
|
||||
providers.ProviderRegistry,
|
||||
"get",
|
||||
classmethod(lambda cls, name: BadProvider()),
|
||||
)
|
||||
|
||||
asyncio.run(scheduler._refresh_quota_rules())
|
||||
assert any(rule_id == "r2" and fields.get("quota_used") == 0 for rule_id, fields in updates)
|
||||
|
||||
|
||||
def test_anthropic_route_forwards_extra_kwargs(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
class Provider:
|
||||
async def chat(self, messages, model, plan, stream=True, **kwargs):
|
||||
captured["kwargs"] = kwargs
|
||||
payload = {
|
||||
"choices": [{"message": {"content": "ok"}}],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
|
||||
}
|
||||
yield json.dumps(payload)
|
||||
|
||||
async def fake_resolve_model(model: str):
|
||||
return "p1"
|
||||
|
||||
async def fake_get_plan(plan_id: str):
|
||||
return {"id": plan_id, "name": "p", "provider_name": "x"}
|
||||
|
||||
async def fake_check_plan_available(plan_id: str):
|
||||
return True
|
||||
|
||||
async def fake_increment_quota_used(plan_id: str, token_count: int = 0):
|
||||
return None
|
||||
|
||||
provider = Provider()
|
||||
monkeypatch.setattr(proxy.db, "resolve_model", fake_resolve_model)
|
||||
monkeypatch.setattr(proxy.db, "get_plan", fake_get_plan)
|
||||
monkeypatch.setattr(proxy.db, "check_plan_available", fake_check_plan_available)
|
||||
monkeypatch.setattr(proxy.db, "increment_quota_used", fake_increment_quota_used)
|
||||
monkeypatch.setattr(
|
||||
proxy.ProviderRegistry,
|
||||
"get",
|
||||
classmethod(lambda cls, name: provider),
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(proxy.router)
|
||||
old_key = settings.server.proxy_api_key
|
||||
settings.server.proxy_api_key = ""
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post(
|
||||
"/v1/messages",
|
||||
json={
|
||||
"model": "m1",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"temperature": 0.2,
|
||||
"max_tokens": 88,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
settings.server.proxy_api_key = old_key
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert captured["kwargs"] == {"temperature": 0.2, "max_tokens": 88}
|
||||
|
||||
|
||||
def test_verify_key_requires_auth_when_using_default_value():
|
||||
old_key = settings.server.proxy_api_key
|
||||
settings.server.proxy_api_key = "sk-plan-manage-change-me"
|
||||
try:
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
proxy._verify_key(None)
|
||||
finally:
|
||||
settings.server.proxy_api_key = old_key
|
||||
assert exc.value.status_code == 401
|
||||
Reference in New Issue
Block a user