Files
planManage/tests/test_scheduler.py

348 lines
11 KiB
Python
Raw Normal View History

"""调度器模块测试"""
import pytest
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
from app.services.scheduler import (
_now,
_parse_dt,
_compute_next_calendar,
_refresh_quota_rules,
_process_task_queue,
_execute_task,
)
class TestUtilityFunctions:
"""工具函数测试"""
def test_now_returns_utc(self):
"""测试 _now 返回 UTC 时间"""
now = _now()
assert now.tzinfo == timezone.utc
def test_parse_dt_valid(self):
"""测试解析有效的日期时间字符串"""
s = "2026-03-31T12:00:00+00:00"
result = _parse_dt(s)
assert result is not None
assert result.year == 2026
assert result.month == 3
assert result.day == 31
def test_parse_dt_none(self):
"""测试解析 None"""
result = _parse_dt(None)
assert result is None
def test_parse_dt_invalid(self):
"""测试解析无效字符串"""
result = _parse_dt("not-a-date")
assert result is None
class TestComputeNextCalendar:
"""计算下一个自然周期测试"""
def test_daily_anchor(self):
"""测试每日刷新点"""
# 假设现在是 3月31日 15:00
after = datetime(2026, 3, 31, 15, 0, tzinfo=timezone.utc)
anchor = {"hour": 0}
next_time = _compute_next_calendar("daily", anchor, after)
# 应该是第二天 0 点
assert next_time.day == 1
assert next_time.hour == 0
assert next_time.month == 4
def test_daily_anchor_before_hour(self):
"""测试每日刷新点(当前时间在刷新点之前)"""
after = datetime(2026, 3, 31, 8, 0, tzinfo=timezone.utc)
anchor = {"hour": 10}
next_time = _compute_next_calendar("daily", anchor, after)
# 应该是当天 10 点
assert next_time.day == 31
assert next_time.hour == 10
def test_weekly_anchor(self):
"""测试每周刷新点(周一 0 点)"""
# 假设现在是周三 (2026-03-31 是周二,用 4月1日周三)
after = datetime(2026, 4, 1, 10, 0, tzinfo=timezone.utc) # 周三
anchor = {"weekday": 1, "hour": 0} # 周一
next_time = _compute_next_calendar("weekly", anchor, after)
# 应该是下周周一
assert next_time.hour == 0
# 4月1日是周三下一个周一是4月6日
assert next_time.day == 6
def test_weekly_anchor_same_day(self):
"""测试每周刷新点(当天是刷新日但已过时间)"""
after = datetime(2026, 3, 31, 10, 0, tzinfo=timezone.utc) # 周二
anchor = {"weekday": 2, "hour": 0} # 周二
next_time = _compute_next_calendar("weekly", anchor, after)
# 应该是下周二
assert next_time.day == 7 # 3月31日是周二下周二是4月7日
def test_monthly_anchor_first_day(self):
"""测试每月刷新点每月1号"""
after = datetime(2026, 3, 15, 10, 0, tzinfo=timezone.utc)
anchor = {"day": 1, "hour": 0}
next_time = _compute_next_calendar("monthly", anchor, after)
# 应该是4月1号
assert next_time.day == 1
assert next_time.month == 4
assert next_time.hour == 0
def test_monthly_anchor_last_day(self):
"""测试每月刷新点31号"""
after = datetime(2026, 3, 15, 10, 0, tzinfo=timezone.utc)
anchor = {"day": 31, "hour": 0}
next_time = _compute_next_calendar("monthly", anchor, after)
# 3月31号
assert next_time.day == 31
assert next_time.month == 3
def test_monthly_anchor_invalid_day(self):
"""测试每月刷新点不存在日期如2月31号"""
after = datetime(2026, 3, 15, 10, 0, tzinfo=timezone.utc)
anchor = {"day": 30, "hour": 0} # 2月没有30号
next_time = _compute_next_calendar("monthly", anchor, after)
# 应该是3月30号
assert next_time.day == 30
assert next_time.month == 3
class TestRefreshQuotaRules:
"""额度刷新测试"""
@pytest.mark.asyncio
async def test_manual_rule_no_refresh(self, temp_db):
"""测试手动规则不刷新"""
await _refresh_quota_rules()
# 手动规则应该被跳过,不抛出错误
@pytest.mark.asyncio
async def test_fixed_interval_refresh(self, temp_db):
"""测试固定间隔刷新"""
from app import database as db
from datetime import datetime, timezone, timedelta
# 创建 Plan 和 Rule
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
rule = await db.create_quota_rule(
plan_id=plan["id"],
rule_name="Hourly",
quota_total=100,
refresh_type="fixed_interval",
interval_hours=1,
)
# 设置 last_refresh_at 和 next_refresh_at 为过去时间
past = datetime.now(timezone.utc) - timedelta(hours=2)
await db.update_quota_rule(
rule["id"],
last_refresh_at=past.isoformat(),
next_refresh_at=past.isoformat(),
)
await db.update_quota_rule(rule["id"], quota_used=99)
# 执行刷新
await _refresh_quota_rules()
# 验证额度已重置
rules = await db.list_quota_rules(plan["id"])
assert rules[0]["quota_used"] == 0
@pytest.mark.asyncio
async def test_calendar_cycle_refresh(self, temp_db):
"""测试日历周期刷新"""
from app import database as db
from datetime import datetime, timezone, timedelta
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
rule = await db.create_quota_rule(
plan_id=plan["id"],
rule_name="Daily",
quota_total=100,
refresh_type="calendar_cycle",
calendar_unit="daily",
calendar_anchor={"hour": 0},
)
# 设置为昨天
past = datetime.now(timezone.utc) - timedelta(days=1)
await db.update_quota_rule(
rule["id"],
last_refresh_at=past.isoformat(),
next_refresh_at=past.isoformat(),
)
await db.update_quota_rule(rule["id"], quota_used=99)
await _refresh_quota_rules()
rules = await db.list_quota_rules(plan["id"])
assert rules[0]["quota_used"] == 0
@pytest.mark.asyncio
async def test_api_sync(self, temp_db):
"""测试 API 同步刷新"""
from app import database as db
from app.providers import ProviderRegistry
from app.providers.base import QuotaInfo
# 注册 Mock Provider
class MockSyncProvider:
name = "mock_sync"
capabilities = []
async def query_quota(self, plan):
return QuotaInfo(quota_used=42, quota_total=100, quota_remaining=58, unit="tokens")
ProviderRegistry._providers["mock_sync"] = MockSyncProvider()
plan = await db.create_plan(name="Test", provider_name="mock_sync", api_key="sk-test")
rule = await db.create_quota_rule(
plan_id=plan["id"],
rule_name="API Sync",
quota_total=100,
refresh_type="api_sync",
)
# 设置 last_refresh_at 为超过10分钟前
past = datetime.now(timezone.utc) - timedelta(seconds=601)
await db.update_quota_rule(rule["id"], last_refresh_at=past.isoformat())
await _refresh_quota_rules()
# 验证不会抛出错误即可(实际更新取决于调度逻辑)
rules = await db.list_quota_rules(plan["id"])
assert rules[0]["rule_name"] == "API Sync"
class TestProcessTaskQueue:
"""任务队列处理测试"""
@pytest.mark.asyncio
async def test_process_pending_tasks(self, temp_db):
"""测试处理待处理任务"""
from app import database as db
# 创建 Plan
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100)
# 创建任务
task = await db.create_task(
task_type="image",
request_payload={"prompt": "test"},
plan_id=plan["id"],
)
# Mock Provider
from app.providers import ProviderRegistry
class MockTaskProvider:
name = "mock_task"
capabilities = []
async def generate_image(self, prompt, plan, **kwargs):
return {"url": "https://example.com/test.png"}
ProviderRegistry._providers["openai"] = MockTaskProvider()
await _process_task_queue()
# 验证任务状态
updated_task = await db.get_task(task["id"])
# 由于我们没有真实的 Provider 实现完整逻辑,任务可能保持 pending 或失败
# 这里只验证不会抛出错误
assert updated_task is not None
@pytest.mark.asyncio
async def test_skip_task_no_quota(self, temp_db):
"""测试跳过额度不足的任务"""
from app import database as db
plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test")
rule = await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100)
await db.update_quota_rule(rule["id"], quota_used=100)
await db.create_task(
task_type="image",
request_payload={"prompt": "test"},
plan_id=plan["id"],
)
# 应该跳过而不抛出错误
await _process_task_queue()
class TestExecuteTask:
"""任务执行测试"""
@pytest.mark.asyncio
async def test_execute_image_task(self):
"""测试执行图片生成任务"""
from app.providers import ProviderRegistry
class MockImageProvider:
name = "mock_img"
capabilities = []
async def generate_image(self, prompt, plan, **kwargs):
return {"url": f"https://example.com/{prompt}.png"}
ProviderRegistry._providers["mock_img"] = MockImageProvider()
plan = {"api_key": "sk-test"}
task = {
"id": "task123",
"task_type": "image",
"request_payload": {"prompt": "a cat", "size": "512x512"},
}
provider = MockImageProvider()
result = await _execute_task(provider, plan, task)
assert "url" in result
assert "cat" in result["url"]
@pytest.mark.asyncio
async def test_execute_unknown_task(self):
"""测试执行未知类型任务"""
from app.providers import ProviderRegistry
class MockProvider:
name = "mock"
capabilities = []
ProviderRegistry._providers["mock"] = MockProvider()
plan = {"api_key": "sk-test"}
task = {
"id": "task123",
"task_type": "unknown_type",
"request_payload": {},
}
provider = MockProvider()
result = await _execute_task(provider, plan, task)
assert "error" in result
assert "Unknown task type" in result["error"]