- 添加 pytest 配置和测试依赖到 requirements.txt - 创建测试包结构和 fixtures (conftest.py) - 添加数据库模块的 CRUD 操作测试 (test_database.py) - 添加 Provider 插件系统测试 (test_providers.py) - 添加调度器模块测试 (test_scheduler.py) - 添加 API 路由测试 (test_api.py) - 添加回归测试覆盖边界条件和错误处理 (test_regressions.py) - 添加健康检查端点用于容器监控 - 修复调度器中的日历计算逻辑和任务执行参数处理 - 更新数据库函数以返回操作结果状态
348 lines
11 KiB
Python
348 lines
11 KiB
Python
"""调度器模块测试"""
|
||
|
||
import pytest
|
||
from datetime import datetime, timedelta, timezone
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
from app.services.scheduler import (
|
||
_now,
|
||
_parse_dt,
|
||
_compute_next_calendar,
|
||
_refresh_quota_rules,
|
||
_process_task_queue,
|
||
_execute_task,
|
||
)
|
||
|
||
|
||
class TestUtilityFunctions:
|
||
"""工具函数测试"""
|
||
|
||
def test_now_returns_utc(self):
|
||
"""测试 _now 返回 UTC 时间"""
|
||
now = _now()
|
||
assert now.tzinfo == timezone.utc
|
||
|
||
def test_parse_dt_valid(self):
|
||
"""测试解析有效的日期时间字符串"""
|
||
s = "2026-03-31T12:00:00+00:00"
|
||
result = _parse_dt(s)
|
||
assert result is not None
|
||
assert result.year == 2026
|
||
assert result.month == 3
|
||
assert result.day == 31
|
||
|
||
def test_parse_dt_none(self):
|
||
"""测试解析 None"""
|
||
result = _parse_dt(None)
|
||
assert result is None
|
||
|
||
def test_parse_dt_invalid(self):
|
||
"""测试解析无效字符串"""
|
||
result = _parse_dt("not-a-date")
|
||
assert result is None
|
||
|
||
|
||
class TestComputeNextCalendar:
|
||
"""计算下一个自然周期测试"""
|
||
|
||
def test_daily_anchor(self):
|
||
"""测试每日刷新点"""
|
||
# 假设现在是 3月31日 15:00
|
||
after = datetime(2026, 3, 31, 15, 0, tzinfo=timezone.utc)
|
||
anchor = {"hour": 0}
|
||
|
||
next_time = _compute_next_calendar("daily", anchor, after)
|
||
|
||
# 应该是第二天 0 点
|
||
assert next_time.day == 1
|
||
assert next_time.hour == 0
|
||
assert next_time.month == 4
|
||
|
||
def test_daily_anchor_before_hour(self):
|
||
"""测试每日刷新点(当前时间在刷新点之前)"""
|
||
after = datetime(2026, 3, 31, 8, 0, tzinfo=timezone.utc)
|
||
anchor = {"hour": 10}
|
||
|
||
next_time = _compute_next_calendar("daily", anchor, after)
|
||
|
||
# 应该是当天 10 点
|
||
assert next_time.day == 31
|
||
assert next_time.hour == 10
|
||
|
||
def test_weekly_anchor(self):
|
||
"""测试每周刷新点(周一 0 点)"""
|
||
# 假设现在是周三 (2026-03-31 是周二,用 4月1日周三)
|
||
after = datetime(2026, 4, 1, 10, 0, tzinfo=timezone.utc) # 周三
|
||
anchor = {"weekday": 1, "hour": 0} # 周一
|
||
|
||
next_time = _compute_next_calendar("weekly", anchor, after)
|
||
|
||
# 应该是下周周一
|
||
assert next_time.hour == 0
|
||
# 4月1日是周三,下一个周一是4月6日
|
||
assert next_time.day == 6
|
||
|
||
def test_weekly_anchor_same_day(self):
|
||
"""测试每周刷新点(当天是刷新日但已过时间)"""
|
||
after = datetime(2026, 3, 31, 10, 0, tzinfo=timezone.utc) # 周二
|
||
anchor = {"weekday": 2, "hour": 0} # 周二
|
||
|
||
next_time = _compute_next_calendar("weekly", anchor, after)
|
||
|
||
# 应该是下周二
|
||
assert next_time.day == 7 # 3月31日是周二,下周二是4月7日
|
||
|
||
def test_monthly_anchor_first_day(self):
|
||
"""测试每月刷新点(每月1号)"""
|
||
after = datetime(2026, 3, 15, 10, 0, tzinfo=timezone.utc)
|
||
anchor = {"day": 1, "hour": 0}
|
||
|
||
next_time = _compute_next_calendar("monthly", anchor, after)
|
||
|
||
# 应该是4月1号
|
||
assert next_time.day == 1
|
||
assert next_time.month == 4
|
||
assert next_time.hour == 0
|
||
|
||
def test_monthly_anchor_last_day(self):
|
||
"""测试每月刷新点(31号)"""
|
||
after = datetime(2026, 3, 15, 10, 0, tzinfo=timezone.utc)
|
||
anchor = {"day": 31, "hour": 0}
|
||
|
||
next_time = _compute_next_calendar("monthly", anchor, after)
|
||
|
||
# 3月31号
|
||
assert next_time.day == 31
|
||
assert next_time.month == 3
|
||
|
||
def test_monthly_anchor_invalid_day(self):
|
||
"""测试每月刷新点(不存在日期,如2月31号)"""
|
||
after = datetime(2026, 3, 15, 10, 0, tzinfo=timezone.utc)
|
||
anchor = {"day": 30, "hour": 0} # 2月没有30号
|
||
|
||
next_time = _compute_next_calendar("monthly", anchor, after)
|
||
|
||
# 应该是3月30号
|
||
assert next_time.day == 30
|
||
assert next_time.month == 3
|
||
|
||
|
||
class TestRefreshQuotaRules:
|
||
"""额度刷新测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_manual_rule_no_refresh(self, temp_db):
|
||
"""测试手动规则不刷新"""
|
||
await _refresh_quota_rules()
|
||
# 手动规则应该被跳过,不抛出错误
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_fixed_interval_refresh(self, temp_db):
|
||
"""测试固定间隔刷新"""
|
||
from app import database as db
|
||
from datetime import datetime, timezone, timedelta
|
||
|
||
# 创建 Plan 和 Rule
|
||
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
|
||
rule = await db.create_quota_rule(
|
||
plan_id=plan["id"],
|
||
rule_name="Hourly",
|
||
quota_total=100,
|
||
refresh_type="fixed_interval",
|
||
interval_hours=1,
|
||
)
|
||
|
||
# 设置 last_refresh_at 和 next_refresh_at 为过去时间
|
||
past = datetime.now(timezone.utc) - timedelta(hours=2)
|
||
await db.update_quota_rule(
|
||
rule["id"],
|
||
last_refresh_at=past.isoformat(),
|
||
next_refresh_at=past.isoformat(),
|
||
)
|
||
await db.update_quota_rule(rule["id"], quota_used=99)
|
||
|
||
# 执行刷新
|
||
await _refresh_quota_rules()
|
||
|
||
# 验证额度已重置
|
||
rules = await db.list_quota_rules(plan["id"])
|
||
assert rules[0]["quota_used"] == 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_calendar_cycle_refresh(self, temp_db):
|
||
"""测试日历周期刷新"""
|
||
from app import database as db
|
||
from datetime import datetime, timezone, timedelta
|
||
|
||
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
|
||
rule = await db.create_quota_rule(
|
||
plan_id=plan["id"],
|
||
rule_name="Daily",
|
||
quota_total=100,
|
||
refresh_type="calendar_cycle",
|
||
calendar_unit="daily",
|
||
calendar_anchor={"hour": 0},
|
||
)
|
||
|
||
# 设置为昨天
|
||
past = datetime.now(timezone.utc) - timedelta(days=1)
|
||
await db.update_quota_rule(
|
||
rule["id"],
|
||
last_refresh_at=past.isoformat(),
|
||
next_refresh_at=past.isoformat(),
|
||
)
|
||
await db.update_quota_rule(rule["id"], quota_used=99)
|
||
|
||
await _refresh_quota_rules()
|
||
|
||
rules = await db.list_quota_rules(plan["id"])
|
||
assert rules[0]["quota_used"] == 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_api_sync(self, temp_db):
|
||
"""测试 API 同步刷新"""
|
||
from app import database as db
|
||
from app.providers import ProviderRegistry
|
||
from app.providers.base import QuotaInfo
|
||
|
||
# 注册 Mock Provider
|
||
class MockSyncProvider:
|
||
name = "mock_sync"
|
||
capabilities = []
|
||
|
||
async def query_quota(self, plan):
|
||
return QuotaInfo(quota_used=42, quota_total=100, quota_remaining=58, unit="tokens")
|
||
|
||
ProviderRegistry._providers["mock_sync"] = MockSyncProvider()
|
||
|
||
plan = await db.create_plan(name="Test", provider_name="mock_sync", api_key="sk-test")
|
||
rule = await db.create_quota_rule(
|
||
plan_id=plan["id"],
|
||
rule_name="API Sync",
|
||
quota_total=100,
|
||
refresh_type="api_sync",
|
||
)
|
||
|
||
# 设置 last_refresh_at 为超过10分钟前
|
||
past = datetime.now(timezone.utc) - timedelta(seconds=601)
|
||
await db.update_quota_rule(rule["id"], last_refresh_at=past.isoformat())
|
||
|
||
await _refresh_quota_rules()
|
||
|
||
# 验证不会抛出错误即可(实际更新取决于调度逻辑)
|
||
rules = await db.list_quota_rules(plan["id"])
|
||
assert rules[0]["rule_name"] == "API Sync"
|
||
|
||
|
||
class TestProcessTaskQueue:
|
||
"""任务队列处理测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_process_pending_tasks(self, temp_db):
|
||
"""测试处理待处理任务"""
|
||
from app import database as db
|
||
|
||
# 创建 Plan
|
||
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
|
||
await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100)
|
||
|
||
# 创建任务
|
||
task = await db.create_task(
|
||
task_type="image",
|
||
request_payload={"prompt": "test"},
|
||
plan_id=plan["id"],
|
||
)
|
||
|
||
# Mock Provider
|
||
from app.providers import ProviderRegistry
|
||
|
||
class MockTaskProvider:
|
||
name = "mock_task"
|
||
capabilities = []
|
||
|
||
async def generate_image(self, prompt, plan, **kwargs):
|
||
return {"url": "https://example.com/test.png"}
|
||
|
||
ProviderRegistry._providers["openai"] = MockTaskProvider()
|
||
|
||
await _process_task_queue()
|
||
|
||
# 验证任务状态
|
||
updated_task = await db.get_task(task["id"])
|
||
# 由于我们没有真实的 Provider 实现完整逻辑,任务可能保持 pending 或失败
|
||
# 这里只验证不会抛出错误
|
||
assert updated_task is not None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_skip_task_no_quota(self, temp_db):
|
||
"""测试跳过额度不足的任务"""
|
||
from app import database as db
|
||
|
||
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
|
||
rule = await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100)
|
||
await db.update_quota_rule(rule["id"], quota_used=100)
|
||
|
||
await db.create_task(
|
||
task_type="image",
|
||
request_payload={"prompt": "test"},
|
||
plan_id=plan["id"],
|
||
)
|
||
|
||
# 应该跳过而不抛出错误
|
||
await _process_task_queue()
|
||
|
||
|
||
class TestExecuteTask:
|
||
"""任务执行测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_execute_image_task(self):
|
||
"""测试执行图片生成任务"""
|
||
from app.providers import ProviderRegistry
|
||
|
||
class MockImageProvider:
|
||
name = "mock_img"
|
||
capabilities = []
|
||
|
||
async def generate_image(self, prompt, plan, **kwargs):
|
||
return {"url": f"https://example.com/{prompt}.png"}
|
||
|
||
ProviderRegistry._providers["mock_img"] = MockImageProvider()
|
||
|
||
plan = {"api_key": "sk-test"}
|
||
task = {
|
||
"id": "task123",
|
||
"task_type": "image",
|
||
"request_payload": {"prompt": "a cat", "size": "512x512"},
|
||
}
|
||
|
||
provider = MockImageProvider()
|
||
result = await _execute_task(provider, plan, task)
|
||
|
||
assert "url" in result
|
||
assert "cat" in result["url"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_execute_unknown_task(self):
|
||
"""测试执行未知类型任务"""
|
||
from app.providers import ProviderRegistry
|
||
|
||
class MockProvider:
|
||
name = "mock"
|
||
capabilities = []
|
||
|
||
ProviderRegistry._providers["mock"] = MockProvider()
|
||
|
||
plan = {"api_key": "sk-test"}
|
||
task = {
|
||
"id": "task123",
|
||
"task_type": "unknown_type",
|
||
"request_payload": {},
|
||
}
|
||
|
||
provider = MockProvider()
|
||
result = await _execute_task(provider, plan, task)
|
||
|
||
assert "error" in result
|
||
assert "Unknown task type" in result["error"]
|