test: 添加测试框架和全面的单元测试

- 添加 pytest 配置和测试依赖到 requirements.txt
- 创建测试包结构和 fixtures (conftest.py)
- 添加数据库模块的 CRUD 操作测试 (test_database.py)
- 添加 Provider 插件系统测试 (test_providers.py)
- 添加调度器模块测试 (test_scheduler.py)
- 添加 API 路由测试 (test_api.py)
- 添加回归测试覆盖边界条件和错误处理 (test_regressions.py)
- 添加健康检查端点用于容器监控
- 修复调度器中的日历计算逻辑和任务执行参数处理
- 更新数据库函数以返回操作结果状态
This commit is contained in:
congsh
2026-03-31 22:36:18 +08:00
parent 61ce809634
commit 37d282c0a2
17 changed files with 1769 additions and 50 deletions
+39 -33
View File
@@ -3,8 +3,8 @@
from __future__ import annotations
import asyncio
import json
import logging
from calendar import monthrange
from datetime import datetime, timedelta, timezone
from app import database as db
@@ -55,28 +55,24 @@ def _compute_next_calendar(calendar_unit: str, anchor: dict, after: datetime) ->
return candidate
if calendar_unit == "monthly":
day = anchor.get("day", 1)
day = int(anchor.get("day", 1))
if day < 1:
day = 1
if day > 31:
day = 31
year, month = after.year, after.month
try:
candidate = after.replace(day=day, hour=hour, minute=0, second=0, microsecond=0)
except ValueError:
# 日期不存在时(如 2 月 30 号),跳到下月
month += 1
if month > 12:
month, year = 1, year + 1
candidate = after.replace(year=year, month=month, day=day,
hour=hour, minute=0, second=0, microsecond=0)
month_last_day = monthrange(year, month)[1]
candidate_day = min(day, month_last_day)
candidate = after.replace(
day=candidate_day, hour=hour, minute=0, second=0, microsecond=0
)
if candidate <= after:
month += 1
if month > 12:
month, year = 1, year + 1
try:
candidate = candidate.replace(year=year, month=month)
except ValueError:
month += 1
if month > 12:
month, year = 1, year + 1
candidate = candidate.replace(year=year, month=month, day=1)
month_last_day = monthrange(year, month)[1]
candidate_day = min(day, month_last_day)
candidate = candidate.replace(year=year, month=month, day=candidate_day)
return candidate
return after + timedelta(days=1)
@@ -131,9 +127,9 @@ async def _refresh_quota_rules():
logger.info("Refreshed rule %s (calendar %s)", rule["rule_name"], cal_unit)
elif rt == "api_sync":
# 每 10 分钟同步一次
last_at = _parse_dt(rule.get("last_refresh_at"))
if last_at and (now - last_at).total_seconds() < 600:
sync_interval = rule.get("sync_interval_seconds") or 600 # 默认 10 分钟
if last_at and (now - last_at).total_seconds() < sync_interval:
continue
plan = await db.get_plan(rule["plan_id"])
if not plan:
@@ -141,19 +137,24 @@ async def _refresh_quota_rules():
from app.providers import ProviderRegistry
provider = ProviderRegistry.get(plan["provider_name"])
if provider:
info = await provider.query_quota(plan)
if info:
await db.update_quota_rule(
rule["id"],
quota_used=info.quota_used,
last_refresh_at=now.isoformat(),
)
logger.info("API synced rule %s: used=%d", rule["rule_name"], info.quota_used)
try:
info = await provider.query_quota(plan)
if info:
await db.update_quota_rule(
rule["id"],
quota_used=info.quota_used,
last_refresh_at=now.isoformat(),
)
logger.info("API synced rule %s: used=%d", rule["rule_name"], info.quota_used)
except Exception as e:
logger.error("API sync failed for rule %s: %s", rule["rule_name"], e)
async def _process_task_queue():
"""消费待处理任务"""
tasks = await db.list_tasks(status="pending", limit=5)
from app.config import settings
limit = settings.scheduler.task_processing_limit
tasks = await db.list_tasks(status="pending", limit=limit)
for task in tasks:
plan_id = task.get("plan_id")
if plan_id and not await db.check_plan_available(plan_id):
@@ -204,10 +205,13 @@ async def _execute_task(provider, plan: dict, task: dict) -> dict:
payload = task.get("request_payload", {})
if tt == "image":
return await provider.generate_image(payload.get("prompt", ""), plan, **payload)
kwargs = dict(payload)
prompt = kwargs.pop("prompt", "")
return await provider.generate_image(prompt, plan, **kwargs)
elif tt == "voice":
audio = await provider.generate_voice(payload.get("text", ""), plan, **payload)
# 保存到文件
kwargs = dict(payload)
text = kwargs.pop("text", "")
audio = await provider.generate_voice(text, plan, **kwargs)
from pathlib import Path
from app.config import settings
fpath = Path(settings.storage.path) / f"{task['id']}.mp3"
@@ -215,7 +219,9 @@ async def _execute_task(provider, plan: dict, task: dict) -> dict:
await db.update_task(task["id"], result_file_path=str(fpath), result_mime_type="audio/mp3")
return {"file": str(fpath)}
elif tt == "video":
return await provider.generate_video(payload.get("prompt", ""), plan, **payload)
kwargs = dict(payload)
prompt = kwargs.pop("prompt", "")
return await provider.generate_video(prompt, plan, **kwargs)
else:
return {"error": f"Unknown task type: {tt}"}