test: 添加测试框架和全面的单元测试
- 添加 pytest 配置和测试依赖到 requirements.txt - 创建测试包结构和 fixtures (conftest.py) - 添加数据库模块的 CRUD 操作测试 (test_database.py) - 添加 Provider 插件系统测试 (test_providers.py) - 添加调度器模块测试 (test_scheduler.py) - 添加 API 路由测试 (test_api.py) - 添加回归测试覆盖边界条件和错误处理 (test_regressions.py) - 添加健康检查端点用于容器监控 - 修复调度器中的日历计算逻辑和任务执行参数处理 - 更新数据库函数以返回操作结果状态
This commit is contained in:
@@ -27,6 +27,12 @@ class StorageConfig(BaseModel):
|
||||
path: str = "./data/files"
|
||||
|
||||
|
||||
class SchedulerConfig(BaseModel):
|
||||
"""调度器配置"""
|
||||
task_processing_limit: int = 5 # 每次处理的最大任务数
|
||||
loop_interval_seconds: int = 30 # 主循环间隔(秒)
|
||||
|
||||
|
||||
class QuotaRuleSeed(BaseModel):
|
||||
"""config.yaml 中单条 QuotaRule 种子"""
|
||||
rule_name: str
|
||||
@@ -55,6 +61,7 @@ class AppConfig(BaseModel):
|
||||
server: ServerConfig = Field(default_factory=ServerConfig)
|
||||
database: DatabaseConfig = Field(default_factory=DatabaseConfig)
|
||||
storage: StorageConfig = Field(default_factory=StorageConfig)
|
||||
scheduler: SchedulerConfig = Field(default_factory=SchedulerConfig)
|
||||
plans: list[PlanSeed] = Field(default_factory=list)
|
||||
|
||||
|
||||
|
||||
@@ -306,6 +306,13 @@ async def update_quota_rule(rule_id: str, **fields) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
async def delete_quota_rule(rule_id: str) -> bool:
|
||||
db = await get_db()
|
||||
cur = await db.execute("DELETE FROM quota_rules WHERE id=?", (rule_id,))
|
||||
await db.commit()
|
||||
return cur.rowcount > 0
|
||||
|
||||
|
||||
async def increment_quota_used(plan_id: str, token_count: int = 0):
|
||||
"""请求完成后增加该 Plan 所有 Rule 的 quota_used"""
|
||||
db = await get_db()
|
||||
@@ -441,6 +448,18 @@ async def update_task(task_id: str, **fields) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
async def get_task(task_id: str) -> dict | None:
|
||||
db = await get_db()
|
||||
cur = await db.execute("SELECT * FROM tasks WHERE id=?", (task_id,))
|
||||
row = await cur.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
t = row_to_dict(row)
|
||||
t["request_payload"] = _parse_json(t["request_payload"], {})
|
||||
t["response_payload"] = _parse_json(t.get("response_payload"))
|
||||
return t
|
||||
|
||||
|
||||
# ── 种子数据导入 ──────────────────────────────────────
|
||||
|
||||
async def seed_from_config():
|
||||
|
||||
@@ -50,3 +50,9 @@ _static_dir = Path(__file__).parent / "static"
|
||||
@app.get("/", include_in_schema=False)
|
||||
async def serve_index():
|
||||
return FileResponse(_static_dir / "index.html")
|
||||
|
||||
|
||||
@app.get("/health", tags=["Health"])
|
||||
async def health_check():
|
||||
"""健康检查端点,用于容器健康检查"""
|
||||
return {"status": "healthy", "service": "plan-manager"}
|
||||
|
||||
@@ -105,10 +105,8 @@ async def update_rule(rule_id: str, body: QuotaRuleUpdate):
|
||||
|
||||
@router.delete("/rules/{rule_id}")
|
||||
async def delete_rule(rule_id: str):
|
||||
d = await db.get_db()
|
||||
cur = await d.execute("DELETE FROM quota_rules WHERE id=?", (rule_id,))
|
||||
await d.commit()
|
||||
if cur.rowcount == 0:
|
||||
ok = await db.delete_quota_rule(rule_id)
|
||||
if not ok:
|
||||
raise HTTPException(404, "Rule not found")
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Header, HTTPException, Request
|
||||
@@ -18,8 +17,8 @@ router = APIRouter()
|
||||
|
||||
def _verify_key(authorization: str | None):
|
||||
expected = settings.server.proxy_api_key
|
||||
if not expected or expected == "sk-plan-manage-change-me":
|
||||
return # 未配置则跳过鉴权
|
||||
if not expected:
|
||||
return
|
||||
if not authorization:
|
||||
raise HTTPException(401, "Missing Authorization header")
|
||||
token = authorization.removeprefix("Bearer ").strip()
|
||||
@@ -150,13 +149,17 @@ async def anthropic_messages(
|
||||
if not provider:
|
||||
raise HTTPException(500, f"Provider '{plan['provider_name']}' not registered")
|
||||
|
||||
extra_kwargs = {k: v for k, v in body.items() if k not in ("model", "messages", "stream", "system")}
|
||||
|
||||
if stream:
|
||||
async def anthropic_stream():
|
||||
"""将 OpenAI SSE 格式转换为 Anthropic SSE 格式"""
|
||||
yield f"event: message_start\ndata: {json.dumps({'type': 'message_start', 'message': {'id': 'msg_proxy', 'type': 'message', 'role': 'assistant', 'model': model, 'content': []}})}\n\n"
|
||||
yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}})}\n\n"
|
||||
|
||||
async for chunk_data in _stream_and_count(provider, oai_messages, model, plan, True):
|
||||
async for chunk_data in _stream_and_count(
|
||||
provider, oai_messages, model, plan, True, **extra_kwargs
|
||||
):
|
||||
if chunk_data.startswith("data: [DONE]"):
|
||||
break
|
||||
if chunk_data.startswith("data: "):
|
||||
@@ -179,7 +182,9 @@ async def anthropic_messages(
|
||||
)
|
||||
else:
|
||||
chunks = []
|
||||
async for c in _stream_and_count(provider, oai_messages, model, plan, False):
|
||||
async for c in _stream_and_count(
|
||||
provider, oai_messages, model, plan, False, **extra_kwargs
|
||||
):
|
||||
chunks.append(c)
|
||||
oai_resp = json.loads(chunks[0]) if chunks else {}
|
||||
# OpenAI 响应 -> Anthropic 响应
|
||||
|
||||
@@ -27,14 +27,9 @@ async def create_task(body: TaskCreate):
|
||||
|
||||
@router.get("/{task_id}", response_model=TaskOut)
|
||||
async def get_task(task_id: str):
|
||||
d = await db.get_db()
|
||||
cur = await d.execute("SELECT * FROM tasks WHERE id=?", (task_id,))
|
||||
row = await cur.fetchone()
|
||||
if not row:
|
||||
t = await db.get_task(task_id)
|
||||
if not t:
|
||||
raise HTTPException(404, "Task not found")
|
||||
t = db.row_to_dict(row)
|
||||
t["request_payload"] = db._parse_json(t["request_payload"], {})
|
||||
t["response_payload"] = db._parse_json(t.get("response_payload"))
|
||||
return t
|
||||
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from calendar import monthrange
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from app import database as db
|
||||
@@ -55,28 +55,24 @@ def _compute_next_calendar(calendar_unit: str, anchor: dict, after: datetime) ->
|
||||
return candidate
|
||||
|
||||
if calendar_unit == "monthly":
|
||||
day = anchor.get("day", 1)
|
||||
day = int(anchor.get("day", 1))
|
||||
if day < 1:
|
||||
day = 1
|
||||
if day > 31:
|
||||
day = 31
|
||||
year, month = after.year, after.month
|
||||
try:
|
||||
candidate = after.replace(day=day, hour=hour, minute=0, second=0, microsecond=0)
|
||||
except ValueError:
|
||||
# 日期不存在时(如 2 月 30 号),跳到下月
|
||||
month += 1
|
||||
if month > 12:
|
||||
month, year = 1, year + 1
|
||||
candidate = after.replace(year=year, month=month, day=day,
|
||||
hour=hour, minute=0, second=0, microsecond=0)
|
||||
month_last_day = monthrange(year, month)[1]
|
||||
candidate_day = min(day, month_last_day)
|
||||
candidate = after.replace(
|
||||
day=candidate_day, hour=hour, minute=0, second=0, microsecond=0
|
||||
)
|
||||
if candidate <= after:
|
||||
month += 1
|
||||
if month > 12:
|
||||
month, year = 1, year + 1
|
||||
try:
|
||||
candidate = candidate.replace(year=year, month=month)
|
||||
except ValueError:
|
||||
month += 1
|
||||
if month > 12:
|
||||
month, year = 1, year + 1
|
||||
candidate = candidate.replace(year=year, month=month, day=1)
|
||||
month_last_day = monthrange(year, month)[1]
|
||||
candidate_day = min(day, month_last_day)
|
||||
candidate = candidate.replace(year=year, month=month, day=candidate_day)
|
||||
return candidate
|
||||
|
||||
return after + timedelta(days=1)
|
||||
@@ -131,9 +127,9 @@ async def _refresh_quota_rules():
|
||||
logger.info("Refreshed rule %s (calendar %s)", rule["rule_name"], cal_unit)
|
||||
|
||||
elif rt == "api_sync":
|
||||
# 每 10 分钟同步一次
|
||||
last_at = _parse_dt(rule.get("last_refresh_at"))
|
||||
if last_at and (now - last_at).total_seconds() < 600:
|
||||
sync_interval = rule.get("sync_interval_seconds") or 600 # 默认 10 分钟
|
||||
if last_at and (now - last_at).total_seconds() < sync_interval:
|
||||
continue
|
||||
plan = await db.get_plan(rule["plan_id"])
|
||||
if not plan:
|
||||
@@ -141,19 +137,24 @@ async def _refresh_quota_rules():
|
||||
from app.providers import ProviderRegistry
|
||||
provider = ProviderRegistry.get(plan["provider_name"])
|
||||
if provider:
|
||||
info = await provider.query_quota(plan)
|
||||
if info:
|
||||
await db.update_quota_rule(
|
||||
rule["id"],
|
||||
quota_used=info.quota_used,
|
||||
last_refresh_at=now.isoformat(),
|
||||
)
|
||||
logger.info("API synced rule %s: used=%d", rule["rule_name"], info.quota_used)
|
||||
try:
|
||||
info = await provider.query_quota(plan)
|
||||
if info:
|
||||
await db.update_quota_rule(
|
||||
rule["id"],
|
||||
quota_used=info.quota_used,
|
||||
last_refresh_at=now.isoformat(),
|
||||
)
|
||||
logger.info("API synced rule %s: used=%d", rule["rule_name"], info.quota_used)
|
||||
except Exception as e:
|
||||
logger.error("API sync failed for rule %s: %s", rule["rule_name"], e)
|
||||
|
||||
|
||||
async def _process_task_queue():
|
||||
"""消费待处理任务"""
|
||||
tasks = await db.list_tasks(status="pending", limit=5)
|
||||
from app.config import settings
|
||||
limit = settings.scheduler.task_processing_limit
|
||||
tasks = await db.list_tasks(status="pending", limit=limit)
|
||||
for task in tasks:
|
||||
plan_id = task.get("plan_id")
|
||||
if plan_id and not await db.check_plan_available(plan_id):
|
||||
@@ -204,10 +205,13 @@ async def _execute_task(provider, plan: dict, task: dict) -> dict:
|
||||
payload = task.get("request_payload", {})
|
||||
|
||||
if tt == "image":
|
||||
return await provider.generate_image(payload.get("prompt", ""), plan, **payload)
|
||||
kwargs = dict(payload)
|
||||
prompt = kwargs.pop("prompt", "")
|
||||
return await provider.generate_image(prompt, plan, **kwargs)
|
||||
elif tt == "voice":
|
||||
audio = await provider.generate_voice(payload.get("text", ""), plan, **payload)
|
||||
# 保存到文件
|
||||
kwargs = dict(payload)
|
||||
text = kwargs.pop("text", "")
|
||||
audio = await provider.generate_voice(text, plan, **kwargs)
|
||||
from pathlib import Path
|
||||
from app.config import settings
|
||||
fpath = Path(settings.storage.path) / f"{task['id']}.mp3"
|
||||
@@ -215,7 +219,9 @@ async def _execute_task(provider, plan: dict, task: dict) -> dict:
|
||||
await db.update_task(task["id"], result_file_path=str(fpath), result_mime_type="audio/mp3")
|
||||
return {"file": str(fpath)}
|
||||
elif tt == "video":
|
||||
return await provider.generate_video(payload.get("prompt", ""), plan, **payload)
|
||||
kwargs = dict(payload)
|
||||
prompt = kwargs.pop("prompt", "")
|
||||
return await provider.generate_video(prompt, plan, **kwargs)
|
||||
else:
|
||||
return {"error": f"Unknown task type: {tt}"}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user