Files
planManage/app/services/scheduler.py
congsh 37d282c0a2 test: 添加测试框架和全面的单元测试
- 添加 pytest 配置和测试依赖到 requirements.txt
- 创建测试包结构和 fixtures (conftest.py)
- 添加数据库模块的 CRUD 操作测试 (test_database.py)
- 添加 Provider 插件系统测试 (test_providers.py)
- 添加调度器模块测试 (test_scheduler.py)
- 添加 API 路由测试 (test_api.py)
- 添加回归测试覆盖边界条件和错误处理 (test_regressions.py)
- 添加健康检查端点用于容器监控
- 修复调度器中的日历计算逻辑和任务执行参数处理
- 更新数据库函数以返回操作结果状态
2026-03-31 22:36:18 +08:00

263 lines
9.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""后台调度器 -- 额度刷新 + 任务队列消费"""
from __future__ import annotations
import asyncio
import logging
from calendar import monthrange
from datetime import datetime, timedelta, timezone
from app import database as db
logger = logging.getLogger("scheduler")
_task: asyncio.Task | None = None
_running = False
def _now() -> datetime:
return datetime.now(timezone.utc)
def _parse_dt(s: str | None) -> datetime | None:
if not s:
return None
try:
return datetime.fromisoformat(s)
except (ValueError, TypeError):
return None
def _compute_next_calendar(calendar_unit: str, anchor: dict, after: datetime) -> datetime:
"""
计算下一个自然周期的刷新时间。
anchor 示例:
daily: {"hour": 0}
weekly: {"weekday": 1, "hour": 0} # 周一 0 点
monthly: {"day": 1, "hour": 0} # 每月 1 号
"""
hour = anchor.get("hour", 0)
if calendar_unit == "daily":
candidate = after.replace(hour=hour, minute=0, second=0, microsecond=0)
if candidate <= after:
candidate += timedelta(days=1)
return candidate
if calendar_unit == "weekly":
weekday = anchor.get("weekday", 1) # 1=Monday
days_ahead = (weekday - after.isoweekday()) % 7
candidate = (after + timedelta(days=days_ahead)).replace(
hour=hour, minute=0, second=0, microsecond=0
)
if candidate <= after:
candidate += timedelta(weeks=1)
return candidate
if calendar_unit == "monthly":
day = int(anchor.get("day", 1))
if day < 1:
day = 1
if day > 31:
day = 31
year, month = after.year, after.month
month_last_day = monthrange(year, month)[1]
candidate_day = min(day, month_last_day)
candidate = after.replace(
day=candidate_day, hour=hour, minute=0, second=0, microsecond=0
)
if candidate <= after:
month += 1
if month > 12:
month, year = 1, year + 1
month_last_day = monthrange(year, month)[1]
candidate_day = min(day, month_last_day)
candidate = candidate.replace(year=year, month=month, day=candidate_day)
return candidate
return after + timedelta(days=1)
async def _refresh_quota_rules():
"""遍历所有 QuotaRule按刷新策略处理"""
now = _now()
rules = await db.get_all_quota_rules()
for rule in rules:
rt = rule.get("refresh_type", "manual")
if rt == "manual":
continue
next_at = _parse_dt(rule.get("next_refresh_at"))
if rt == "fixed_interval":
interval = rule.get("interval_hours")
if not interval:
continue
last_at = _parse_dt(rule.get("last_refresh_at")) or now
if next_at is None:
next_at = last_at + timedelta(hours=interval)
await db.update_quota_rule(rule["id"], next_refresh_at=next_at.isoformat())
if now >= next_at:
new_next = now + timedelta(hours=interval)
await db.update_quota_rule(
rule["id"],
quota_used=0,
last_refresh_at=now.isoformat(),
next_refresh_at=new_next.isoformat(),
)
logger.info("Refreshed rule %s (fixed_interval %sh)", rule["rule_name"], interval)
elif rt == "calendar_cycle":
cal_unit = rule.get("calendar_unit", "daily")
anchor = rule.get("calendar_anchor") or {}
if next_at is None:
last_at = _parse_dt(rule.get("last_refresh_at")) or now
next_at = _compute_next_calendar(cal_unit, anchor, last_at)
await db.update_quota_rule(rule["id"], next_refresh_at=next_at.isoformat())
if now >= next_at:
new_next = _compute_next_calendar(cal_unit, anchor, now)
await db.update_quota_rule(
rule["id"],
quota_used=0,
last_refresh_at=now.isoformat(),
next_refresh_at=new_next.isoformat(),
)
logger.info("Refreshed rule %s (calendar %s)", rule["rule_name"], cal_unit)
elif rt == "api_sync":
last_at = _parse_dt(rule.get("last_refresh_at"))
sync_interval = rule.get("sync_interval_seconds") or 600 # 默认 10 分钟
if last_at and (now - last_at).total_seconds() < sync_interval:
continue
plan = await db.get_plan(rule["plan_id"])
if not plan:
continue
from app.providers import ProviderRegistry
provider = ProviderRegistry.get(plan["provider_name"])
if provider:
try:
info = await provider.query_quota(plan)
if info:
await db.update_quota_rule(
rule["id"],
quota_used=info.quota_used,
last_refresh_at=now.isoformat(),
)
logger.info("API synced rule %s: used=%d", rule["rule_name"], info.quota_used)
except Exception as e:
logger.error("API sync failed for rule %s: %s", rule["rule_name"], e)
async def _process_task_queue():
"""消费待处理任务"""
from app.config import settings
limit = settings.scheduler.task_processing_limit
tasks = await db.list_tasks(status="pending", limit=limit)
for task in tasks:
plan_id = task.get("plan_id")
if plan_id and not await db.check_plan_available(plan_id):
continue # 额度不足,跳过
await db.update_task(task["id"], status="running", started_at=_now().isoformat())
try:
if plan_id:
plan = await db.get_plan(plan_id)
if plan:
from app.providers import ProviderRegistry
provider = ProviderRegistry.get(plan["provider_name"])
if provider:
result = await _execute_task(provider, plan, task)
await db.update_task(
task["id"],
status="completed",
response_payload=result,
completed_at=_now().isoformat(),
)
await db.increment_quota_used(plan_id, token_count=0)
continue
await db.update_task(
task["id"],
status="failed",
response_payload={"error": "No provider available"},
completed_at=_now().isoformat(),
)
except Exception as e:
retry = task.get("retry_count", 0) + 1
max_r = task.get("max_retries", 3)
new_status = "pending" if retry < max_r else "failed"
await db.update_task(
task["id"],
status=new_status,
retry_count=retry,
response_payload={"error": str(e)},
completed_at=_now().isoformat() if new_status == "failed" else None,
)
logger.error("Task %s failed: %s", task["id"], e)
async def _execute_task(provider, plan: dict, task: dict) -> dict:
"""根据 task_type 调用对应的 Provider 方法"""
tt = task["task_type"]
payload = task.get("request_payload", {})
if tt == "image":
kwargs = dict(payload)
prompt = kwargs.pop("prompt", "")
return await provider.generate_image(prompt, plan, **kwargs)
elif tt == "voice":
kwargs = dict(payload)
text = kwargs.pop("text", "")
audio = await provider.generate_voice(text, plan, **kwargs)
from pathlib import Path
from app.config import settings
fpath = Path(settings.storage.path) / f"{task['id']}.mp3"
fpath.write_bytes(audio)
await db.update_task(task["id"], result_file_path=str(fpath), result_mime_type="audio/mp3")
return {"file": str(fpath)}
elif tt == "video":
kwargs = dict(payload)
prompt = kwargs.pop("prompt", "")
return await provider.generate_video(prompt, plan, **kwargs)
else:
return {"error": f"Unknown task type: {tt}"}
async def _scheduler_loop():
"""主调度循环"""
global _running
while _running:
try:
await _refresh_quota_rules()
await _process_task_queue()
except Exception as e:
logger.error("Scheduler error: %s", e)
await asyncio.sleep(30)
async def start_scheduler():
global _task, _running
# 注册 Provider
from app.providers import ProviderRegistry
ProviderRegistry.auto_discover()
logger.info("Providers: %s", list(ProviderRegistry.all().keys()))
_running = True
_task = asyncio.create_task(_scheduler_loop())
logger.info("Scheduler started")
async def stop_scheduler():
global _task, _running
_running = False
if _task:
_task.cancel()
try:
await _task
except asyncio.CancelledError:
pass
logger.info("Scheduler stopped")