"""后台调度器 -- 额度刷新 + 任务队列消费""" from __future__ import annotations import asyncio import json import logging from datetime import datetime, timedelta, timezone from app import database as db logger = logging.getLogger("scheduler") _task: asyncio.Task | None = None _running = False def _now() -> datetime: return datetime.now(timezone.utc) def _parse_dt(s: str | None) -> datetime | None: if not s: return None try: return datetime.fromisoformat(s) except (ValueError, TypeError): return None def _compute_next_calendar(calendar_unit: str, anchor: dict, after: datetime) -> datetime: """ 计算下一个自然周期的刷新时间。 anchor 示例: daily: {"hour": 0} weekly: {"weekday": 1, "hour": 0} # 周一 0 点 monthly: {"day": 1, "hour": 0} # 每月 1 号 """ hour = anchor.get("hour", 0) if calendar_unit == "daily": candidate = after.replace(hour=hour, minute=0, second=0, microsecond=0) if candidate <= after: candidate += timedelta(days=1) return candidate if calendar_unit == "weekly": weekday = anchor.get("weekday", 1) # 1=Monday days_ahead = (weekday - after.isoweekday()) % 7 candidate = (after + timedelta(days=days_ahead)).replace( hour=hour, minute=0, second=0, microsecond=0 ) if candidate <= after: candidate += timedelta(weeks=1) return candidate if calendar_unit == "monthly": day = anchor.get("day", 1) year, month = after.year, after.month try: candidate = after.replace(day=day, hour=hour, minute=0, second=0, microsecond=0) except ValueError: # 日期不存在时(如 2 月 30 号),跳到下月 month += 1 if month > 12: month, year = 1, year + 1 candidate = after.replace(year=year, month=month, day=day, hour=hour, minute=0, second=0, microsecond=0) if candidate <= after: month += 1 if month > 12: month, year = 1, year + 1 try: candidate = candidate.replace(year=year, month=month) except ValueError: month += 1 if month > 12: month, year = 1, year + 1 candidate = candidate.replace(year=year, month=month, day=1) return candidate return after + timedelta(days=1) async def _refresh_quota_rules(): """遍历所有 QuotaRule,按刷新策略处理""" now = _now() rules = await db.get_all_quota_rules() for rule in rules: rt = rule.get("refresh_type", "manual") if rt == "manual": continue next_at = _parse_dt(rule.get("next_refresh_at")) if rt == "fixed_interval": interval = rule.get("interval_hours") if not interval: continue last_at = _parse_dt(rule.get("last_refresh_at")) or now if next_at is None: next_at = last_at + timedelta(hours=interval) await db.update_quota_rule(rule["id"], next_refresh_at=next_at.isoformat()) if now >= next_at: new_next = now + timedelta(hours=interval) await db.update_quota_rule( rule["id"], quota_used=0, last_refresh_at=now.isoformat(), next_refresh_at=new_next.isoformat(), ) logger.info("Refreshed rule %s (fixed_interval %sh)", rule["rule_name"], interval) elif rt == "calendar_cycle": cal_unit = rule.get("calendar_unit", "daily") anchor = rule.get("calendar_anchor") or {} if next_at is None: last_at = _parse_dt(rule.get("last_refresh_at")) or now next_at = _compute_next_calendar(cal_unit, anchor, last_at) await db.update_quota_rule(rule["id"], next_refresh_at=next_at.isoformat()) if now >= next_at: new_next = _compute_next_calendar(cal_unit, anchor, now) await db.update_quota_rule( rule["id"], quota_used=0, last_refresh_at=now.isoformat(), next_refresh_at=new_next.isoformat(), ) logger.info("Refreshed rule %s (calendar %s)", rule["rule_name"], cal_unit) elif rt == "api_sync": # 每 10 分钟同步一次 last_at = _parse_dt(rule.get("last_refresh_at")) if last_at and (now - last_at).total_seconds() < 600: continue plan = await db.get_plan(rule["plan_id"]) if not plan: continue from app.providers import ProviderRegistry provider = ProviderRegistry.get(plan["provider_name"]) if provider: info = await provider.query_quota(plan) if info: await db.update_quota_rule( rule["id"], quota_used=info.quota_used, last_refresh_at=now.isoformat(), ) logger.info("API synced rule %s: used=%d", rule["rule_name"], info.quota_used) async def _process_task_queue(): """消费待处理任务""" tasks = await db.list_tasks(status="pending", limit=5) for task in tasks: plan_id = task.get("plan_id") if plan_id and not await db.check_plan_available(plan_id): continue # 额度不足,跳过 await db.update_task(task["id"], status="running", started_at=_now().isoformat()) try: if plan_id: plan = await db.get_plan(plan_id) if plan: from app.providers import ProviderRegistry provider = ProviderRegistry.get(plan["provider_name"]) if provider: result = await _execute_task(provider, plan, task) await db.update_task( task["id"], status="completed", response_payload=result, completed_at=_now().isoformat(), ) await db.increment_quota_used(plan_id, token_count=0) continue await db.update_task( task["id"], status="failed", response_payload={"error": "No provider available"}, completed_at=_now().isoformat(), ) except Exception as e: retry = task.get("retry_count", 0) + 1 max_r = task.get("max_retries", 3) new_status = "pending" if retry < max_r else "failed" await db.update_task( task["id"], status=new_status, retry_count=retry, response_payload={"error": str(e)}, completed_at=_now().isoformat() if new_status == "failed" else None, ) logger.error("Task %s failed: %s", task["id"], e) async def _execute_task(provider, plan: dict, task: dict) -> dict: """根据 task_type 调用对应的 Provider 方法""" tt = task["task_type"] payload = task.get("request_payload", {}) if tt == "image": return await provider.generate_image(payload.get("prompt", ""), plan, **payload) elif tt == "voice": audio = await provider.generate_voice(payload.get("text", ""), plan, **payload) # 保存到文件 from pathlib import Path from app.config import settings fpath = Path(settings.storage.path) / f"{task['id']}.mp3" fpath.write_bytes(audio) await db.update_task(task["id"], result_file_path=str(fpath), result_mime_type="audio/mp3") return {"file": str(fpath)} elif tt == "video": return await provider.generate_video(payload.get("prompt", ""), plan, **payload) else: return {"error": f"Unknown task type: {tt}"} async def _scheduler_loop(): """主调度循环""" global _running while _running: try: await _refresh_quota_rules() await _process_task_queue() except Exception as e: logger.error("Scheduler error: %s", e) await asyncio.sleep(30) async def start_scheduler(): global _task, _running # 注册 Provider from app.providers import ProviderRegistry ProviderRegistry.auto_discover() logger.info("Providers: %s", list(ProviderRegistry.all().keys())) _running = True _task = asyncio.create_task(_scheduler_loop()) logger.info("Scheduler started") async def stop_scheduler(): global _task, _running _running = False if _task: _task.cancel() try: await _task except asyncio.CancelledError: pass logger.info("Scheduler stopped")