feat: 多平台 Coding Plan 统一管理系统初始实现
- 支持 MiniMax/OpenAI/Google Gemini/智谱/Kimi 五个平台 - 插件化 Provider 架构,自动发现注册 - 多维度 QuotaRule 额度追踪(固定间隔/自然周期/API同步/手动) - OpenAI + Anthropic 兼容 API 代理,SSE 流式转发 - Model 路由表 + 额度耗尽自动 fallback - 多媒体任务队列(图片/语音/视频) - Vue3 + Tailwind 单文件 Web 仪表盘 - Docker 一键部署 Made-with: Cursor
This commit is contained in:
@@ -0,0 +1,256 @@
|
||||
"""后台调度器 -- 额度刷新 + 任务队列消费"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from app import database as db
|
||||
|
||||
logger = logging.getLogger("scheduler")
|
||||
|
||||
_task: asyncio.Task | None = None
|
||||
_running = False
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _parse_dt(s: str | None) -> datetime | None:
|
||||
if not s:
|
||||
return None
|
||||
try:
|
||||
return datetime.fromisoformat(s)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def _compute_next_calendar(calendar_unit: str, anchor: dict, after: datetime) -> datetime:
|
||||
"""
|
||||
计算下一个自然周期的刷新时间。
|
||||
anchor 示例:
|
||||
daily: {"hour": 0}
|
||||
weekly: {"weekday": 1, "hour": 0} # 周一 0 点
|
||||
monthly: {"day": 1, "hour": 0} # 每月 1 号
|
||||
"""
|
||||
hour = anchor.get("hour", 0)
|
||||
|
||||
if calendar_unit == "daily":
|
||||
candidate = after.replace(hour=hour, minute=0, second=0, microsecond=0)
|
||||
if candidate <= after:
|
||||
candidate += timedelta(days=1)
|
||||
return candidate
|
||||
|
||||
if calendar_unit == "weekly":
|
||||
weekday = anchor.get("weekday", 1) # 1=Monday
|
||||
days_ahead = (weekday - after.isoweekday()) % 7
|
||||
candidate = (after + timedelta(days=days_ahead)).replace(
|
||||
hour=hour, minute=0, second=0, microsecond=0
|
||||
)
|
||||
if candidate <= after:
|
||||
candidate += timedelta(weeks=1)
|
||||
return candidate
|
||||
|
||||
if calendar_unit == "monthly":
|
||||
day = anchor.get("day", 1)
|
||||
year, month = after.year, after.month
|
||||
try:
|
||||
candidate = after.replace(day=day, hour=hour, minute=0, second=0, microsecond=0)
|
||||
except ValueError:
|
||||
# 日期不存在时(如 2 月 30 号),跳到下月
|
||||
month += 1
|
||||
if month > 12:
|
||||
month, year = 1, year + 1
|
||||
candidate = after.replace(year=year, month=month, day=day,
|
||||
hour=hour, minute=0, second=0, microsecond=0)
|
||||
if candidate <= after:
|
||||
month += 1
|
||||
if month > 12:
|
||||
month, year = 1, year + 1
|
||||
try:
|
||||
candidate = candidate.replace(year=year, month=month)
|
||||
except ValueError:
|
||||
month += 1
|
||||
if month > 12:
|
||||
month, year = 1, year + 1
|
||||
candidate = candidate.replace(year=year, month=month, day=1)
|
||||
return candidate
|
||||
|
||||
return after + timedelta(days=1)
|
||||
|
||||
|
||||
async def _refresh_quota_rules():
|
||||
"""遍历所有 QuotaRule,按刷新策略处理"""
|
||||
now = _now()
|
||||
rules = await db.get_all_quota_rules()
|
||||
|
||||
for rule in rules:
|
||||
rt = rule.get("refresh_type", "manual")
|
||||
|
||||
if rt == "manual":
|
||||
continue
|
||||
|
||||
next_at = _parse_dt(rule.get("next_refresh_at"))
|
||||
|
||||
if rt == "fixed_interval":
|
||||
interval = rule.get("interval_hours")
|
||||
if not interval:
|
||||
continue
|
||||
last_at = _parse_dt(rule.get("last_refresh_at")) or now
|
||||
if next_at is None:
|
||||
next_at = last_at + timedelta(hours=interval)
|
||||
await db.update_quota_rule(rule["id"], next_refresh_at=next_at.isoformat())
|
||||
if now >= next_at:
|
||||
new_next = now + timedelta(hours=interval)
|
||||
await db.update_quota_rule(
|
||||
rule["id"],
|
||||
quota_used=0,
|
||||
last_refresh_at=now.isoformat(),
|
||||
next_refresh_at=new_next.isoformat(),
|
||||
)
|
||||
logger.info("Refreshed rule %s (fixed_interval %sh)", rule["rule_name"], interval)
|
||||
|
||||
elif rt == "calendar_cycle":
|
||||
cal_unit = rule.get("calendar_unit", "daily")
|
||||
anchor = rule.get("calendar_anchor") or {}
|
||||
if next_at is None:
|
||||
last_at = _parse_dt(rule.get("last_refresh_at")) or now
|
||||
next_at = _compute_next_calendar(cal_unit, anchor, last_at)
|
||||
await db.update_quota_rule(rule["id"], next_refresh_at=next_at.isoformat())
|
||||
if now >= next_at:
|
||||
new_next = _compute_next_calendar(cal_unit, anchor, now)
|
||||
await db.update_quota_rule(
|
||||
rule["id"],
|
||||
quota_used=0,
|
||||
last_refresh_at=now.isoformat(),
|
||||
next_refresh_at=new_next.isoformat(),
|
||||
)
|
||||
logger.info("Refreshed rule %s (calendar %s)", rule["rule_name"], cal_unit)
|
||||
|
||||
elif rt == "api_sync":
|
||||
# 每 10 分钟同步一次
|
||||
last_at = _parse_dt(rule.get("last_refresh_at"))
|
||||
if last_at and (now - last_at).total_seconds() < 600:
|
||||
continue
|
||||
plan = await db.get_plan(rule["plan_id"])
|
||||
if not plan:
|
||||
continue
|
||||
from app.providers import ProviderRegistry
|
||||
provider = ProviderRegistry.get(plan["provider_name"])
|
||||
if provider:
|
||||
info = await provider.query_quota(plan)
|
||||
if info:
|
||||
await db.update_quota_rule(
|
||||
rule["id"],
|
||||
quota_used=info.quota_used,
|
||||
last_refresh_at=now.isoformat(),
|
||||
)
|
||||
logger.info("API synced rule %s: used=%d", rule["rule_name"], info.quota_used)
|
||||
|
||||
|
||||
async def _process_task_queue():
|
||||
"""消费待处理任务"""
|
||||
tasks = await db.list_tasks(status="pending", limit=5)
|
||||
for task in tasks:
|
||||
plan_id = task.get("plan_id")
|
||||
if plan_id and not await db.check_plan_available(plan_id):
|
||||
continue # 额度不足,跳过
|
||||
|
||||
await db.update_task(task["id"], status="running", started_at=_now().isoformat())
|
||||
|
||||
try:
|
||||
if plan_id:
|
||||
plan = await db.get_plan(plan_id)
|
||||
if plan:
|
||||
from app.providers import ProviderRegistry
|
||||
provider = ProviderRegistry.get(plan["provider_name"])
|
||||
if provider:
|
||||
result = await _execute_task(provider, plan, task)
|
||||
await db.update_task(
|
||||
task["id"],
|
||||
status="completed",
|
||||
response_payload=result,
|
||||
completed_at=_now().isoformat(),
|
||||
)
|
||||
await db.increment_quota_used(plan_id, token_count=0)
|
||||
continue
|
||||
|
||||
await db.update_task(
|
||||
task["id"],
|
||||
status="failed",
|
||||
response_payload={"error": "No provider available"},
|
||||
completed_at=_now().isoformat(),
|
||||
)
|
||||
except Exception as e:
|
||||
retry = task.get("retry_count", 0) + 1
|
||||
max_r = task.get("max_retries", 3)
|
||||
new_status = "pending" if retry < max_r else "failed"
|
||||
await db.update_task(
|
||||
task["id"],
|
||||
status=new_status,
|
||||
retry_count=retry,
|
||||
response_payload={"error": str(e)},
|
||||
completed_at=_now().isoformat() if new_status == "failed" else None,
|
||||
)
|
||||
logger.error("Task %s failed: %s", task["id"], e)
|
||||
|
||||
|
||||
async def _execute_task(provider, plan: dict, task: dict) -> dict:
|
||||
"""根据 task_type 调用对应的 Provider 方法"""
|
||||
tt = task["task_type"]
|
||||
payload = task.get("request_payload", {})
|
||||
|
||||
if tt == "image":
|
||||
return await provider.generate_image(payload.get("prompt", ""), plan, **payload)
|
||||
elif tt == "voice":
|
||||
audio = await provider.generate_voice(payload.get("text", ""), plan, **payload)
|
||||
# 保存到文件
|
||||
from pathlib import Path
|
||||
from app.config import settings
|
||||
fpath = Path(settings.storage.path) / f"{task['id']}.mp3"
|
||||
fpath.write_bytes(audio)
|
||||
await db.update_task(task["id"], result_file_path=str(fpath), result_mime_type="audio/mp3")
|
||||
return {"file": str(fpath)}
|
||||
elif tt == "video":
|
||||
return await provider.generate_video(payload.get("prompt", ""), plan, **payload)
|
||||
else:
|
||||
return {"error": f"Unknown task type: {tt}"}
|
||||
|
||||
|
||||
async def _scheduler_loop():
|
||||
"""主调度循环"""
|
||||
global _running
|
||||
while _running:
|
||||
try:
|
||||
await _refresh_quota_rules()
|
||||
await _process_task_queue()
|
||||
except Exception as e:
|
||||
logger.error("Scheduler error: %s", e)
|
||||
await asyncio.sleep(30)
|
||||
|
||||
|
||||
async def start_scheduler():
|
||||
global _task, _running
|
||||
# 注册 Provider
|
||||
from app.providers import ProviderRegistry
|
||||
ProviderRegistry.auto_discover()
|
||||
logger.info("Providers: %s", list(ProviderRegistry.all().keys()))
|
||||
|
||||
_running = True
|
||||
_task = asyncio.create_task(_scheduler_loop())
|
||||
logger.info("Scheduler started")
|
||||
|
||||
|
||||
async def stop_scheduler():
|
||||
global _task, _running
|
||||
_running = False
|
||||
if _task:
|
||||
_task.cancel()
|
||||
try:
|
||||
await _task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Scheduler stopped")
|
||||
Reference in New Issue
Block a user