test: 添加测试框架和全面的单元测试
- 添加 pytest 配置和测试依赖到 requirements.txt - 创建测试包结构和 fixtures (conftest.py) - 添加数据库模块的 CRUD 操作测试 (test_database.py) - 添加 Provider 插件系统测试 (test_providers.py) - 添加调度器模块测试 (test_scheduler.py) - 添加 API 路由测试 (test_api.py) - 添加回归测试覆盖边界条件和错误处理 (test_regressions.py) - 添加健康检查端点用于容器监控 - 修复调度器中的日历计算逻辑和任务执行参数处理 - 更新数据库函数以返回操作结果状态
This commit is contained in:
+39
-33
@@ -3,8 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from calendar import monthrange
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from app import database as db
|
||||
@@ -55,28 +55,24 @@ def _compute_next_calendar(calendar_unit: str, anchor: dict, after: datetime) ->
|
||||
return candidate
|
||||
|
||||
if calendar_unit == "monthly":
|
||||
day = anchor.get("day", 1)
|
||||
day = int(anchor.get("day", 1))
|
||||
if day < 1:
|
||||
day = 1
|
||||
if day > 31:
|
||||
day = 31
|
||||
year, month = after.year, after.month
|
||||
try:
|
||||
candidate = after.replace(day=day, hour=hour, minute=0, second=0, microsecond=0)
|
||||
except ValueError:
|
||||
# 日期不存在时(如 2 月 30 号),跳到下月
|
||||
month += 1
|
||||
if month > 12:
|
||||
month, year = 1, year + 1
|
||||
candidate = after.replace(year=year, month=month, day=day,
|
||||
hour=hour, minute=0, second=0, microsecond=0)
|
||||
month_last_day = monthrange(year, month)[1]
|
||||
candidate_day = min(day, month_last_day)
|
||||
candidate = after.replace(
|
||||
day=candidate_day, hour=hour, minute=0, second=0, microsecond=0
|
||||
)
|
||||
if candidate <= after:
|
||||
month += 1
|
||||
if month > 12:
|
||||
month, year = 1, year + 1
|
||||
try:
|
||||
candidate = candidate.replace(year=year, month=month)
|
||||
except ValueError:
|
||||
month += 1
|
||||
if month > 12:
|
||||
month, year = 1, year + 1
|
||||
candidate = candidate.replace(year=year, month=month, day=1)
|
||||
month_last_day = monthrange(year, month)[1]
|
||||
candidate_day = min(day, month_last_day)
|
||||
candidate = candidate.replace(year=year, month=month, day=candidate_day)
|
||||
return candidate
|
||||
|
||||
return after + timedelta(days=1)
|
||||
@@ -131,9 +127,9 @@ async def _refresh_quota_rules():
|
||||
logger.info("Refreshed rule %s (calendar %s)", rule["rule_name"], cal_unit)
|
||||
|
||||
elif rt == "api_sync":
|
||||
# 每 10 分钟同步一次
|
||||
last_at = _parse_dt(rule.get("last_refresh_at"))
|
||||
if last_at and (now - last_at).total_seconds() < 600:
|
||||
sync_interval = rule.get("sync_interval_seconds") or 600 # 默认 10 分钟
|
||||
if last_at and (now - last_at).total_seconds() < sync_interval:
|
||||
continue
|
||||
plan = await db.get_plan(rule["plan_id"])
|
||||
if not plan:
|
||||
@@ -141,19 +137,24 @@ async def _refresh_quota_rules():
|
||||
from app.providers import ProviderRegistry
|
||||
provider = ProviderRegistry.get(plan["provider_name"])
|
||||
if provider:
|
||||
info = await provider.query_quota(plan)
|
||||
if info:
|
||||
await db.update_quota_rule(
|
||||
rule["id"],
|
||||
quota_used=info.quota_used,
|
||||
last_refresh_at=now.isoformat(),
|
||||
)
|
||||
logger.info("API synced rule %s: used=%d", rule["rule_name"], info.quota_used)
|
||||
try:
|
||||
info = await provider.query_quota(plan)
|
||||
if info:
|
||||
await db.update_quota_rule(
|
||||
rule["id"],
|
||||
quota_used=info.quota_used,
|
||||
last_refresh_at=now.isoformat(),
|
||||
)
|
||||
logger.info("API synced rule %s: used=%d", rule["rule_name"], info.quota_used)
|
||||
except Exception as e:
|
||||
logger.error("API sync failed for rule %s: %s", rule["rule_name"], e)
|
||||
|
||||
|
||||
async def _process_task_queue():
|
||||
"""消费待处理任务"""
|
||||
tasks = await db.list_tasks(status="pending", limit=5)
|
||||
from app.config import settings
|
||||
limit = settings.scheduler.task_processing_limit
|
||||
tasks = await db.list_tasks(status="pending", limit=limit)
|
||||
for task in tasks:
|
||||
plan_id = task.get("plan_id")
|
||||
if plan_id and not await db.check_plan_available(plan_id):
|
||||
@@ -204,10 +205,13 @@ async def _execute_task(provider, plan: dict, task: dict) -> dict:
|
||||
payload = task.get("request_payload", {})
|
||||
|
||||
if tt == "image":
|
||||
return await provider.generate_image(payload.get("prompt", ""), plan, **payload)
|
||||
kwargs = dict(payload)
|
||||
prompt = kwargs.pop("prompt", "")
|
||||
return await provider.generate_image(prompt, plan, **kwargs)
|
||||
elif tt == "voice":
|
||||
audio = await provider.generate_voice(payload.get("text", ""), plan, **payload)
|
||||
# 保存到文件
|
||||
kwargs = dict(payload)
|
||||
text = kwargs.pop("text", "")
|
||||
audio = await provider.generate_voice(text, plan, **kwargs)
|
||||
from pathlib import Path
|
||||
from app.config import settings
|
||||
fpath = Path(settings.storage.path) / f"{task['id']}.mp3"
|
||||
@@ -215,7 +219,9 @@ async def _execute_task(provider, plan: dict, task: dict) -> dict:
|
||||
await db.update_task(task["id"], result_file_path=str(fpath), result_mime_type="audio/mp3")
|
||||
return {"file": str(fpath)}
|
||||
elif tt == "video":
|
||||
return await provider.generate_video(payload.get("prompt", ""), plan, **payload)
|
||||
kwargs = dict(payload)
|
||||
prompt = kwargs.pop("prompt", "")
|
||||
return await provider.generate_video(prompt, plan, **kwargs)
|
||||
else:
|
||||
return {"error": f"Unknown task type: {tt}"}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user