test: 添加测试框架和全面的单元测试

- 添加 pytest 配置和测试依赖到 requirements.txt
- 创建测试包结构和 fixtures (conftest.py)
- 添加数据库模块的 CRUD 操作测试 (test_database.py)
- 添加 Provider 插件系统测试 (test_providers.py)
- 添加调度器模块测试 (test_scheduler.py)
- 添加 API 路由测试 (test_api.py)
- 添加回归测试覆盖边界条件和错误处理 (test_regressions.py)
- 添加健康检查端点用于容器监控
- 修复调度器中的日历计算逻辑和任务执行参数处理
- 更新数据库函数以返回操作结果状态
This commit is contained in:
congsh
2026-03-31 22:36:18 +08:00
parent 61ce809634
commit 37d282c0a2
17 changed files with 1769 additions and 50 deletions

View File

@@ -27,6 +27,12 @@ class StorageConfig(BaseModel):
path: str = "./data/files"
class SchedulerConfig(BaseModel):
"""调度器配置"""
task_processing_limit: int = 5 # 每次处理的最大任务数
loop_interval_seconds: int = 30 # 主循环间隔(秒)
class QuotaRuleSeed(BaseModel):
"""config.yaml 中单条 QuotaRule 种子"""
rule_name: str
@@ -55,6 +61,7 @@ class AppConfig(BaseModel):
server: ServerConfig = Field(default_factory=ServerConfig)
database: DatabaseConfig = Field(default_factory=DatabaseConfig)
storage: StorageConfig = Field(default_factory=StorageConfig)
scheduler: SchedulerConfig = Field(default_factory=SchedulerConfig)
plans: list[PlanSeed] = Field(default_factory=list)

View File

@@ -306,6 +306,13 @@ async def update_quota_rule(rule_id: str, **fields) -> bool:
return True
async def delete_quota_rule(rule_id: str) -> bool:
db = await get_db()
cur = await db.execute("DELETE FROM quota_rules WHERE id=?", (rule_id,))
await db.commit()
return cur.rowcount > 0
async def increment_quota_used(plan_id: str, token_count: int = 0):
"""请求完成后增加该 Plan 所有 Rule 的 quota_used"""
db = await get_db()
@@ -441,6 +448,18 @@ async def update_task(task_id: str, **fields) -> bool:
return True
async def get_task(task_id: str) -> dict | None:
db = await get_db()
cur = await db.execute("SELECT * FROM tasks WHERE id=?", (task_id,))
row = await cur.fetchone()
if not row:
return None
t = row_to_dict(row)
t["request_payload"] = _parse_json(t["request_payload"], {})
t["response_payload"] = _parse_json(t.get("response_payload"))
return t
# ── 种子数据导入 ──────────────────────────────────────
async def seed_from_config():

View File

@@ -50,3 +50,9 @@ _static_dir = Path(__file__).parent / "static"
@app.get("/", include_in_schema=False)
async def serve_index():
return FileResponse(_static_dir / "index.html")
@app.get("/health", tags=["Health"])
async def health_check():
"""健康检查端点,用于容器健康检查"""
return {"status": "healthy", "service": "plan-manager"}

View File

@@ -105,10 +105,8 @@ async def update_rule(rule_id: str, body: QuotaRuleUpdate):
@router.delete("/rules/{rule_id}")
async def delete_rule(rule_id: str):
d = await db.get_db()
cur = await d.execute("DELETE FROM quota_rules WHERE id=?", (rule_id,))
await d.commit()
if cur.rowcount == 0:
ok = await db.delete_quota_rule(rule_id)
if not ok:
raise HTTPException(404, "Rule not found")
return {"ok": True}

View File

@@ -3,7 +3,6 @@
from __future__ import annotations
import json
import time
from typing import Any
from fastapi import APIRouter, Header, HTTPException, Request
@@ -18,8 +17,8 @@ router = APIRouter()
def _verify_key(authorization: str | None):
expected = settings.server.proxy_api_key
if not expected or expected == "sk-plan-manage-change-me":
return # 未配置则跳过鉴权
if not expected:
return
if not authorization:
raise HTTPException(401, "Missing Authorization header")
token = authorization.removeprefix("Bearer ").strip()
@@ -150,13 +149,17 @@ async def anthropic_messages(
if not provider:
raise HTTPException(500, f"Provider '{plan['provider_name']}' not registered")
extra_kwargs = {k: v for k, v in body.items() if k not in ("model", "messages", "stream", "system")}
if stream:
async def anthropic_stream():
"""将 OpenAI SSE 格式转换为 Anthropic SSE 格式"""
yield f"event: message_start\ndata: {json.dumps({'type': 'message_start', 'message': {'id': 'msg_proxy', 'type': 'message', 'role': 'assistant', 'model': model, 'content': []}})}\n\n"
yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}})}\n\n"
async for chunk_data in _stream_and_count(provider, oai_messages, model, plan, True):
async for chunk_data in _stream_and_count(
provider, oai_messages, model, plan, True, **extra_kwargs
):
if chunk_data.startswith("data: [DONE]"):
break
if chunk_data.startswith("data: "):
@@ -179,7 +182,9 @@ async def anthropic_messages(
)
else:
chunks = []
async for c in _stream_and_count(provider, oai_messages, model, plan, False):
async for c in _stream_and_count(
provider, oai_messages, model, plan, False, **extra_kwargs
):
chunks.append(c)
oai_resp = json.loads(chunks[0]) if chunks else {}
# OpenAI 响应 -> Anthropic 响应

