Files
planManage/app/database.py
congsh 37d282c0a2 test: 添加测试框架和全面的单元测试
- 添加 pytest 配置和测试依赖到 requirements.txt
- 创建测试包结构和 fixtures (conftest.py)
- 添加数据库模块的 CRUD 操作测试 (test_database.py)
- 添加 Provider 插件系统测试 (test_providers.py)
- 添加调度器模块测试 (test_scheduler.py)
- 添加 API 路由测试 (test_api.py)
- 添加回归测试覆盖边界条件和错误处理 (test_regressions.py)
- 添加健康检查端点用于容器监控
- 修复调度器中的日历计算逻辑和任务执行参数处理
- 更新数据库函数以返回操作结果状态
2026-03-31 22:36:18 +08:00

499 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""SQLite 数据库管理 -- 异步连接 + 自动建表"""
from __future__ import annotations
import json
import os
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import aiosqlite
from app.config import settings
_db: aiosqlite.Connection | None = None
SQL_CREATE_TABLES = """
CREATE TABLE IF NOT EXISTS plans (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
provider_name TEXT NOT NULL,
api_key TEXT DEFAULT '',
api_base TEXT DEFAULT '',
plan_type TEXT DEFAULT 'coding',
supported_models TEXT DEFAULT '[]',
extra_headers TEXT DEFAULT '{}',
extra_config TEXT DEFAULT '{}',
enabled INTEGER DEFAULT 1,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS quota_rules (
id TEXT PRIMARY KEY,
plan_id TEXT NOT NULL REFERENCES plans(id) ON DELETE CASCADE,
rule_name TEXT NOT NULL,
quota_total INTEGER NOT NULL DEFAULT 0,
quota_used INTEGER NOT NULL DEFAULT 0,
quota_unit TEXT DEFAULT 'requests',
refresh_type TEXT DEFAULT 'calendar_cycle',
interval_hours REAL,
calendar_unit TEXT,
calendar_anchor TEXT,
last_refresh_at TEXT,
next_refresh_at TEXT,
enabled INTEGER DEFAULT 1
);
CREATE TABLE IF NOT EXISTS quota_snapshots (
id TEXT PRIMARY KEY,
rule_id TEXT NOT NULL REFERENCES quota_rules(id) ON DELETE CASCADE,
quota_used INTEGER NOT NULL,
quota_remaining INTEGER NOT NULL,
checked_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS model_routes (
id TEXT PRIMARY KEY,
model_name TEXT NOT NULL,
plan_id TEXT NOT NULL REFERENCES plans(id) ON DELETE CASCADE,
priority INTEGER DEFAULT 0
);
CREATE TABLE IF NOT EXISTS tasks (
id TEXT PRIMARY KEY,
plan_id TEXT REFERENCES plans(id) ON DELETE SET NULL,
task_type TEXT NOT NULL,
status TEXT DEFAULT 'pending',
request_payload TEXT DEFAULT '{}',
response_payload TEXT,
result_file_path TEXT,
result_mime_type TEXT,
priority INTEGER DEFAULT 0,
retry_count INTEGER DEFAULT 0,
max_retries INTEGER DEFAULT 3,
callback_url TEXT,
created_at TEXT NOT NULL,
started_at TEXT,
completed_at TEXT
);
CREATE INDEX IF NOT EXISTS idx_quota_rules_plan ON quota_rules(plan_id);
CREATE INDEX IF NOT EXISTS idx_model_routes_model ON model_routes(model_name);
CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status);
"""
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def new_id() -> str:
return uuid.uuid4().hex[:16]
async def get_db() -> aiosqlite.Connection:
global _db
if _db is None:
db_path = Path(settings.database.path)
db_path.parent.mkdir(parents=True, exist_ok=True)
_db = await aiosqlite.connect(str(db_path))
_db.row_factory = aiosqlite.Row
await _db.execute("PRAGMA journal_mode=WAL")
await _db.execute("PRAGMA foreign_keys=ON")
await _db.executescript(SQL_CREATE_TABLES)
await _db.commit()
return _db
async def close_db():
global _db
if _db:
await _db.close()
_db = None
# ── 通用辅助 ──────────────────────────────────────────
def row_to_dict(row: aiosqlite.Row) -> dict[str, Any]:
return dict(row)
def _parse_json(val: str | None, default: Any = None) -> Any:
if val is None:
return default
try:
return json.loads(val)
except (json.JSONDecodeError, TypeError):
return default
# ── Plan CRUD ─────────────────────────────────────────
async def create_plan(
name: str,
provider_name: str,
api_key: str = "",
api_base: str = "",
plan_type: str = "coding",
supported_models: list[str] | None = None,
extra_headers: dict | None = None,
extra_config: dict | None = None,
enabled: bool = True,
) -> dict:
db = await get_db()
pid = new_id()
now = _now_iso()
await db.execute(
"""INSERT INTO plans
(id, name, provider_name, api_key, api_base, plan_type,
supported_models, extra_headers, extra_config, enabled, created_at, updated_at)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?)""",
(
pid, name, provider_name, api_key, api_base, plan_type,
json.dumps(supported_models or []),
json.dumps(extra_headers or {}),
json.dumps(extra_config or {}),
int(enabled), now, now,
),
)
await db.commit()
return {"id": pid, "name": name, "provider_name": provider_name}
async def get_plan(plan_id: str) -> dict | None:
db = await get_db()
cur = await db.execute("SELECT * FROM plans WHERE id=?", (plan_id,))
row = await cur.fetchone()
if not row:
return None
d = row_to_dict(row)
d["supported_models"] = _parse_json(d["supported_models"], [])
d["extra_headers"] = _parse_json(d["extra_headers"], {})
d["extra_config"] = _parse_json(d["extra_config"], {})
d["enabled"] = bool(d["enabled"])
return d
async def list_plans(enabled_only: bool = False) -> list[dict]:
db = await get_db()
sql = "SELECT * FROM plans"
if enabled_only:
sql += " WHERE enabled=1"
cur = await db.execute(sql)
rows = await cur.fetchall()
result = []
for row in rows:
d = row_to_dict(row)
d["supported_models"] = _parse_json(d["supported_models"], [])
d["extra_headers"] = _parse_json(d["extra_headers"], {})
d["extra_config"] = _parse_json(d["extra_config"], {})
d["enabled"] = bool(d["enabled"])
result.append(d)
return result
async def update_plan(plan_id: str, **fields) -> bool:
db = await get_db()
json_fields = ("supported_models", "extra_headers", "extra_config")
sets, vals = [], []
for k, v in fields.items():
if v is None:
continue
if k in json_fields:
v = json.dumps(v)
if k == "enabled":
v = int(v)
sets.append(f"{k}=?")
vals.append(v)
if not sets:
return False
sets.append("updated_at=?")
vals.append(_now_iso())
vals.append(plan_id)
await db.execute(f"UPDATE plans SET {', '.join(sets)} WHERE id=?", vals)
await db.commit()
return True
async def delete_plan(plan_id: str) -> bool:
db = await get_db()
cur = await db.execute("DELETE FROM plans WHERE id=?", (plan_id,))
await db.commit()
return cur.rowcount > 0
# ── QuotaRule CRUD ────────────────────────────────────
async def create_quota_rule(
plan_id: str,
rule_name: str,
quota_total: int,
quota_unit: str = "requests",
refresh_type: str = "calendar_cycle",
interval_hours: float | None = None,
calendar_unit: str | None = None,
calendar_anchor: dict | None = None,
enabled: bool = True,
) -> dict:
db = await get_db()
rid = new_id()
now = _now_iso()
await db.execute(
"""INSERT INTO quota_rules
(id, plan_id, rule_name, quota_total, quota_used, quota_unit,
refresh_type, interval_hours, calendar_unit, calendar_anchor,
last_refresh_at, next_refresh_at, enabled)
VALUES (?,?,?,?,0,?,?,?,?,?,?,?,?)""",
(
rid, plan_id, rule_name, quota_total, quota_unit,
refresh_type, interval_hours, calendar_unit,
json.dumps(calendar_anchor) if calendar_anchor else None,
now, None, int(enabled),
),
)
await db.commit()
return {"id": rid, "plan_id": plan_id, "rule_name": rule_name}
async def list_quota_rules(plan_id: str) -> list[dict]:
db = await get_db()
cur = await db.execute("SELECT * FROM quota_rules WHERE plan_id=?", (plan_id,))
rows = await cur.fetchall()
result = []
for row in rows:
d = row_to_dict(row)
d["calendar_anchor"] = _parse_json(d.get("calendar_anchor"), {})
d["enabled"] = bool(d["enabled"])
result.append(d)
return result
async def get_all_quota_rules() -> list[dict]:
"""获取全部 QuotaRule供调度器使用"""
db = await get_db()
cur = await db.execute("SELECT * FROM quota_rules WHERE enabled=1")
rows = await cur.fetchall()
result = []
for row in rows:
d = row_to_dict(row)
d["calendar_anchor"] = _parse_json(d.get("calendar_anchor"), {})
d["enabled"] = bool(d["enabled"])
result.append(d)
return result
async def update_quota_rule(rule_id: str, **fields) -> bool:
db = await get_db()
json_fields = ("calendar_anchor",)
sets, vals = [], []
for k, v in fields.items():
if v is None:
continue
if k in json_fields:
v = json.dumps(v)
if k == "enabled":
v = int(v)
sets.append(f"{k}=?")
vals.append(v)
if not sets:
return False
vals.append(rule_id)
await db.execute(f"UPDATE quota_rules SET {', '.join(sets)} WHERE id=?", vals)
await db.commit()
return True
async def delete_quota_rule(rule_id: str) -> bool:
db = await get_db()
cur = await db.execute("DELETE FROM quota_rules WHERE id=?", (rule_id,))
await db.commit()
return cur.rowcount > 0
async def increment_quota_used(plan_id: str, token_count: int = 0):
"""请求完成后增加该 Plan 所有 Rule 的 quota_used"""
db = await get_db()
rules = await list_quota_rules(plan_id)
for r in rules:
if not r["enabled"]:
continue
inc = token_count if r["quota_unit"] == "tokens" else 1
await db.execute(
"UPDATE quota_rules SET quota_used = quota_used + ? WHERE id=?",
(inc, r["id"]),
)
await db.commit()
async def check_plan_available(plan_id: str) -> bool:
"""判断 Plan 所有 Rule 是否都有余量"""
rules = await list_quota_rules(plan_id)
for r in rules:
if not r["enabled"]:
continue
if r["quota_used"] >= r["quota_total"]:
return False
return True
# ── Model Route ───────────────────────────────────────
async def set_model_route(model_name: str, plan_id: str, priority: int = 0) -> dict:
db = await get_db()
mid = new_id()
await db.execute(
"DELETE FROM model_routes WHERE model_name=? AND plan_id=?",
(model_name, plan_id),
)
await db.execute(
"INSERT INTO model_routes (id, model_name, plan_id, priority) VALUES (?,?,?,?)",
(mid, model_name, plan_id, priority),
)
await db.commit()
return {"id": mid, "model_name": model_name, "plan_id": plan_id}
async def resolve_model(model_name: str) -> str | None:
"""按 priority 降序找到可用的 plan_idfallback 逻辑)"""
db = await get_db()
cur = await db.execute(
"SELECT plan_id FROM model_routes WHERE model_name=? ORDER BY priority DESC",
(model_name,),
)
rows = await cur.fetchall()
for row in rows:
pid = row["plan_id"]
if await check_plan_available(pid):
return pid
return rows[0]["plan_id"] if rows else None
async def list_model_routes() -> list[dict]:
db = await get_db()
cur = await db.execute("SELECT * FROM model_routes ORDER BY model_name, priority DESC")
return [row_to_dict(r) for r in await cur.fetchall()]
async def delete_model_route(route_id: str) -> bool:
db = await get_db()
cur = await db.execute("DELETE FROM model_routes WHERE id=?", (route_id,))
await db.commit()
return cur.rowcount > 0
# ── Task Queue ────────────────────────────────────────
async def create_task(
task_type: str,
request_payload: dict,
plan_id: str | None = None,
priority: int = 0,
max_retries: int = 3,
callback_url: str | None = None,
) -> dict:
db = await get_db()
tid = new_id()
now = _now_iso()
await db.execute(
"""INSERT INTO tasks
(id, plan_id, task_type, status, request_payload, priority,
max_retries, callback_url, created_at)
VALUES (?,?,?,?,?,?,?,?,?)""",
(tid, plan_id, task_type, "pending", json.dumps(request_payload),
priority, max_retries, callback_url, now),
)
await db.commit()
return {"id": tid, "status": "pending"}
async def list_tasks(status: str | None = None, limit: int = 50) -> list[dict]:
db = await get_db()
sql = "SELECT * FROM tasks"
params: list = []
if status:
sql += " WHERE status=?"
params.append(status)
sql += " ORDER BY priority DESC, created_at ASC LIMIT ?"
params.append(limit)
cur = await db.execute(sql, params)
rows = await cur.fetchall()
result = []
for row in rows:
d = row_to_dict(row)
d["request_payload"] = _parse_json(d["request_payload"], {})
d["response_payload"] = _parse_json(d.get("response_payload"))
result.append(d)
return result
async def update_task(task_id: str, **fields) -> bool:
db = await get_db()
json_fields = ("request_payload", "response_payload")
sets, vals = [], []
for k, v in fields.items():
if v is None:
continue
if k in json_fields and isinstance(v, dict):
v = json.dumps(v)
sets.append(f"{k}=?")
vals.append(v)
if not sets:
return False
vals.append(task_id)
await db.execute(f"UPDATE tasks SET {', '.join(sets)} WHERE id=?", vals)
await db.commit()
return True
async def get_task(task_id: str) -> dict | None:
db = await get_db()
cur = await db.execute("SELECT * FROM tasks WHERE id=?", (task_id,))
row = await cur.fetchone()
if not row:
return None
t = row_to_dict(row)
t["request_payload"] = _parse_json(t["request_payload"], {})
t["response_payload"] = _parse_json(t.get("response_payload"))
return t
# ── 种子数据导入 ──────────────────────────────────────
async def seed_from_config():
"""首次启动时从 config.yaml 导入 Plan + QuotaRule + ModelRoute"""
from app.config import settings as cfg
db = await get_db()
cur = await db.execute("SELECT COUNT(*) as cnt FROM plans")
row = await cur.fetchone()
if row["cnt"] > 0:
return
for ps in cfg.plans:
plan = await create_plan(
name=ps.name,
provider_name=ps.provider,
api_key=ps.api_key,
api_base=ps.api_base,
plan_type=ps.plan_type,
supported_models=ps.supported_models,
extra_headers=ps.extra_headers,
extra_config=ps.extra_config,
)
for qr in ps.quota_rules:
await create_quota_rule(
plan_id=plan["id"],
rule_name=qr.rule_name,
quota_total=qr.quota_total,
quota_unit=qr.quota_unit,
refresh_type=qr.refresh_type,
interval_hours=qr.interval_hours,
calendar_unit=qr.calendar_unit,
calendar_anchor=qr.calendar_anchor,
)
for model in ps.supported_models:
await set_model_route(model, plan["id"], priority=0)