- 支持 MiniMax/OpenAI/Google Gemini/智谱/Kimi 五个平台 - 插件化 Provider 架构,自动发现注册 - 多维度 QuotaRule 额度追踪(固定间隔/自然周期/API同步/手动) - OpenAI + Anthropic 兼容 API 代理,SSE 流式转发 - Model 路由表 + 额度耗尽自动 fallback - 多媒体任务队列(图片/语音/视频) - Vue3 + Tailwind 单文件 Web 仪表盘 - Docker 一键部署 Made-with: Cursor
480 lines
15 KiB
Python
480 lines
15 KiB
Python
"""SQLite 数据库管理 -- 异步连接 + 自动建表"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import os
|
||
import uuid
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
import aiosqlite
|
||
|
||
from app.config import settings
|
||
|
||
_db: aiosqlite.Connection | None = None
|
||
|
||
SQL_CREATE_TABLES = """
|
||
CREATE TABLE IF NOT EXISTS plans (
|
||
id TEXT PRIMARY KEY,
|
||
name TEXT NOT NULL,
|
||
provider_name TEXT NOT NULL,
|
||
api_key TEXT DEFAULT '',
|
||
api_base TEXT DEFAULT '',
|
||
plan_type TEXT DEFAULT 'coding',
|
||
supported_models TEXT DEFAULT '[]',
|
||
extra_headers TEXT DEFAULT '{}',
|
||
extra_config TEXT DEFAULT '{}',
|
||
enabled INTEGER DEFAULT 1,
|
||
created_at TEXT NOT NULL,
|
||
updated_at TEXT NOT NULL
|
||
);
|
||
|
||
CREATE TABLE IF NOT EXISTS quota_rules (
|
||
id TEXT PRIMARY KEY,
|
||
plan_id TEXT NOT NULL REFERENCES plans(id) ON DELETE CASCADE,
|
||
rule_name TEXT NOT NULL,
|
||
quota_total INTEGER NOT NULL DEFAULT 0,
|
||
quota_used INTEGER NOT NULL DEFAULT 0,
|
||
quota_unit TEXT DEFAULT 'requests',
|
||
refresh_type TEXT DEFAULT 'calendar_cycle',
|
||
interval_hours REAL,
|
||
calendar_unit TEXT,
|
||
calendar_anchor TEXT,
|
||
last_refresh_at TEXT,
|
||
next_refresh_at TEXT,
|
||
enabled INTEGER DEFAULT 1
|
||
);
|
||
|
||
CREATE TABLE IF NOT EXISTS quota_snapshots (
|
||
id TEXT PRIMARY KEY,
|
||
rule_id TEXT NOT NULL REFERENCES quota_rules(id) ON DELETE CASCADE,
|
||
quota_used INTEGER NOT NULL,
|
||
quota_remaining INTEGER NOT NULL,
|
||
checked_at TEXT NOT NULL
|
||
);
|
||
|
||
CREATE TABLE IF NOT EXISTS model_routes (
|
||
id TEXT PRIMARY KEY,
|
||
model_name TEXT NOT NULL,
|
||
plan_id TEXT NOT NULL REFERENCES plans(id) ON DELETE CASCADE,
|
||
priority INTEGER DEFAULT 0
|
||
);
|
||
|
||
CREATE TABLE IF NOT EXISTS tasks (
|
||
id TEXT PRIMARY KEY,
|
||
plan_id TEXT REFERENCES plans(id) ON DELETE SET NULL,
|
||
task_type TEXT NOT NULL,
|
||
status TEXT DEFAULT 'pending',
|
||
request_payload TEXT DEFAULT '{}',
|
||
response_payload TEXT,
|
||
result_file_path TEXT,
|
||
result_mime_type TEXT,
|
||
priority INTEGER DEFAULT 0,
|
||
retry_count INTEGER DEFAULT 0,
|
||
max_retries INTEGER DEFAULT 3,
|
||
callback_url TEXT,
|
||
created_at TEXT NOT NULL,
|
||
started_at TEXT,
|
||
completed_at TEXT
|
||
);
|
||
|
||
CREATE INDEX IF NOT EXISTS idx_quota_rules_plan ON quota_rules(plan_id);
|
||
CREATE INDEX IF NOT EXISTS idx_model_routes_model ON model_routes(model_name);
|
||
CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status);
|
||
"""
|
||
|
||
|
||
def _now_iso() -> str:
|
||
return datetime.now(timezone.utc).isoformat()
|
||
|
||
|
||
def new_id() -> str:
|
||
return uuid.uuid4().hex[:16]
|
||
|
||
|
||
async def get_db() -> aiosqlite.Connection:
|
||
global _db
|
||
if _db is None:
|
||
db_path = Path(settings.database.path)
|
||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||
_db = await aiosqlite.connect(str(db_path))
|
||
_db.row_factory = aiosqlite.Row
|
||
await _db.execute("PRAGMA journal_mode=WAL")
|
||
await _db.execute("PRAGMA foreign_keys=ON")
|
||
await _db.executescript(SQL_CREATE_TABLES)
|
||
await _db.commit()
|
||
return _db
|
||
|
||
|
||
async def close_db():
|
||
global _db
|
||
if _db:
|
||
await _db.close()
|
||
_db = None
|
||
|
||
|
||
# ── 通用辅助 ──────────────────────────────────────────
|
||
|
||
def row_to_dict(row: aiosqlite.Row) -> dict[str, Any]:
|
||
return dict(row)
|
||
|
||
|
||
def _parse_json(val: str | None, default: Any = None) -> Any:
|
||
if val is None:
|
||
return default
|
||
try:
|
||
return json.loads(val)
|
||
except (json.JSONDecodeError, TypeError):
|
||
return default
|
||
|
||
|
||
# ── Plan CRUD ─────────────────────────────────────────
|
||
|
||
async def create_plan(
|
||
name: str,
|
||
provider_name: str,
|
||
api_key: str = "",
|
||
api_base: str = "",
|
||
plan_type: str = "coding",
|
||
supported_models: list[str] | None = None,
|
||
extra_headers: dict | None = None,
|
||
extra_config: dict | None = None,
|
||
enabled: bool = True,
|
||
) -> dict:
|
||
db = await get_db()
|
||
pid = new_id()
|
||
now = _now_iso()
|
||
await db.execute(
|
||
"""INSERT INTO plans
|
||
(id, name, provider_name, api_key, api_base, plan_type,
|
||
supported_models, extra_headers, extra_config, enabled, created_at, updated_at)
|
||
VALUES (?,?,?,?,?,?,?,?,?,?,?,?)""",
|
||
(
|
||
pid, name, provider_name, api_key, api_base, plan_type,
|
||
json.dumps(supported_models or []),
|
||
json.dumps(extra_headers or {}),
|
||
json.dumps(extra_config or {}),
|
||
int(enabled), now, now,
|
||
),
|
||
)
|
||
await db.commit()
|
||
return {"id": pid, "name": name, "provider_name": provider_name}
|
||
|
||
|
||
async def get_plan(plan_id: str) -> dict | None:
|
||
db = await get_db()
|
||
cur = await db.execute("SELECT * FROM plans WHERE id=?", (plan_id,))
|
||
row = await cur.fetchone()
|
||
if not row:
|
||
return None
|
||
d = row_to_dict(row)
|
||
d["supported_models"] = _parse_json(d["supported_models"], [])
|
||
d["extra_headers"] = _parse_json(d["extra_headers"], {})
|
||
d["extra_config"] = _parse_json(d["extra_config"], {})
|
||
d["enabled"] = bool(d["enabled"])
|
||
return d
|
||
|
||
|
||
async def list_plans(enabled_only: bool = False) -> list[dict]:
|
||
db = await get_db()
|
||
sql = "SELECT * FROM plans"
|
||
if enabled_only:
|
||
sql += " WHERE enabled=1"
|
||
cur = await db.execute(sql)
|
||
rows = await cur.fetchall()
|
||
result = []
|
||
for row in rows:
|
||
d = row_to_dict(row)
|
||
d["supported_models"] = _parse_json(d["supported_models"], [])
|
||
d["extra_headers"] = _parse_json(d["extra_headers"], {})
|
||
d["extra_config"] = _parse_json(d["extra_config"], {})
|
||
d["enabled"] = bool(d["enabled"])
|
||
result.append(d)
|
||
return result
|
||
|
||
|
||
async def update_plan(plan_id: str, **fields) -> bool:
|
||
db = await get_db()
|
||
json_fields = ("supported_models", "extra_headers", "extra_config")
|
||
sets, vals = [], []
|
||
for k, v in fields.items():
|
||
if v is None:
|
||
continue
|
||
if k in json_fields:
|
||
v = json.dumps(v)
|
||
if k == "enabled":
|
||
v = int(v)
|
||
sets.append(f"{k}=?")
|
||
vals.append(v)
|
||
if not sets:
|
||
return False
|
||
sets.append("updated_at=?")
|
||
vals.append(_now_iso())
|
||
vals.append(plan_id)
|
||
await db.execute(f"UPDATE plans SET {', '.join(sets)} WHERE id=?", vals)
|
||
await db.commit()
|
||
return True
|
||
|
||
|
||
async def delete_plan(plan_id: str) -> bool:
|
||
db = await get_db()
|
||
cur = await db.execute("DELETE FROM plans WHERE id=?", (plan_id,))
|
||
await db.commit()
|
||
return cur.rowcount > 0
|
||
|
||
|
||
# ── QuotaRule CRUD ────────────────────────────────────
|
||
|
||
async def create_quota_rule(
|
||
plan_id: str,
|
||
rule_name: str,
|
||
quota_total: int,
|
||
quota_unit: str = "requests",
|
||
refresh_type: str = "calendar_cycle",
|
||
interval_hours: float | None = None,
|
||
calendar_unit: str | None = None,
|
||
calendar_anchor: dict | None = None,
|
||
enabled: bool = True,
|
||
) -> dict:
|
||
db = await get_db()
|
||
rid = new_id()
|
||
now = _now_iso()
|
||
await db.execute(
|
||
"""INSERT INTO quota_rules
|
||
(id, plan_id, rule_name, quota_total, quota_used, quota_unit,
|
||
refresh_type, interval_hours, calendar_unit, calendar_anchor,
|
||
last_refresh_at, next_refresh_at, enabled)
|
||
VALUES (?,?,?,?,0,?,?,?,?,?,?,?,?)""",
|
||
(
|
||
rid, plan_id, rule_name, quota_total, quota_unit,
|
||
refresh_type, interval_hours, calendar_unit,
|
||
json.dumps(calendar_anchor) if calendar_anchor else None,
|
||
now, None, int(enabled),
|
||
),
|
||
)
|
||
await db.commit()
|
||
return {"id": rid, "plan_id": plan_id, "rule_name": rule_name}
|
||
|
||
|
||
async def list_quota_rules(plan_id: str) -> list[dict]:
|
||
db = await get_db()
|
||
cur = await db.execute("SELECT * FROM quota_rules WHERE plan_id=?", (plan_id,))
|
||
rows = await cur.fetchall()
|
||
result = []
|
||
for row in rows:
|
||
d = row_to_dict(row)
|
||
d["calendar_anchor"] = _parse_json(d.get("calendar_anchor"), {})
|
||
d["enabled"] = bool(d["enabled"])
|
||
result.append(d)
|
||
return result
|
||
|
||
|
||
async def get_all_quota_rules() -> list[dict]:
|
||
"""获取全部 QuotaRule,供调度器使用"""
|
||
db = await get_db()
|
||
cur = await db.execute("SELECT * FROM quota_rules WHERE enabled=1")
|
||
rows = await cur.fetchall()
|
||
result = []
|
||
for row in rows:
|
||
d = row_to_dict(row)
|
||
d["calendar_anchor"] = _parse_json(d.get("calendar_anchor"), {})
|
||
d["enabled"] = bool(d["enabled"])
|
||
result.append(d)
|
||
return result
|
||
|
||
|
||
async def update_quota_rule(rule_id: str, **fields) -> bool:
|
||
db = await get_db()
|
||
json_fields = ("calendar_anchor",)
|
||
sets, vals = [], []
|
||
for k, v in fields.items():
|
||
if v is None:
|
||
continue
|
||
if k in json_fields:
|
||
v = json.dumps(v)
|
||
if k == "enabled":
|
||
v = int(v)
|
||
sets.append(f"{k}=?")
|
||
vals.append(v)
|
||
if not sets:
|
||
return False
|
||
vals.append(rule_id)
|
||
await db.execute(f"UPDATE quota_rules SET {', '.join(sets)} WHERE id=?", vals)
|
||
await db.commit()
|
||
return True
|
||
|
||
|
||
async def increment_quota_used(plan_id: str, token_count: int = 0):
|
||
"""请求完成后增加该 Plan 所有 Rule 的 quota_used"""
|
||
db = await get_db()
|
||
rules = await list_quota_rules(plan_id)
|
||
for r in rules:
|
||
if not r["enabled"]:
|
||
continue
|
||
inc = token_count if r["quota_unit"] == "tokens" else 1
|
||
await db.execute(
|
||
"UPDATE quota_rules SET quota_used = quota_used + ? WHERE id=?",
|
||
(inc, r["id"]),
|
||
)
|
||
await db.commit()
|
||
|
||
|
||
async def check_plan_available(plan_id: str) -> bool:
|
||
"""判断 Plan 所有 Rule 是否都有余量"""
|
||
rules = await list_quota_rules(plan_id)
|
||
for r in rules:
|
||
if not r["enabled"]:
|
||
continue
|
||
if r["quota_used"] >= r["quota_total"]:
|
||
return False
|
||
return True
|
||
|
||
|
||
# ── Model Route ───────────────────────────────────────
|
||
|
||
async def set_model_route(model_name: str, plan_id: str, priority: int = 0) -> dict:
|
||
db = await get_db()
|
||
mid = new_id()
|
||
await db.execute(
|
||
"DELETE FROM model_routes WHERE model_name=? AND plan_id=?",
|
||
(model_name, plan_id),
|
||
)
|
||
await db.execute(
|
||
"INSERT INTO model_routes (id, model_name, plan_id, priority) VALUES (?,?,?,?)",
|
||
(mid, model_name, plan_id, priority),
|
||
)
|
||
await db.commit()
|
||
return {"id": mid, "model_name": model_name, "plan_id": plan_id}
|
||
|
||
|
||
async def resolve_model(model_name: str) -> str | None:
|
||
"""按 priority 降序找到可用的 plan_id(fallback 逻辑)"""
|
||
db = await get_db()
|
||
cur = await db.execute(
|
||
"SELECT plan_id FROM model_routes WHERE model_name=? ORDER BY priority DESC",
|
||
(model_name,),
|
||
)
|
||
rows = await cur.fetchall()
|
||
for row in rows:
|
||
pid = row["plan_id"]
|
||
if await check_plan_available(pid):
|
||
return pid
|
||
return rows[0]["plan_id"] if rows else None
|
||
|
||
|
||
async def list_model_routes() -> list[dict]:
|
||
db = await get_db()
|
||
cur = await db.execute("SELECT * FROM model_routes ORDER BY model_name, priority DESC")
|
||
return [row_to_dict(r) for r in await cur.fetchall()]
|
||
|
||
|
||
async def delete_model_route(route_id: str) -> bool:
|
||
db = await get_db()
|
||
cur = await db.execute("DELETE FROM model_routes WHERE id=?", (route_id,))
|
||
await db.commit()
|
||
return cur.rowcount > 0
|
||
|
||
|
||
# ── Task Queue ────────────────────────────────────────
|
||
|
||
async def create_task(
|
||
task_type: str,
|
||
request_payload: dict,
|
||
plan_id: str | None = None,
|
||
priority: int = 0,
|
||
max_retries: int = 3,
|
||
callback_url: str | None = None,
|
||
) -> dict:
|
||
db = await get_db()
|
||
tid = new_id()
|
||
now = _now_iso()
|
||
await db.execute(
|
||
"""INSERT INTO tasks
|
||
(id, plan_id, task_type, status, request_payload, priority,
|
||
max_retries, callback_url, created_at)
|
||
VALUES (?,?,?,?,?,?,?,?,?)""",
|
||
(tid, plan_id, task_type, "pending", json.dumps(request_payload),
|
||
priority, max_retries, callback_url, now),
|
||
)
|
||
await db.commit()
|
||
return {"id": tid, "status": "pending"}
|
||
|
||
|
||
async def list_tasks(status: str | None = None, limit: int = 50) -> list[dict]:
|
||
db = await get_db()
|
||
sql = "SELECT * FROM tasks"
|
||
params: list = []
|
||
if status:
|
||
sql += " WHERE status=?"
|
||
params.append(status)
|
||
sql += " ORDER BY priority DESC, created_at ASC LIMIT ?"
|
||
params.append(limit)
|
||
cur = await db.execute(sql, params)
|
||
rows = await cur.fetchall()
|
||
result = []
|
||
for row in rows:
|
||
d = row_to_dict(row)
|
||
d["request_payload"] = _parse_json(d["request_payload"], {})
|
||
d["response_payload"] = _parse_json(d.get("response_payload"))
|
||
result.append(d)
|
||
return result
|
||
|
||
|
||
async def update_task(task_id: str, **fields) -> bool:
|
||
db = await get_db()
|
||
json_fields = ("request_payload", "response_payload")
|
||
sets, vals = [], []
|
||
for k, v in fields.items():
|
||
if v is None:
|
||
continue
|
||
if k in json_fields and isinstance(v, dict):
|
||
v = json.dumps(v)
|
||
sets.append(f"{k}=?")
|
||
vals.append(v)
|
||
if not sets:
|
||
return False
|
||
vals.append(task_id)
|
||
await db.execute(f"UPDATE tasks SET {', '.join(sets)} WHERE id=?", vals)
|
||
await db.commit()
|
||
return True
|
||
|
||
|
||
# ── 种子数据导入 ──────────────────────────────────────
|
||
|
||
async def seed_from_config():
|
||
"""首次启动时从 config.yaml 导入 Plan + QuotaRule + ModelRoute"""
|
||
from app.config import settings as cfg
|
||
|
||
db = await get_db()
|
||
cur = await db.execute("SELECT COUNT(*) as cnt FROM plans")
|
||
row = await cur.fetchone()
|
||
if row["cnt"] > 0:
|
||
return
|
||
|
||
for ps in cfg.plans:
|
||
plan = await create_plan(
|
||
name=ps.name,
|
||
provider_name=ps.provider,
|
||
api_key=ps.api_key,
|
||
api_base=ps.api_base,
|
||
plan_type=ps.plan_type,
|
||
supported_models=ps.supported_models,
|
||
extra_headers=ps.extra_headers,
|
||
extra_config=ps.extra_config,
|
||
)
|
||
for qr in ps.quota_rules:
|
||
await create_quota_rule(
|
||
plan_id=plan["id"],
|
||
rule_name=qr.rule_name,
|
||
quota_total=qr.quota_total,
|
||
quota_unit=qr.quota_unit,
|
||
refresh_type=qr.refresh_type,
|
||
interval_hours=qr.interval_hours,
|
||
calendar_unit=qr.calendar_unit,
|
||
calendar_anchor=qr.calendar_anchor,
|
||
)
|
||
for model in ps.supported_models:
|
||
await set_model_route(model, plan["id"], priority=0)
|