"""SQLite 数据库管理 -- 异步连接 + 自动建表""" from __future__ import annotations import json import os import uuid from datetime import datetime, timezone from pathlib import Path from typing import Any import aiosqlite from app.config import settings _db: aiosqlite.Connection | None = None SQL_CREATE_TABLES = """ CREATE TABLE IF NOT EXISTS plans ( id TEXT PRIMARY KEY, name TEXT NOT NULL, provider_name TEXT NOT NULL, api_key TEXT DEFAULT '', api_base TEXT DEFAULT '', plan_type TEXT DEFAULT 'coding', supported_models TEXT DEFAULT '[]', extra_headers TEXT DEFAULT '{}', extra_config TEXT DEFAULT '{}', enabled INTEGER DEFAULT 1, created_at TEXT NOT NULL, updated_at TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS quota_rules ( id TEXT PRIMARY KEY, plan_id TEXT NOT NULL REFERENCES plans(id) ON DELETE CASCADE, rule_name TEXT NOT NULL, quota_total INTEGER NOT NULL DEFAULT 0, quota_used INTEGER NOT NULL DEFAULT 0, quota_unit TEXT DEFAULT 'requests', refresh_type TEXT DEFAULT 'calendar_cycle', interval_hours REAL, calendar_unit TEXT, calendar_anchor TEXT, last_refresh_at TEXT, next_refresh_at TEXT, enabled INTEGER DEFAULT 1 ); CREATE TABLE IF NOT EXISTS quota_snapshots ( id TEXT PRIMARY KEY, rule_id TEXT NOT NULL REFERENCES quota_rules(id) ON DELETE CASCADE, quota_used INTEGER NOT NULL, quota_remaining INTEGER NOT NULL, checked_at TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS model_routes ( id TEXT PRIMARY KEY, model_name TEXT NOT NULL, plan_id TEXT NOT NULL REFERENCES plans(id) ON DELETE CASCADE, priority INTEGER DEFAULT 0 ); CREATE TABLE IF NOT EXISTS tasks ( id TEXT PRIMARY KEY, plan_id TEXT REFERENCES plans(id) ON DELETE SET NULL, task_type TEXT NOT NULL, status TEXT DEFAULT 'pending', request_payload TEXT DEFAULT '{}', response_payload TEXT, result_file_path TEXT, result_mime_type TEXT, priority INTEGER DEFAULT 0, retry_count INTEGER DEFAULT 0, max_retries INTEGER DEFAULT 3, callback_url TEXT, created_at TEXT NOT NULL, started_at TEXT, completed_at TEXT ); CREATE INDEX IF NOT EXISTS idx_quota_rules_plan ON quota_rules(plan_id); CREATE INDEX IF NOT EXISTS idx_model_routes_model ON model_routes(model_name); CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status); """ def _now_iso() -> str: return datetime.now(timezone.utc).isoformat() def new_id() -> str: return uuid.uuid4().hex[:16] async def get_db() -> aiosqlite.Connection: global _db if _db is None: db_path = Path(settings.database.path) db_path.parent.mkdir(parents=True, exist_ok=True) _db = await aiosqlite.connect(str(db_path)) _db.row_factory = aiosqlite.Row await _db.execute("PRAGMA journal_mode=WAL") await _db.execute("PRAGMA foreign_keys=ON") await _db.executescript(SQL_CREATE_TABLES) await _db.commit() return _db async def close_db(): global _db if _db: await _db.close() _db = None # ── 通用辅助 ────────────────────────────────────────── def row_to_dict(row: aiosqlite.Row) -> dict[str, Any]: return dict(row) def _parse_json(val: str | None, default: Any = None) -> Any: if val is None: return default try: return json.loads(val) except (json.JSONDecodeError, TypeError): return default # ── Plan CRUD ───────────────────────────────────────── async def create_plan( name: str, provider_name: str, api_key: str = "", api_base: str = "", plan_type: str = "coding", supported_models: list[str] | None = None, extra_headers: dict | None = None, extra_config: dict | None = None, enabled: bool = True, ) -> dict: db = await get_db() pid = new_id() now = _now_iso() await db.execute( """INSERT INTO plans (id, name, provider_name, api_key, api_base, plan_type, supported_models, extra_headers, extra_config, enabled, created_at, updated_at) VALUES (?,?,?,?,?,?,?,?,?,?,?,?)""", ( pid, name, provider_name, api_key, api_base, plan_type, json.dumps(supported_models or []), json.dumps(extra_headers or {}), json.dumps(extra_config or {}), int(enabled), now, now, ), ) await db.commit() return {"id": pid, "name": name, "provider_name": provider_name} async def get_plan(plan_id: str) -> dict | None: db = await get_db() cur = await db.execute("SELECT * FROM plans WHERE id=?", (plan_id,)) row = await cur.fetchone() if not row: return None d = row_to_dict(row) d["supported_models"] = _parse_json(d["supported_models"], []) d["extra_headers"] = _parse_json(d["extra_headers"], {}) d["extra_config"] = _parse_json(d["extra_config"], {}) d["enabled"] = bool(d["enabled"]) return d async def list_plans(enabled_only: bool = False) -> list[dict]: db = await get_db() sql = "SELECT * FROM plans" if enabled_only: sql += " WHERE enabled=1" cur = await db.execute(sql) rows = await cur.fetchall() result = [] for row in rows: d = row_to_dict(row) d["supported_models"] = _parse_json(d["supported_models"], []) d["extra_headers"] = _parse_json(d["extra_headers"], {}) d["extra_config"] = _parse_json(d["extra_config"], {}) d["enabled"] = bool(d["enabled"]) result.append(d) return result async def update_plan(plan_id: str, **fields) -> bool: db = await get_db() json_fields = ("supported_models", "extra_headers", "extra_config") sets, vals = [], [] for k, v in fields.items(): if v is None: continue if k in json_fields: v = json.dumps(v) if k == "enabled": v = int(v) sets.append(f"{k}=?") vals.append(v) if not sets: return False sets.append("updated_at=?") vals.append(_now_iso()) vals.append(plan_id) await db.execute(f"UPDATE plans SET {', '.join(sets)} WHERE id=?", vals) await db.commit() return True async def delete_plan(plan_id: str) -> bool: db = await get_db() cur = await db.execute("DELETE FROM plans WHERE id=?", (plan_id,)) await db.commit() return cur.rowcount > 0 # ── QuotaRule CRUD ──────────────────────────────────── async def create_quota_rule( plan_id: str, rule_name: str, quota_total: int, quota_unit: str = "requests", refresh_type: str = "calendar_cycle", interval_hours: float | None = None, calendar_unit: str | None = None, calendar_anchor: dict | None = None, enabled: bool = True, ) -> dict: db = await get_db() rid = new_id() now = _now_iso() await db.execute( """INSERT INTO quota_rules (id, plan_id, rule_name, quota_total, quota_used, quota_unit, refresh_type, interval_hours, calendar_unit, calendar_anchor, last_refresh_at, next_refresh_at, enabled) VALUES (?,?,?,?,0,?,?,?,?,?,?,?,?)""", ( rid, plan_id, rule_name, quota_total, quota_unit, refresh_type, interval_hours, calendar_unit, json.dumps(calendar_anchor) if calendar_anchor else None, now, None, int(enabled), ), ) await db.commit() return {"id": rid, "plan_id": plan_id, "rule_name": rule_name} async def list_quota_rules(plan_id: str) -> list[dict]: db = await get_db() cur = await db.execute("SELECT * FROM quota_rules WHERE plan_id=?", (plan_id,)) rows = await cur.fetchall() result = [] for row in rows: d = row_to_dict(row) d["calendar_anchor"] = _parse_json(d.get("calendar_anchor"), {}) d["enabled"] = bool(d["enabled"]) result.append(d) return result async def get_all_quota_rules() -> list[dict]: """获取全部 QuotaRule,供调度器使用""" db = await get_db() cur = await db.execute("SELECT * FROM quota_rules WHERE enabled=1") rows = await cur.fetchall() result = [] for row in rows: d = row_to_dict(row) d["calendar_anchor"] = _parse_json(d.get("calendar_anchor"), {}) d["enabled"] = bool(d["enabled"]) result.append(d) return result async def update_quota_rule(rule_id: str, **fields) -> bool: db = await get_db() json_fields = ("calendar_anchor",) sets, vals = [], [] for k, v in fields.items(): if v is None: continue if k in json_fields: v = json.dumps(v) if k == "enabled": v = int(v) sets.append(f"{k}=?") vals.append(v) if not sets: return False vals.append(rule_id) await db.execute(f"UPDATE quota_rules SET {', '.join(sets)} WHERE id=?", vals) await db.commit() return True async def delete_quota_rule(rule_id: str) -> bool: db = await get_db() cur = await db.execute("DELETE FROM quota_rules WHERE id=?", (rule_id,)) await db.commit() return cur.rowcount > 0 async def increment_quota_used(plan_id: str, token_count: int = 0): """请求完成后增加该 Plan 所有 Rule 的 quota_used""" db = await get_db() rules = await list_quota_rules(plan_id) for r in rules: if not r["enabled"]: continue inc = token_count if r["quota_unit"] == "tokens" else 1 await db.execute( "UPDATE quota_rules SET quota_used = quota_used + ? WHERE id=?", (inc, r["id"]), ) await db.commit() async def check_plan_available(plan_id: str) -> bool: """判断 Plan 所有 Rule 是否都有余量""" rules = await list_quota_rules(plan_id) for r in rules: if not r["enabled"]: continue if r["quota_used"] >= r["quota_total"]: return False return True # ── Model Route ─────────────────────────────────────── async def set_model_route(model_name: str, plan_id: str, priority: int = 0) -> dict: db = await get_db() mid = new_id() await db.execute( "DELETE FROM model_routes WHERE model_name=? AND plan_id=?", (model_name, plan_id), ) await db.execute( "INSERT INTO model_routes (id, model_name, plan_id, priority) VALUES (?,?,?,?)", (mid, model_name, plan_id, priority), ) await db.commit() return {"id": mid, "model_name": model_name, "plan_id": plan_id} async def resolve_model(model_name: str) -> str | None: """按 priority 降序找到可用的 plan_id(fallback 逻辑)""" db = await get_db() cur = await db.execute( "SELECT plan_id FROM model_routes WHERE model_name=? ORDER BY priority DESC", (model_name,), ) rows = await cur.fetchall() for row in rows: pid = row["plan_id"] if await check_plan_available(pid): return pid return rows[0]["plan_id"] if rows else None async def list_model_routes() -> list[dict]: db = await get_db() cur = await db.execute("SELECT * FROM model_routes ORDER BY model_name, priority DESC") return [row_to_dict(r) for r in await cur.fetchall()] async def delete_model_route(route_id: str) -> bool: db = await get_db() cur = await db.execute("DELETE FROM model_routes WHERE id=?", (route_id,)) await db.commit() return cur.rowcount > 0 # ── Task Queue ──────────────────────────────────────── async def create_task( task_type: str, request_payload: dict, plan_id: str | None = None, priority: int = 0, max_retries: int = 3, callback_url: str | None = None, ) -> dict: db = await get_db() tid = new_id() now = _now_iso() await db.execute( """INSERT INTO tasks (id, plan_id, task_type, status, request_payload, priority, max_retries, callback_url, created_at) VALUES (?,?,?,?,?,?,?,?,?)""", (tid, plan_id, task_type, "pending", json.dumps(request_payload), priority, max_retries, callback_url, now), ) await db.commit() return {"id": tid, "status": "pending"} async def list_tasks(status: str | None = None, limit: int = 50) -> list[dict]: db = await get_db() sql = "SELECT * FROM tasks" params: list = [] if status: sql += " WHERE status=?" params.append(status) sql += " ORDER BY priority DESC, created_at ASC LIMIT ?" params.append(limit) cur = await db.execute(sql, params) rows = await cur.fetchall() result = [] for row in rows: d = row_to_dict(row) d["request_payload"] = _parse_json(d["request_payload"], {}) d["response_payload"] = _parse_json(d.get("response_payload")) result.append(d) return result async def update_task(task_id: str, **fields) -> bool: db = await get_db() json_fields = ("request_payload", "response_payload") sets, vals = [], [] for k, v in fields.items(): if v is None: continue if k in json_fields and isinstance(v, dict): v = json.dumps(v) sets.append(f"{k}=?") vals.append(v) if not sets: return False vals.append(task_id) await db.execute(f"UPDATE tasks SET {', '.join(sets)} WHERE id=?", vals) await db.commit() return True async def get_task(task_id: str) -> dict | None: db = await get_db() cur = await db.execute("SELECT * FROM tasks WHERE id=?", (task_id,)) row = await cur.fetchone() if not row: return None t = row_to_dict(row) t["request_payload"] = _parse_json(t["request_payload"], {}) t["response_payload"] = _parse_json(t.get("response_payload")) return t # ── 种子数据导入 ────────────────────────────────────── async def seed_from_config(): """首次启动时从 config.yaml 导入 Plan + QuotaRule + ModelRoute""" from app.config import settings as cfg db = await get_db() cur = await db.execute("SELECT COUNT(*) as cnt FROM plans") row = await cur.fetchone() if row["cnt"] > 0: return for ps in cfg.plans: plan = await create_plan( name=ps.name, provider_name=ps.provider, api_key=ps.api_key, api_base=ps.api_base, plan_type=ps.plan_type, supported_models=ps.supported_models, extra_headers=ps.extra_headers, extra_config=ps.extra_config, ) for qr in ps.quota_rules: await create_quota_rule( plan_id=plan["id"], rule_name=qr.rule_name, quota_total=qr.quota_total, quota_unit=qr.quota_unit, refresh_type=qr.refresh_type, interval_hours=qr.interval_hours, calendar_unit=qr.calendar_unit, calendar_anchor=qr.calendar_anchor, ) for model in ps.supported_models: await set_model_route(model, plan["id"], priority=0)