Files
planManage/tests/test_scheduler.py
congsh 37d282c0a2 test: 添加测试框架和全面的单元测试
- 添加 pytest 配置和测试依赖到 requirements.txt
- 创建测试包结构和 fixtures (conftest.py)
- 添加数据库模块的 CRUD 操作测试 (test_database.py)
- 添加 Provider 插件系统测试 (test_providers.py)
- 添加调度器模块测试 (test_scheduler.py)
- 添加 API 路由测试 (test_api.py)
- 添加回归测试覆盖边界条件和错误处理 (test_regressions.py)
- 添加健康检查端点用于容器监控
- 修复调度器中的日历计算逻辑和任务执行参数处理
- 更新数据库函数以返回操作结果状态
2026-03-31 22:36:18 +08:00

348 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""调度器模块测试"""
import pytest
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
from app.services.scheduler import (
_now,
_parse_dt,
_compute_next_calendar,
_refresh_quota_rules,
_process_task_queue,
_execute_task,
)
class TestUtilityFunctions:
"""工具函数测试"""
def test_now_returns_utc(self):
"""测试 _now 返回 UTC 时间"""
now = _now()
assert now.tzinfo == timezone.utc
def test_parse_dt_valid(self):
"""测试解析有效的日期时间字符串"""
s = "2026-03-31T12:00:00+00:00"
result = _parse_dt(s)
assert result is not None
assert result.year == 2026
assert result.month == 3
assert result.day == 31
def test_parse_dt_none(self):
"""测试解析 None"""
result = _parse_dt(None)
assert result is None
def test_parse_dt_invalid(self):
"""测试解析无效字符串"""
result = _parse_dt("not-a-date")
assert result is None
class TestComputeNextCalendar:
"""计算下一个自然周期测试"""
def test_daily_anchor(self):
"""测试每日刷新点"""
# 假设现在是 3月31日 15:00
after = datetime(2026, 3, 31, 15, 0, tzinfo=timezone.utc)
anchor = {"hour": 0}
next_time = _compute_next_calendar("daily", anchor, after)
# 应该是第二天 0 点
assert next_time.day == 1
assert next_time.hour == 0
assert next_time.month == 4
def test_daily_anchor_before_hour(self):
"""测试每日刷新点(当前时间在刷新点之前)"""
after = datetime(2026, 3, 31, 8, 0, tzinfo=timezone.utc)
anchor = {"hour": 10}
next_time = _compute_next_calendar("daily", anchor, after)
# 应该是当天 10 点
assert next_time.day == 31
assert next_time.hour == 10
def test_weekly_anchor(self):
"""测试每周刷新点(周一 0 点)"""
# 假设现在是周三 (2026-03-31 是周二,用 4月1日周三)
after = datetime(2026, 4, 1, 10, 0, tzinfo=timezone.utc) # 周三
anchor = {"weekday": 1, "hour": 0} # 周一
next_time = _compute_next_calendar("weekly", anchor, after)
# 应该是下周周一
assert next_time.hour == 0
# 4月1日是周三下一个周一是4月6日
assert next_time.day == 6
def test_weekly_anchor_same_day(self):
"""测试每周刷新点(当天是刷新日但已过时间)"""
after = datetime(2026, 3, 31, 10, 0, tzinfo=timezone.utc) # 周二
anchor = {"weekday": 2, "hour": 0} # 周二
next_time = _compute_next_calendar("weekly", anchor, after)
# 应该是下周二
assert next_time.day == 7 # 3月31日是周二下周二是4月7日
def test_monthly_anchor_first_day(self):
"""测试每月刷新点每月1号"""
after = datetime(2026, 3, 15, 10, 0, tzinfo=timezone.utc)
anchor = {"day": 1, "hour": 0}
next_time = _compute_next_calendar("monthly", anchor, after)
# 应该是4月1号
assert next_time.day == 1
assert next_time.month == 4
assert next_time.hour == 0
def test_monthly_anchor_last_day(self):
"""测试每月刷新点31号"""
after = datetime(2026, 3, 15, 10, 0, tzinfo=timezone.utc)
anchor = {"day": 31, "hour": 0}
next_time = _compute_next_calendar("monthly", anchor, after)
# 3月31号
assert next_time.day == 31
assert next_time.month == 3
def test_monthly_anchor_invalid_day(self):
"""测试每月刷新点不存在日期如2月31号"""
after = datetime(2026, 3, 15, 10, 0, tzinfo=timezone.utc)
anchor = {"day": 30, "hour": 0} # 2月没有30号
next_time = _compute_next_calendar("monthly", anchor, after)
# 应该是3月30号
assert next_time.day == 30
assert next_time.month == 3
class TestRefreshQuotaRules:
"""额度刷新测试"""
@pytest.mark.asyncio
async def test_manual_rule_no_refresh(self, temp_db):
"""测试手动规则不刷新"""
await _refresh_quota_rules()
# 手动规则应该被跳过,不抛出错误
@pytest.mark.asyncio
async def test_fixed_interval_refresh(self, temp_db):
"""测试固定间隔刷新"""
from app import database as db
from datetime import datetime, timezone, timedelta
# 创建 Plan 和 Rule
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
rule = await db.create_quota_rule(
plan_id=plan["id"],
rule_name="Hourly",
quota_total=100,
refresh_type="fixed_interval",
interval_hours=1,
)
# 设置 last_refresh_at 和 next_refresh_at 为过去时间
past = datetime.now(timezone.utc) - timedelta(hours=2)
await db.update_quota_rule(
rule["id"],
last_refresh_at=past.isoformat(),
next_refresh_at=past.isoformat(),
)
await db.update_quota_rule(rule["id"], quota_used=99)
# 执行刷新
await _refresh_quota_rules()
# 验证额度已重置
rules = await db.list_quota_rules(plan["id"])
assert rules[0]["quota_used"] == 0
@pytest.mark.asyncio
async def test_calendar_cycle_refresh(self, temp_db):
"""测试日历周期刷新"""
from app import database as db
from datetime import datetime, timezone, timedelta
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
rule = await db.create_quota_rule(
plan_id=plan["id"],
rule_name="Daily",
quota_total=100,
refresh_type="calendar_cycle",
calendar_unit="daily",
calendar_anchor={"hour": 0},
)
# 设置为昨天
past = datetime.now(timezone.utc) - timedelta(days=1)
await db.update_quota_rule(
rule["id"],
last_refresh_at=past.isoformat(),
next_refresh_at=past.isoformat(),
)
await db.update_quota_rule(rule["id"], quota_used=99)
await _refresh_quota_rules()
rules = await db.list_quota_rules(plan["id"])
assert rules[0]["quota_used"] == 0
@pytest.mark.asyncio
async def test_api_sync(self, temp_db):
"""测试 API 同步刷新"""
from app import database as db
from app.providers import ProviderRegistry
from app.providers.base import QuotaInfo
# 注册 Mock Provider
class MockSyncProvider:
name = "mock_sync"
capabilities = []
async def query_quota(self, plan):
return QuotaInfo(quota_used=42, quota_total=100, quota_remaining=58, unit="tokens")
ProviderRegistry._providers["mock_sync"] = MockSyncProvider()
plan = await db.create_plan(name="Test", provider_name="mock_sync", api_key="sk-test")
rule = await db.create_quota_rule(
plan_id=plan["id"],
rule_name="API Sync",
quota_total=100,
refresh_type="api_sync",
)
# 设置 last_refresh_at 为超过10分钟前
past = datetime.now(timezone.utc) - timedelta(seconds=601)
await db.update_quota_rule(rule["id"], last_refresh_at=past.isoformat())
await _refresh_quota_rules()
# 验证不会抛出错误即可(实际更新取决于调度逻辑)
rules = await db.list_quota_rules(plan["id"])
assert rules[0]["rule_name"] == "API Sync"
class TestProcessTaskQueue:
"""任务队列处理测试"""
@pytest.mark.asyncio
async def test_process_pending_tasks(self, temp_db):
"""测试处理待处理任务"""
from app import database as db
# 创建 Plan
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100)
# 创建任务
task = await db.create_task(
task_type="image",
request_payload={"prompt": "test"},
plan_id=plan["id"],
)
# Mock Provider
from app.providers import ProviderRegistry
class MockTaskProvider:
name = "mock_task"
capabilities = []
async def generate_image(self, prompt, plan, **kwargs):
return {"url": "https://example.com/test.png"}
ProviderRegistry._providers["openai"] = MockTaskProvider()
await _process_task_queue()
# 验证任务状态
updated_task = await db.get_task(task["id"])
# 由于我们没有真实的 Provider 实现完整逻辑,任务可能保持 pending 或失败
# 这里只验证不会抛出错误
assert updated_task is not None
@pytest.mark.asyncio
async def test_skip_task_no_quota(self, temp_db):
"""测试跳过额度不足的任务"""
from app import database as db
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
rule = await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100)
await db.update_quota_rule(rule["id"], quota_used=100)
await db.create_task(
task_type="image",
request_payload={"prompt": "test"},
plan_id=plan["id"],
)
# 应该跳过而不抛出错误
await _process_task_queue()
class TestExecuteTask:
"""任务执行测试"""
@pytest.mark.asyncio
async def test_execute_image_task(self):
"""测试执行图片生成任务"""
from app.providers import ProviderRegistry
class MockImageProvider:
name = "mock_img"
capabilities = []
async def generate_image(self, prompt, plan, **kwargs):
return {"url": f"https://example.com/{prompt}.png"}
ProviderRegistry._providers["mock_img"] = MockImageProvider()
plan = {"api_key": "sk-test"}
task = {
"id": "task123",
"task_type": "image",
"request_payload": {"prompt": "a cat", "size": "512x512"},
}
provider = MockImageProvider()
result = await _execute_task(provider, plan, task)
assert "url" in result
assert "cat" in result["url"]
@pytest.mark.asyncio
async def test_execute_unknown_task(self):
"""测试执行未知类型任务"""
from app.providers import ProviderRegistry
class MockProvider:
name = "mock"
capabilities = []
ProviderRegistry._providers["mock"] = MockProvider()
plan = {"api_key": "sk-test"}
task = {
"id": "task123",
"task_type": "unknown_type",
"request_payload": {},
}
provider = MockProvider()
result = await _execute_task(provider, plan, task)
assert "error" in result
assert "Unknown task type" in result["error"]