- 添加 pytest 配置和测试依赖到 requirements.txt - 创建测试包结构和 fixtures (conftest.py) - 添加数据库模块的 CRUD 操作测试 (test_database.py) - 添加 Provider 插件系统测试 (test_providers.py) - 添加调度器模块测试 (test_scheduler.py) - 添加 API 路由测试 (test_api.py) - 添加回归测试覆盖边界条件和错误处理 (test_regressions.py) - 添加健康检查端点用于容器监控 - 修复调度器中的日历计算逻辑和任务执行参数处理 - 更新数据库函数以返回操作结果状态
263 lines
9.2 KiB
Python
263 lines
9.2 KiB
Python
"""后台调度器 -- 额度刷新 + 任务队列消费"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
from calendar import monthrange
|
||
from datetime import datetime, timedelta, timezone
|
||
|
||
from app import database as db
|
||
|
||
logger = logging.getLogger("scheduler")
|
||
|
||
_task: asyncio.Task | None = None
|
||
_running = False
|
||
|
||
|
||
def _now() -> datetime:
|
||
return datetime.now(timezone.utc)
|
||
|
||
|
||
def _parse_dt(s: str | None) -> datetime | None:
|
||
if not s:
|
||
return None
|
||
try:
|
||
return datetime.fromisoformat(s)
|
||
except (ValueError, TypeError):
|
||
return None
|
||
|
||
|
||
def _compute_next_calendar(calendar_unit: str, anchor: dict, after: datetime) -> datetime:
|
||
"""
|
||
计算下一个自然周期的刷新时间。
|
||
anchor 示例:
|
||
daily: {"hour": 0}
|
||
weekly: {"weekday": 1, "hour": 0} # 周一 0 点
|
||
monthly: {"day": 1, "hour": 0} # 每月 1 号
|
||
"""
|
||
hour = anchor.get("hour", 0)
|
||
|
||
if calendar_unit == "daily":
|
||
candidate = after.replace(hour=hour, minute=0, second=0, microsecond=0)
|
||
if candidate <= after:
|
||
candidate += timedelta(days=1)
|
||
return candidate
|
||
|
||
if calendar_unit == "weekly":
|
||
weekday = anchor.get("weekday", 1) # 1=Monday
|
||
days_ahead = (weekday - after.isoweekday()) % 7
|
||
candidate = (after + timedelta(days=days_ahead)).replace(
|
||
hour=hour, minute=0, second=0, microsecond=0
|
||
)
|
||
if candidate <= after:
|
||
candidate += timedelta(weeks=1)
|
||
return candidate
|
||
|
||
if calendar_unit == "monthly":
|
||
day = int(anchor.get("day", 1))
|
||
if day < 1:
|
||
day = 1
|
||
if day > 31:
|
||
day = 31
|
||
year, month = after.year, after.month
|
||
month_last_day = monthrange(year, month)[1]
|
||
candidate_day = min(day, month_last_day)
|
||
candidate = after.replace(
|
||
day=candidate_day, hour=hour, minute=0, second=0, microsecond=0
|
||
)
|
||
if candidate <= after:
|
||
month += 1
|
||
if month > 12:
|
||
month, year = 1, year + 1
|
||
month_last_day = monthrange(year, month)[1]
|
||
candidate_day = min(day, month_last_day)
|
||
candidate = candidate.replace(year=year, month=month, day=candidate_day)
|
||
return candidate
|
||
|
||
return after + timedelta(days=1)
|
||
|
||
|
||
async def _refresh_quota_rules():
|
||
"""遍历所有 QuotaRule,按刷新策略处理"""
|
||
now = _now()
|
||
rules = await db.get_all_quota_rules()
|
||
|
||
for rule in rules:
|
||
rt = rule.get("refresh_type", "manual")
|
||
|
||
if rt == "manual":
|
||
continue
|
||
|
||
next_at = _parse_dt(rule.get("next_refresh_at"))
|
||
|
||
if rt == "fixed_interval":
|
||
interval = rule.get("interval_hours")
|
||
if not interval:
|
||
continue
|
||
last_at = _parse_dt(rule.get("last_refresh_at")) or now
|
||
if next_at is None:
|
||
next_at = last_at + timedelta(hours=interval)
|
||
await db.update_quota_rule(rule["id"], next_refresh_at=next_at.isoformat())
|
||
if now >= next_at:
|
||
new_next = now + timedelta(hours=interval)
|
||
await db.update_quota_rule(
|
||
rule["id"],
|
||
quota_used=0,
|
||
last_refresh_at=now.isoformat(),
|
||
next_refresh_at=new_next.isoformat(),
|
||
)
|
||
logger.info("Refreshed rule %s (fixed_interval %sh)", rule["rule_name"], interval)
|
||
|
||
elif rt == "calendar_cycle":
|
||
cal_unit = rule.get("calendar_unit", "daily")
|
||
anchor = rule.get("calendar_anchor") or {}
|
||
if next_at is None:
|
||
last_at = _parse_dt(rule.get("last_refresh_at")) or now
|
||
next_at = _compute_next_calendar(cal_unit, anchor, last_at)
|
||
await db.update_quota_rule(rule["id"], next_refresh_at=next_at.isoformat())
|
||
if now >= next_at:
|
||
new_next = _compute_next_calendar(cal_unit, anchor, now)
|
||
await db.update_quota_rule(
|
||
rule["id"],
|
||
quota_used=0,
|
||
last_refresh_at=now.isoformat(),
|
||
next_refresh_at=new_next.isoformat(),
|
||
)
|
||
logger.info("Refreshed rule %s (calendar %s)", rule["rule_name"], cal_unit)
|
||
|
||
elif rt == "api_sync":
|
||
last_at = _parse_dt(rule.get("last_refresh_at"))
|
||
sync_interval = rule.get("sync_interval_seconds") or 600 # 默认 10 分钟
|
||
if last_at and (now - last_at).total_seconds() < sync_interval:
|
||
continue
|
||
plan = await db.get_plan(rule["plan_id"])
|
||
if not plan:
|
||
continue
|
||
from app.providers import ProviderRegistry
|
||
provider = ProviderRegistry.get(plan["provider_name"])
|
||
if provider:
|
||
try:
|
||
info = await provider.query_quota(plan)
|
||
if info:
|
||
await db.update_quota_rule(
|
||
rule["id"],
|
||
quota_used=info.quota_used,
|
||
last_refresh_at=now.isoformat(),
|
||
)
|
||
logger.info("API synced rule %s: used=%d", rule["rule_name"], info.quota_used)
|
||
except Exception as e:
|
||
logger.error("API sync failed for rule %s: %s", rule["rule_name"], e)
|
||
|
||
|
||
async def _process_task_queue():
|
||
"""消费待处理任务"""
|
||
from app.config import settings
|
||
limit = settings.scheduler.task_processing_limit
|
||
tasks = await db.list_tasks(status="pending", limit=limit)
|
||
for task in tasks:
|
||
plan_id = task.get("plan_id")
|
||
if plan_id and not await db.check_plan_available(plan_id):
|
||
continue # 额度不足,跳过
|
||
|
||
await db.update_task(task["id"], status="running", started_at=_now().isoformat())
|
||
|
||
try:
|
||
if plan_id:
|
||
plan = await db.get_plan(plan_id)
|
||
if plan:
|
||
from app.providers import ProviderRegistry
|
||
provider = ProviderRegistry.get(plan["provider_name"])
|
||
if provider:
|
||
result = await _execute_task(provider, plan, task)
|
||
await db.update_task(
|
||
task["id"],
|
||
status="completed",
|
||
response_payload=result,
|
||
completed_at=_now().isoformat(),
|
||
)
|
||
await db.increment_quota_used(plan_id, token_count=0)
|
||
continue
|
||
|
||
await db.update_task(
|
||
task["id"],
|
||
status="failed",
|
||
response_payload={"error": "No provider available"},
|
||
completed_at=_now().isoformat(),
|
||
)
|
||
except Exception as e:
|
||
retry = task.get("retry_count", 0) + 1
|
||
max_r = task.get("max_retries", 3)
|
||
new_status = "pending" if retry < max_r else "failed"
|
||
await db.update_task(
|
||
task["id"],
|
||
status=new_status,
|
||
retry_count=retry,
|
||
response_payload={"error": str(e)},
|
||
completed_at=_now().isoformat() if new_status == "failed" else None,
|
||
)
|
||
logger.error("Task %s failed: %s", task["id"], e)
|
||
|
||
|
||
async def _execute_task(provider, plan: dict, task: dict) -> dict:
|
||
"""根据 task_type 调用对应的 Provider 方法"""
|
||
tt = task["task_type"]
|
||
payload = task.get("request_payload", {})
|
||
|
||
if tt == "image":
|
||
kwargs = dict(payload)
|
||
prompt = kwargs.pop("prompt", "")
|
||
return await provider.generate_image(prompt, plan, **kwargs)
|
||
elif tt == "voice":
|
||
kwargs = dict(payload)
|
||
text = kwargs.pop("text", "")
|
||
audio = await provider.generate_voice(text, plan, **kwargs)
|
||
from pathlib import Path
|
||
from app.config import settings
|
||
fpath = Path(settings.storage.path) / f"{task['id']}.mp3"
|
||
fpath.write_bytes(audio)
|
||
await db.update_task(task["id"], result_file_path=str(fpath), result_mime_type="audio/mp3")
|
||
return {"file": str(fpath)}
|
||
elif tt == "video":
|
||
kwargs = dict(payload)
|
||
prompt = kwargs.pop("prompt", "")
|
||
return await provider.generate_video(prompt, plan, **kwargs)
|
||
else:
|
||
return {"error": f"Unknown task type: {tt}"}
|
||
|
||
|
||
async def _scheduler_loop():
|
||
"""主调度循环"""
|
||
global _running
|
||
while _running:
|
||
try:
|
||
await _refresh_quota_rules()
|
||
await _process_task_queue()
|
||
except Exception as e:
|
||
logger.error("Scheduler error: %s", e)
|
||
await asyncio.sleep(30)
|
||
|
||
|
||
async def start_scheduler():
|
||
global _task, _running
|
||
# 注册 Provider
|
||
from app.providers import ProviderRegistry
|
||
ProviderRegistry.auto_discover()
|
||
logger.info("Providers: %s", list(ProviderRegistry.all().keys()))
|
||
|
||
_running = True
|
||
_task = asyncio.create_task(_scheduler_loop())
|
||
logger.info("Scheduler started")
|
||
|
||
|
||
async def stop_scheduler():
|
||
global _task, _running
|
||
_running = False
|
||
if _task:
|
||
_task.cancel()
|
||
try:
|
||
await _task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
logger.info("Scheduler stopped")
|