Files
planManage/app/database.py

499 lines
16 KiB
Python
Raw Permalink Normal View History

"""SQLite 数据库管理 -- 异步连接 + 自动建表"""
from __future__ import annotations
import json
import os
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import aiosqlite
from app.config import settings
_db: aiosqlite.Connection | None = None
SQL_CREATE_TABLES = """
CREATE TABLE IF NOT EXISTS plans (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
provider_name TEXT NOT NULL,
api_key TEXT DEFAULT '',
api_base TEXT DEFAULT '',
plan_type TEXT DEFAULT 'coding',
supported_models TEXT DEFAULT '[]',
extra_headers TEXT DEFAULT '{}',
extra_config TEXT DEFAULT '{}',
enabled INTEGER DEFAULT 1,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS quota_rules (
id TEXT PRIMARY KEY,
plan_id TEXT NOT NULL REFERENCES plans(id) ON DELETE CASCADE,
rule_name TEXT NOT NULL,
quota_total INTEGER NOT NULL DEFAULT 0,
quota_used INTEGER NOT NULL DEFAULT 0,
quota_unit TEXT DEFAULT 'requests',
refresh_type TEXT DEFAULT 'calendar_cycle',
interval_hours REAL,
calendar_unit TEXT,
calendar_anchor TEXT,
last_refresh_at TEXT,
next_refresh_at TEXT,
enabled INTEGER DEFAULT 1
);
CREATE TABLE IF NOT EXISTS quota_snapshots (
id TEXT PRIMARY KEY,
rule_id TEXT NOT NULL REFERENCES quota_rules(id) ON DELETE CASCADE,
quota_used INTEGER NOT NULL,
quota_remaining INTEGER NOT NULL,
checked_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS model_routes (
id TEXT PRIMARY KEY,
model_name TEXT NOT NULL,
plan_id TEXT NOT NULL REFERENCES plans(id) ON DELETE CASCADE,
priority INTEGER DEFAULT 0
);
CREATE TABLE IF NOT EXISTS tasks (
id TEXT PRIMARY KEY,
plan_id TEXT REFERENCES plans(id) ON DELETE SET NULL,
task_type TEXT NOT NULL,
status TEXT DEFAULT 'pending',
request_payload TEXT DEFAULT '{}',
response_payload TEXT,
result_file_path TEXT,
result_mime_type TEXT,
priority INTEGER DEFAULT 0,
retry_count INTEGER DEFAULT 0,
max_retries INTEGER DEFAULT 3,
callback_url TEXT,
created_at TEXT NOT NULL,
started_at TEXT,
completed_at TEXT
);
CREATE INDEX IF NOT EXISTS idx_quota_rules_plan ON quota_rules(plan_id);
CREATE INDEX IF NOT EXISTS idx_model_routes_model ON model_routes(model_name);
CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status);
"""
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def new_id() -> str:
return uuid.uuid4().hex[:16]
async def get_db() -> aiosqlite.Connection:
global _db
if _db is None:
db_path = Path(settings.database.path)
db_path.parent.mkdir(parents=True, exist_ok=True)
_db = await aiosqlite.connect(str(db_path))
_db.row_factory = aiosqlite.Row
await _db.execute("PRAGMA journal_mode=WAL")
await _db.execute("PRAGMA foreign_keys=ON")
await _db.executescript(SQL_CREATE_TABLES)
await _db.commit()
return _db
async def close_db():
global _db
if _db:
await _db.close()
_db = None
# ── 通用辅助 ──────────────────────────────────────────
def row_to_dict(row: aiosqlite.Row) -> dict[str, Any]:
return dict(row)
def _parse_json(val: str | None, default: Any = None) -> Any:
if val is None:
return default
try:
return json.loads(val)
except (json.JSONDecodeError, TypeError):
return default
# ── Plan CRUD ─────────────────────────────────────────
async def create_plan(
name: str,
provider_name: str,
api_key: str = "",
api_base: str = "",
plan_type: str = "coding",
supported_models: list[str] | None = None,
extra_headers: dict | None = None,
extra_config: dict | None = None,
enabled: bool = True,
) -> dict:
db = await get_db()
pid = new_id()
now = _now_iso()
await db.execute(
"""INSERT INTO plans
(id, name, provider_name, api_key, api_base, plan_type,
supported_models, extra_headers, extra_config, enabled, created_at, updated_at)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?)""",
(
pid, name, provider_name, api_key, api_base, plan_type,
json.dumps(supported_models or []),
json.dumps(extra_headers or {}),
json.dumps(extra_config or {}),
int(enabled), now, now,
),
)
await db.commit()
return {"id": pid, "name": name, "provider_name": provider_name}
async def get_plan(plan_id: str) -> dict | None:
db = await get_db()
cur = await db.execute("SELECT * FROM plans WHERE id=?", (plan_id,))
row = await cur.fetchone()
if not row:
return None
d = row_to_dict(row)
d["supported_models"] = _parse_json(d["supported_models"], [])
d["extra_headers"] = _parse_json(d["extra_headers"], {})
d["extra_config"] = _parse_json(d["extra_config"], {})
d["enabled"] = bool(d["enabled"])
return d
async def list_plans(enabled_only: bool = False) -> list[dict]:
db = await get_db()
sql = "SELECT * FROM plans"
if enabled_only:
sql += " WHERE enabled=1"
cur = await db.execute(sql)
rows = await cur.fetchall()
result = []
for row in rows:
d = row_to_dict(row)
d["supported_models"] = _parse_json(d["supported_models"], [])
d["extra_headers"] = _parse_json(d["extra_headers"], {})
d["extra_config"] = _parse_json(d["extra_config"], {})
d["enabled"] = bool(d["enabled"])
result.append(d)
return result
async def update_plan(plan_id: str, **fields) -> bool:
db = await get_db()
json_fields = ("supported_models", "extra_headers", "extra_config")
sets, vals = [], []
for k, v in fields.items():
if v is None:
continue
if k in json_fields:
v = json.dumps(v)
if k == "enabled":
v = int(v)
sets.append(f"{k}=?")
vals.append(v)
if not sets:
return False
sets.append("updated_at=?")
vals.append(_now_iso())
vals.append(plan_id)
await db.execute(f"UPDATE plans SET {', '.join(sets)} WHERE id=?", vals)
await db.commit()
return True
async def delete_plan(plan_id: str) -> bool:
db = await get_db()
cur = await db.execute("DELETE FROM plans WHERE id=?", (plan_id,))
await db.commit()
return cur.rowcount > 0
# ── QuotaRule CRUD ────────────────────────────────────
async def create_quota_rule(
plan_id: str,
rule_name: str,
quota_total: int,
quota_unit: str = "requests",
refresh_type: str = "calendar_cycle",
interval_hours: float | None = None,
calendar_unit: str | None = None,
calendar_anchor: dict | None = None,
enabled: bool = True,
) -> dict:
db = await get_db()
rid = new_id()
now = _now_iso()
await db.execute(
"""INSERT INTO quota_rules
(id, plan_id, rule_name, quota_total, quota_used, quota_unit,
refresh_type, interval_hours, calendar_unit, calendar_anchor,
last_refresh_at, next_refresh_at, enabled)
VALUES (?,?,?,?,0,?,?,?,?,?,?,?,?)""",
(
rid, plan_id, rule_name, quota_total, quota_unit,
refresh_type, interval_hours, calendar_unit,
json.dumps(calendar_anchor) if calendar_anchor else None,
now, None, int(enabled),
),
)
await db.commit()
return {"id": rid, "plan_id": plan_id, "rule_name": rule_name}
async def list_quota_rules(plan_id: str) -> list[dict]:
db = await get_db()
cur = await db.execute("SELECT * FROM quota_rules WHERE plan_id=?", (plan_id,))
rows = await cur.fetchall()
result = []
for row in rows:
d = row_to_dict(row)
d["calendar_anchor"] = _parse_json(d.get("calendar_anchor"), {})
d["enabled"] = bool(d["enabled"])
result.append(d)
return result
async def get_all_quota_rules() -> list[dict]:
"""获取全部 QuotaRule供调度器使用"""
db = await get_db()
cur = await db.execute("SELECT * FROM quota_rules WHERE enabled=1")
rows = await cur.fetchall()
result = []
for row in rows:
d = row_to_dict(row)
d["calendar_anchor"] = _parse_json(d.get("calendar_anchor"), {})
d["enabled"] = bool(d["enabled"])
result.append(d)
return result
async def update_quota_rule(rule_id: str, **fields) -> bool:
db = await get_db()
json_fields = ("calendar_anchor",)
sets, vals = [], []
for k, v in fields.items():
if v is None:
continue
if k in json_fields:
v = json.dumps(v)
if k == "enabled":
v = int(v)
sets.append(f"{k}=?")
vals.append(v)
if not sets:
return False
vals.append(rule_id)
await db.execute(f"UPDATE quota_rules SET {', '.join(sets)} WHERE id=?", vals)
await db.commit()
return True
async def delete_quota_rule(rule_id: str) -> bool:
db = await get_db()
cur = await db.execute("DELETE FROM quota_rules WHERE id=?", (rule_id,))
await db.commit()
return cur.rowcount > 0
async def increment_quota_used(plan_id: str, token_count: int = 0):
"""请求完成后增加该 Plan 所有 Rule 的 quota_used"""
db = await get_db()
rules = await list_quota_rules(plan_id)
for r in rules:
if not r["enabled"]:
continue
inc = token_count if r["quota_unit"] == "tokens" else 1
await db.execute(
"UPDATE quota_rules SET quota_used = quota_used + ? WHERE id=?",
(inc, r["id"]),
)
await db.commit()
async def check_plan_available(plan_id: str) -> bool:
"""判断 Plan 所有 Rule 是否都有余量"""
rules = await list_quota_rules(plan_id)
for r in rules:
if not r["enabled"]:
continue
if r["quota_used"] >= r["quota_total"]:
return False
return True
# ── Model Route ───────────────────────────────────────
async def set_model_route(model_name: str, plan_id: str, priority: int = 0) -> dict:
db = await get_db()
mid = new_id()
await db.execute(
"DELETE FROM model_routes WHERE model_name=? AND plan_id=?",
(model_name, plan_id),
)
await db.execute(
"INSERT INTO model_routes (id, model_name, plan_id, priority) VALUES (?,?,?,?)",
(mid, model_name, plan_id, priority),
)
await db.commit()
return {"id": mid, "model_name": model_name, "plan_id": plan_id}
async def resolve_model(model_name: str) -> str | None:
"""按 priority 降序找到可用的 plan_idfallback 逻辑)"""
db = await get_db()
cur = await db.execute(
"SELECT plan_id FROM model_routes WHERE model_name=? ORDER BY priority DESC",
(model_name,),
)
rows = await cur.fetchall()
for row in rows:
pid = row["plan_id"]
if await check_plan_available(pid):
return pid
return rows[0]["plan_id"] if rows else None
async def list_model_routes() -> list[dict]:
db = await get_db()
cur = await db.execute("SELECT * FROM model_routes ORDER BY model_name, priority DESC")
return [row_to_dict(r) for r in await cur.fetchall()]
async def delete_model_route(route_id: str) -> bool:
db = await get_db()
cur = await db.execute("DELETE FROM model_routes WHERE id=?", (route_id,))
await db.commit()
return cur.rowcount > 0
# ── Task Queue ────────────────────────────────────────
async def create_task(
task_type: str,
request_payload: dict,
plan_id: str | None = None,
priority: int = 0,
max_retries: int = 3,
callback_url: str | None = None,
) -> dict:
db = await get_db()
tid = new_id()
now = _now_iso()
await db.execute(
"""INSERT INTO tasks
(id, plan_id, task_type, status, request_payload, priority,
max_retries, callback_url, created_at)
VALUES (?,?,?,?,?,?,?,?,?)""",
(tid, plan_id, task_type, "pending", json.dumps(request_payload),
priority, max_retries, callback_url, now),
)
await db.commit()
return {"id": tid, "status": "pending"}
async def list_tasks(status: str | None = None, limit: int = 50) -> list[dict]:
db = await get_db()
sql = "SELECT * FROM tasks"
params: list = []
if status:
sql += " WHERE status=?"
params.append(status)
sql += " ORDER BY priority DESC, created_at ASC LIMIT ?"
params.append(limit)
cur = await db.execute(sql, params)
rows = await cur.fetchall()
result = []
for row in rows:
d = row_to_dict(row)
d["request_payload"] = _parse_json(d["request_payload"], {})
d["response_payload"] = _parse_json(d.get("response_payload"))
result.append(d)
return result
async def update_task(task_id: str, **fields) -> bool:
db = await get_db()
json_fields = ("request_payload", "response_payload")
sets, vals = [], []
for k, v in fields.items():
if v is None:
continue
if k in json_fields and isinstance(v, dict):
v = json.dumps(v)
sets.append(f"{k}=?")
vals.append(v)
if not sets:
return False
vals.append(task_id)
await db.execute(f"UPDATE tasks SET {', '.join(sets)} WHERE id=?", vals)
await db.commit()
return True
async def get_task(task_id: str) -> dict | None:
db = await get_db()
cur = await db.execute("SELECT * FROM tasks WHERE id=?", (task_id,))
row = await cur.fetchone()
if not row:
return None
t = row_to_dict(row)
t["request_payload"] = _parse_json(t["request_payload"], {})
t["response_payload"] = _parse_json(t.get("response_payload"))
return t
# ── 种子数据导入 ──────────────────────────────────────
async def seed_from_config():
"""首次启动时从 config.yaml 导入 Plan + QuotaRule + ModelRoute"""
from app.config import settings as cfg
db = await get_db()
cur = await db.execute("SELECT COUNT(*) as cnt FROM plans")
row = await cur.fetchone()
if row["cnt"] > 0:
return
for ps in cfg.plans:
plan = await create_plan(
name=ps.name,
provider_name=ps.provider,
api_key=ps.api_key,
api_base=ps.api_base,
plan_type=ps.plan_type,
supported_models=ps.supported_models,
extra_headers=ps.extra_headers,
extra_config=ps.extra_config,
)
for qr in ps.quota_rules:
await create_quota_rule(
plan_id=plan["id"],
rule_name=qr.rule_name,
quota_total=qr.quota_total,
quota_unit=qr.quota_unit,
refresh_type=qr.refresh_type,
interval_hours=qr.interval_hours,
calendar_unit=qr.calendar_unit,
calendar_anchor=qr.calendar_anchor,
)
for model in ps.supported_models:
await set_model_route(model, plan["id"], priority=0)