import asyncio import json from datetime import datetime, timedelta, timezone from pathlib import Path import sys import pytest from fastapi import FastAPI, HTTPException from fastapi.testclient import TestClient sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from app import database as db from app.config import settings from app.routers import proxy from app.services import scheduler @pytest.fixture(autouse=True) def isolated_db(tmp_path): old_path = settings.database.path settings.database.path = str(tmp_path / "test.db") asyncio.run(db.close_db()) yield asyncio.run(db.close_db()) settings.database.path = old_path def test_update_functions_report_missing_rows(): # update_plan 现在返回 cur.rowcount > 0,不存在的行返回 False result = asyncio.run(db.update_plan("missing", name="x")) assert result is False assert asyncio.run(db.update_quota_rule("missing", rule_name="x")) is False assert asyncio.run(db.update_task("missing", status="cancelled")) is False def test_update_functions_report_success_for_existing_rows(): plan = asyncio.run(db.create_plan(name="a", provider_name="openai")) rule = asyncio.run( db.create_quota_rule(plan_id=plan["id"], rule_name="r", quota_total=10) ) task = asyncio.run(db.create_task(task_type="image", request_payload={"prompt": "p"})) assert asyncio.run(db.update_plan(plan["id"], name="b")) is True assert asyncio.run(db.update_quota_rule(rule["id"], quota_total=20)) is True assert asyncio.run(db.update_task(task["id"], status="running")) is True def test_execute_task_avoids_duplicate_prompt_argument(): class Provider: async def generate_image(self, prompt: str, plan: dict, **kwargs): return {"prompt": prompt, "kwargs": kwargs} task = { "id": "t1", "task_type": "image", "request_payload": {"prompt": "hello", "model": "m1"}, } result = asyncio.run(scheduler._execute_task(Provider(), {}, task)) assert result["prompt"] == "hello" assert result["kwargs"] == {"model": "m1"} def test_compute_next_calendar_monthly_handles_large_day_anchor(): after = datetime(2026, 2, 15, 12, 0, tzinfo=timezone.utc) next_at = scheduler._compute_next_calendar("monthly", {"day": 31, "hour": 0}, after) assert next_at == datetime(2026, 2, 28, 0, 0, tzinfo=timezone.utc) after2 = datetime(2026, 2, 28, 1, 0, tzinfo=timezone.utc) next_at2 = scheduler._compute_next_calendar("monthly", {"day": 31, "hour": 0}, after2) assert next_at2 == datetime(2026, 3, 31, 0, 0, tzinfo=timezone.utc) def test_refresh_quota_rules_isolates_api_sync_errors(monkeypatch): fixed_now = datetime(2026, 3, 1, 12, 0, tzinfo=timezone.utc) updates = [] rules = [ { "id": "r1", "rule_name": "api", "refresh_type": "api_sync", "plan_id": "p1", "last_refresh_at": None, "next_refresh_at": None, }, { "id": "r2", "rule_name": "fixed", "refresh_type": "fixed_interval", "interval_hours": 1, "last_refresh_at": (fixed_now - timedelta(hours=2)).isoformat(), "next_refresh_at": (fixed_now - timedelta(minutes=1)).isoformat(), }, ] async def fake_get_all_quota_rules(): return rules async def fake_get_plan(plan_id: str): return {"id": plan_id, "provider_name": "broken"} async def fake_update_quota_rule(rule_id: str, **fields): updates.append((rule_id, fields)) return True class BadProvider: async def query_quota(self, plan: dict): raise RuntimeError("boom") import app.providers as providers monkeypatch.setattr(scheduler, "_now", lambda: fixed_now) monkeypatch.setattr(scheduler.db, "get_all_quota_rules", fake_get_all_quota_rules) monkeypatch.setattr(scheduler.db, "get_plan", fake_get_plan) monkeypatch.setattr(scheduler.db, "update_quota_rule", fake_update_quota_rule) monkeypatch.setattr( providers.ProviderRegistry, "get", classmethod(lambda cls, name: BadProvider()), ) asyncio.run(scheduler._refresh_quota_rules()) assert any(rule_id == "r2" and fields.get("quota_used") == 0 for rule_id, fields in updates) def test_anthropic_route_forwards_extra_kwargs(monkeypatch): captured = {} class Provider: async def chat(self, messages, model, plan, stream=True, **kwargs): captured["kwargs"] = kwargs payload = { "choices": [{"message": {"content": "ok"}}], "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, } yield json.dumps(payload) async def fake_resolve_model(model: str): return "p1" async def fake_get_plan(plan_id: str): return {"id": plan_id, "name": "p", "provider_name": "x"} async def fake_check_plan_available(plan_id: str): return True async def fake_increment_quota_used(plan_id: str, token_count: int = 0): return None provider = Provider() monkeypatch.setattr(proxy.db, "resolve_model", fake_resolve_model) monkeypatch.setattr(proxy.db, "get_plan", fake_get_plan) monkeypatch.setattr(proxy.db, "check_plan_available", fake_check_plan_available) monkeypatch.setattr(proxy.db, "increment_quota_used", fake_increment_quota_used) monkeypatch.setattr( proxy.ProviderRegistry, "get", classmethod(lambda cls, name: provider), ) app = FastAPI() app.include_router(proxy.router) old_key = settings.server.proxy_api_key settings.server.proxy_api_key = "" try: client = TestClient(app) resp = client.post( "/v1/messages", json={ "model": "m1", "messages": [{"role": "user", "content": "hello"}], "temperature": 0.2, "max_tokens": 88, }, ) finally: settings.server.proxy_api_key = old_key assert resp.status_code == 200 assert captured["kwargs"] == {"temperature": 0.2, "max_tokens": 88} def test_verify_key_requires_auth_when_using_default_value(): old_key = settings.server.proxy_api_key settings.server.proxy_api_key = "sk-plan-manage-change-me" try: with pytest.raises(HTTPException) as exc: proxy._verify_key(None) finally: settings.server.proxy_api_key = old_key assert exc.value.status_code == 401