View File

@@ -27,14 +27,9 @@ async def create_task(body: TaskCreate):
@router.get("/{task_id}", response_model=TaskOut)
async def get_task(task_id: str):
d = await db.get_db()
cur = await d.execute("SELECT * FROM tasks WHERE id=?", (task_id,))
row = await cur.fetchone()
if not row:
t = await db.get_task(task_id)
if not t:
raise HTTPException(404, "Task not found")
t = db.row_to_dict(row)
t["request_payload"] = db._parse_json(t["request_payload"], {})
t["response_payload"] = db._parse_json(t.get("response_payload"))
return t

View File

@@ -3,8 +3,8 @@
from __future__ import annotations
import asyncio
import json
import logging
from calendar import monthrange
from datetime import datetime, timedelta, timezone
from app import database as db
@@ -55,28 +55,24 @@ def _compute_next_calendar(calendar_unit: str, anchor: dict, after: datetime) ->
return candidate
if calendar_unit == "monthly":
day = anchor.get("day", 1)
day = int(anchor.get("day", 1))
if day < 1:
day = 1
if day > 31:
day = 31
year, month = after.year, after.month
try:
candidate = after.replace(day=day, hour=hour, minute=0, second=0, microsecond=0)
except ValueError:
# 日期不存在时(如 2 月 30 号),跳到下月
month += 1
if month > 12:
month, year = 1, year + 1
candidate = after.replace(year=year, month=month, day=day,
hour=hour, minute=0, second=0, microsecond=0)
month_last_day = monthrange(year, month)[1]
candidate_day = min(day, month_last_day)
candidate = after.replace(
day=candidate_day, hour=hour, minute=0, second=0, microsecond=0
)
if candidate <= after:
month += 1
if month > 12:
month, year = 1, year + 1
try:
candidate = candidate.replace(year=year, month=month)
except ValueError:
month += 1
if month > 12:
month, year = 1, year + 1
candidate = candidate.replace(year=year, month=month, day=1)
month_last_day = monthrange(year, month)[1]
candidate_day = min(day, month_last_day)
candidate = candidate.replace(year=year, month=month, day=candidate_day)
return candidate
return after + timedelta(days=1)
@@ -131,9 +127,9 @@ async def _refresh_quota_rules():
logger.info("Refreshed rule %s (calendar %s)", rule["rule_name"], cal_unit)
elif rt == "api_sync":
# 每 10 分钟同步一次
last_at = _parse_dt(rule.get("last_refresh_at"))
if last_at and (now - last_at).total_seconds() < 600:
sync_interval = rule.get("sync_interval_seconds") or 600 # 默认 10 分钟
if last_at and (now - last_at).total_seconds() < sync_interval:
continue
plan = await db.get_plan(rule["plan_id"])
if not plan:
@@ -141,19 +137,24 @@ async def _refresh_quota_rules():
from app.providers import ProviderRegistry
provider = ProviderRegistry.get(plan["provider_name"])
if provider:
info = await provider.query_quota(plan)
if info:
await db.update_quota_rule(
rule["id"],
quota_used=info.quota_used,
last_refresh_at=now.isoformat(),
)
logger.info("API synced rule %s: used=%d", rule["rule_name"], info.quota_used)
try:
info = await provider.query_quota(plan)
if info:
await db.update_quota_rule(
rule["id"],
quota_used=info.quota_used,
last_refresh_at=now.isoformat(),
)
logger.info("API synced rule %s: used=%d", rule["rule_name"], info.quota_used)
except Exception as e:
logger.error("API sync failed for rule %s: %s", rule["rule_name"], e)
async def _process_task_queue():
"""消费待处理任务"""
tasks = await db.list_tasks(status="pending", limit=5)
from app.config import settings
limit = settings.scheduler.task_processing_limit
tasks = await db.list_tasks(status="pending", limit=limit)
for task in tasks:
plan_id = task.get("plan_id")
if plan_id and not await db.check_plan_available(plan_id):
@@ -204,10 +205,13 @@ async def _execute_task(provider, plan: dict, task: dict) -> dict:
payload = task.get("request_payload", {})
if tt == "image":
return await provider.generate_image(payload.get("prompt", ""), plan, **payload)
kwargs = dict(payload)
prompt = kwargs.pop("prompt", "")
return await provider.generate_image(prompt, plan, **kwargs)
elif tt == "voice":
audio = await provider.generate_voice(payload.get("text", ""), plan, **payload)
# 保存到文件
kwargs = dict(payload)
text = kwargs.pop("text", "")
audio = await provider.generate_voice(text, plan, **kwargs)
from pathlib import Path
from app.config import settings
fpath = Path(settings.storage.path) / f"{task['id']}.mp3"
@@ -215,7 +219,9 @@ async def _execute_task(provider, plan: dict, task: dict) -> dict:
await db.update_task(task["id"], result_file_path=str(fpath), result_mime_type="audio/mp3")
return {"file": str(fpath)}
elif tt == "video":
return await provider.generate_video(payload.get("prompt", ""), plan, **payload)
kwargs = dict(payload)
prompt = kwargs.pop("prompt", "")
return await provider.generate_video(prompt, plan, **kwargs)
else:
return {"error": f"Unknown task type: {tt}"}