Files
planManage/app/services/scheduler.py
锦麟 王 61ce809634 feat: 多平台 Coding Plan 统一管理系统初始实现
- 支持 MiniMax/OpenAI/Google Gemini/智谱/Kimi 五个平台
- 插件化 Provider 架构,自动发现注册
- 多维度 QuotaRule 额度追踪(固定间隔/自然周期/API同步/手动)
- OpenAI + Anthropic 兼容 API 代理,SSE 流式转发
- Model 路由表 + 额度耗尽自动 fallback
- 多媒体任务队列(图片/语音/视频)
- Vue3 + Tailwind 单文件 Web 仪表盘
- Docker 一键部署

Made-with: Cursor
2026-03-31 15:50:42 +08:00

257 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""后台调度器 -- 额度刷新 + 任务队列消费"""
from __future__ import annotations
import asyncio
import json
import logging
from datetime import datetime, timedelta, timezone
from app import database as db
logger = logging.getLogger("scheduler")
_task: asyncio.Task | None = None
_running = False
def _now() -> datetime:
return datetime.now(timezone.utc)
def _parse_dt(s: str | None) -> datetime | None:
if not s:
return None
try:
return datetime.fromisoformat(s)
except (ValueError, TypeError):
return None
def _compute_next_calendar(calendar_unit: str, anchor: dict, after: datetime) -> datetime:
"""
计算下一个自然周期的刷新时间。
anchor 示例:
daily: {"hour": 0}
weekly: {"weekday": 1, "hour": 0} # 周一 0 点
monthly: {"day": 1, "hour": 0} # 每月 1 号
"""
hour = anchor.get("hour", 0)
if calendar_unit == "daily":
candidate = after.replace(hour=hour, minute=0, second=0, microsecond=0)
if candidate <= after:
candidate += timedelta(days=1)
return candidate
if calendar_unit == "weekly":
weekday = anchor.get("weekday", 1) # 1=Monday
days_ahead = (weekday - after.isoweekday()) % 7
candidate = (after + timedelta(days=days_ahead)).replace(
hour=hour, minute=0, second=0, microsecond=0
)
if candidate <= after:
candidate += timedelta(weeks=1)
return candidate
if calendar_unit == "monthly":
day = anchor.get("day", 1)
year, month = after.year, after.month
try:
candidate = after.replace(day=day, hour=hour, minute=0, second=0, microsecond=0)
except ValueError:
# 日期不存在时(如 2 月 30 号),跳到下月
month += 1
if month > 12:
month, year = 1, year + 1
candidate = after.replace(year=year, month=month, day=day,
hour=hour, minute=0, second=0, microsecond=0)
if candidate <= after:
month += 1
if month > 12:
month, year = 1, year + 1
try:
candidate = candidate.replace(year=year, month=month)
except ValueError:
month += 1
if month > 12:
month, year = 1, year + 1
candidate = candidate.replace(year=year, month=month, day=1)
return candidate
return after + timedelta(days=1)
async def _refresh_quota_rules():
"""遍历所有 QuotaRule按刷新策略处理"""
now = _now()
rules = await db.get_all_quota_rules()
for rule in rules:
rt = rule.get("refresh_type", "manual")
if rt == "manual":
continue
next_at = _parse_dt(rule.get("next_refresh_at"))
if rt == "fixed_interval":
interval = rule.get("interval_hours")
if not interval:
continue
last_at = _parse_dt(rule.get("last_refresh_at")) or now
if next_at is None:
next_at = last_at + timedelta(hours=interval)
await db.update_quota_rule(rule["id"], next_refresh_at=next_at.isoformat())
if now >= next_at:
new_next = now + timedelta(hours=interval)
await db.update_quota_rule(
rule["id"],
quota_used=0,
last_refresh_at=now.isoformat(),
next_refresh_at=new_next.isoformat(),
)
logger.info("Refreshed rule %s (fixed_interval %sh)", rule["rule_name"], interval)
elif rt == "calendar_cycle":
cal_unit = rule.get("calendar_unit", "daily")
anchor = rule.get("calendar_anchor") or {}
if next_at is None:
last_at = _parse_dt(rule.get("last_refresh_at")) or now
next_at = _compute_next_calendar(cal_unit, anchor, last_at)
await db.update_quota_rule(rule["id"], next_refresh_at=next_at.isoformat())
if now >= next_at:
new_next = _compute_next_calendar(cal_unit, anchor, now)
await db.update_quota_rule(
rule["id"],
quota_used=0,
last_refresh_at=now.isoformat(),
next_refresh_at=new_next.isoformat(),
)
logger.info("Refreshed rule %s (calendar %s)", rule["rule_name"], cal_unit)
elif rt == "api_sync":
# 每 10 分钟同步一次
last_at = _parse_dt(rule.get("last_refresh_at"))
if last_at and (now - last_at).total_seconds() < 600:
continue
plan = await db.get_plan(rule["plan_id"])
if not plan:
continue
from app.providers import ProviderRegistry
provider = ProviderRegistry.get(plan["provider_name"])
if provider:
info = await provider.query_quota(plan)
if info:
await db.update_quota_rule(
rule["id"],
quota_used=info.quota_used,
last_refresh_at=now.isoformat(),
)
logger.info("API synced rule %s: used=%d", rule["rule_name"], info.quota_used)
async def _process_task_queue():
"""消费待处理任务"""
tasks = await db.list_tasks(status="pending", limit=5)
for task in tasks:
plan_id = task.get("plan_id")
if plan_id and not await db.check_plan_available(plan_id):
continue # 额度不足,跳过
await db.update_task(task["id"], status="running", started_at=_now().isoformat())
try:
if plan_id:
plan = await db.get_plan(plan_id)
if plan:
from app.providers import ProviderRegistry
provider = ProviderRegistry.get(plan["provider_name"])
if provider:
result = await _execute_task(provider, plan, task)
await db.update_task(
task["id"],
status="completed",
response_payload=result,
completed_at=_now().isoformat(),
)
await db.increment_quota_used(plan_id, token_count=0)
continue
await db.update_task(
task["id"],
status="failed",
response_payload={"error": "No provider available"},
completed_at=_now().isoformat(),
)
except Exception as e:
retry = task.get("retry_count", 0) + 1
max_r = task.get("max_retries", 3)
new_status = "pending" if retry < max_r else "failed"
await db.update_task(
task["id"],
status=new_status,
retry_count=retry,
response_payload={"error": str(e)},
completed_at=_now().isoformat() if new_status == "failed" else None,
)
logger.error("Task %s failed: %s", task["id"], e)
async def _execute_task(provider, plan: dict, task: dict) -> dict:
"""根据 task_type 调用对应的 Provider 方法"""
tt = task["task_type"]
payload = task.get("request_payload", {})
if tt == "image":
return await provider.generate_image(payload.get("prompt", ""), plan, **payload)
elif tt == "voice":
audio = await provider.generate_voice(payload.get("text", ""), plan, **payload)
# 保存到文件
from pathlib import Path
from app.config import settings
fpath = Path(settings.storage.path) / f"{task['id']}.mp3"
fpath.write_bytes(audio)
await db.update_task(task["id"], result_file_path=str(fpath), result_mime_type="audio/mp3")
return {"file": str(fpath)}
elif tt == "video":
return await provider.generate_video(payload.get("prompt", ""), plan, **payload)
else:
return {"error": f"Unknown task type: {tt}"}
async def _scheduler_loop():
"""主调度循环"""
global _running
while _running:
try:
await _refresh_quota_rules()
await _process_task_queue()
except Exception as e:
logger.error("Scheduler error: %s", e)
await asyncio.sleep(30)
async def start_scheduler():
global _task, _running
# 注册 Provider
from app.providers import ProviderRegistry
ProviderRegistry.auto_discover()
logger.info("Providers: %s", list(ProviderRegistry.all().keys()))
_running = True
_task = asyncio.create_task(_scheduler_loop())
logger.info("Scheduler started")
async def stop_scheduler():
global _task, _running
_running = False
if _task:
_task.cancel()
try:
await _task
except asyncio.CancelledError:
pass
logger.info("Scheduler stopped")