"""调度器模块测试""" import pytest from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, MagicMock, patch from app.services.scheduler import ( _now, _parse_dt, _compute_next_calendar, _refresh_quota_rules, _process_task_queue, _execute_task, ) class TestUtilityFunctions: """工具函数测试""" def test_now_returns_utc(self): """测试 _now 返回 UTC 时间""" now = _now() assert now.tzinfo == timezone.utc def test_parse_dt_valid(self): """测试解析有效的日期时间字符串""" s = "2026-03-31T12:00:00+00:00" result = _parse_dt(s) assert result is not None assert result.year == 2026 assert result.month == 3 assert result.day == 31 def test_parse_dt_none(self): """测试解析 None""" result = _parse_dt(None) assert result is None def test_parse_dt_invalid(self): """测试解析无效字符串""" result = _parse_dt("not-a-date") assert result is None class TestComputeNextCalendar: """计算下一个自然周期测试""" def test_daily_anchor(self): """测试每日刷新点""" # 假设现在是 3月31日 15:00 after = datetime(2026, 3, 31, 15, 0, tzinfo=timezone.utc) anchor = {"hour": 0} next_time = _compute_next_calendar("daily", anchor, after) # 应该是第二天 0 点 assert next_time.day == 1 assert next_time.hour == 0 assert next_time.month == 4 def test_daily_anchor_before_hour(self): """测试每日刷新点(当前时间在刷新点之前)""" after = datetime(2026, 3, 31, 8, 0, tzinfo=timezone.utc) anchor = {"hour": 10} next_time = _compute_next_calendar("daily", anchor, after) # 应该是当天 10 点 assert next_time.day == 31 assert next_time.hour == 10 def test_weekly_anchor(self): """测试每周刷新点(周一 0 点)""" # 假设现在是周三 (2026-03-31 是周二,用 4月1日周三) after = datetime(2026, 4, 1, 10, 0, tzinfo=timezone.utc) # 周三 anchor = {"weekday": 1, "hour": 0} # 周一 next_time = _compute_next_calendar("weekly", anchor, after) # 应该是下周周一 assert next_time.hour == 0 # 4月1日是周三,下一个周一是4月6日 assert next_time.day == 6 def test_weekly_anchor_same_day(self): """测试每周刷新点(当天是刷新日但已过时间)""" after = datetime(2026, 3, 31, 10, 0, tzinfo=timezone.utc) # 周二 anchor = {"weekday": 2, "hour": 0} # 周二 next_time = _compute_next_calendar("weekly", anchor, after) # 应该是下周二 assert next_time.day == 7 # 3月31日是周二,下周二是4月7日 def test_monthly_anchor_first_day(self): """测试每月刷新点(每月1号)""" after = datetime(2026, 3, 15, 10, 0, tzinfo=timezone.utc) anchor = {"day": 1, "hour": 0} next_time = _compute_next_calendar("monthly", anchor, after) # 应该是4月1号 assert next_time.day == 1 assert next_time.month == 4 assert next_time.hour == 0 def test_monthly_anchor_last_day(self): """测试每月刷新点(31号)""" after = datetime(2026, 3, 15, 10, 0, tzinfo=timezone.utc) anchor = {"day": 31, "hour": 0} next_time = _compute_next_calendar("monthly", anchor, after) # 3月31号 assert next_time.day == 31 assert next_time.month == 3 def test_monthly_anchor_invalid_day(self): """测试每月刷新点(不存在日期,如2月31号)""" after = datetime(2026, 3, 15, 10, 0, tzinfo=timezone.utc) anchor = {"day": 30, "hour": 0} # 2月没有30号 next_time = _compute_next_calendar("monthly", anchor, after) # 应该是3月30号 assert next_time.day == 30 assert next_time.month == 3 class TestRefreshQuotaRules: """额度刷新测试""" @pytest.mark.asyncio async def test_manual_rule_no_refresh(self, temp_db): """测试手动规则不刷新""" await _refresh_quota_rules() # 手动规则应该被跳过,不抛出错误 @pytest.mark.asyncio async def test_fixed_interval_refresh(self, temp_db): """测试固定间隔刷新""" from app import database as db from datetime import datetime, timezone, timedelta # 创建 Plan 和 Rule plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test") rule = await db.create_quota_rule( plan_id=plan["id"], rule_name="Hourly", quota_total=100, refresh_type="fixed_interval", interval_hours=1, ) # 设置 last_refresh_at 和 next_refresh_at 为过去时间 past = datetime.now(timezone.utc) - timedelta(hours=2) await db.update_quota_rule( rule["id"], last_refresh_at=past.isoformat(), next_refresh_at=past.isoformat(), ) await db.update_quota_rule(rule["id"], quota_used=99) # 执行刷新 await _refresh_quota_rules() # 验证额度已重置 rules = await db.list_quota_rules(plan["id"]) assert rules[0]["quota_used"] == 0 @pytest.mark.asyncio async def test_calendar_cycle_refresh(self, temp_db): """测试日历周期刷新""" from app import database as db from datetime import datetime, timezone, timedelta plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test") rule = await db.create_quota_rule( plan_id=plan["id"], rule_name="Daily", quota_total=100, refresh_type="calendar_cycle", calendar_unit="daily", calendar_anchor={"hour": 0}, ) # 设置为昨天 past = datetime.now(timezone.utc) - timedelta(days=1) await db.update_quota_rule( rule["id"], last_refresh_at=past.isoformat(), next_refresh_at=past.isoformat(), ) await db.update_quota_rule(rule["id"], quota_used=99) await _refresh_quota_rules() rules = await db.list_quota_rules(plan["id"]) assert rules[0]["quota_used"] == 0 @pytest.mark.asyncio async def test_api_sync(self, temp_db): """测试 API 同步刷新""" from app import database as db from app.providers import ProviderRegistry from app.providers.base import QuotaInfo # 注册 Mock Provider class MockSyncProvider: name = "mock_sync" capabilities = [] async def query_quota(self, plan): return QuotaInfo(quota_used=42, quota_total=100, quota_remaining=58, unit="tokens") ProviderRegistry._providers["mock_sync"] = MockSyncProvider() plan = await db.create_plan(name="Test", provider_name="mock_sync", api_key="sk-test") rule = await db.create_quota_rule( plan_id=plan["id"], rule_name="API Sync", quota_total=100, refresh_type="api_sync", ) # 设置 last_refresh_at 为超过10分钟前 past = datetime.now(timezone.utc) - timedelta(seconds=601) await db.update_quota_rule(rule["id"], last_refresh_at=past.isoformat()) await _refresh_quota_rules() # 验证不会抛出错误即可(实际更新取决于调度逻辑) rules = await db.list_quota_rules(plan["id"]) assert rules[0]["rule_name"] == "API Sync" class TestProcessTaskQueue: """任务队列处理测试""" @pytest.mark.asyncio async def test_process_pending_tasks(self, temp_db): """测试处理待处理任务""" from app import database as db # 创建 Plan plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test") await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100) # 创建任务 task = await db.create_task( task_type="image", request_payload={"prompt": "test"}, plan_id=plan["id"], ) # Mock Provider from app.providers import ProviderRegistry class MockTaskProvider: name = "mock_task" capabilities = [] async def generate_image(self, prompt, plan, **kwargs): return {"url": "https://example.com/test.png"} ProviderRegistry._providers["openai"] = MockTaskProvider() await _process_task_queue() # 验证任务状态 updated_task = await db.get_task(task["id"]) # 由于我们没有真实的 Provider 实现完整逻辑,任务可能保持 pending 或失败 # 这里只验证不会抛出错误 assert updated_task is not None @pytest.mark.asyncio async def test_skip_task_no_quota(self, temp_db): """测试跳过额度不足的任务""" from app import database as db plan = await db.create_plan(name="Test", provider_name="openai", api_key="sk-test") rule = await db.create_quota_rule(plan_id=plan["id"], rule_name="Rule", quota_total=100) await db.update_quota_rule(rule["id"], quota_used=100) await db.create_task( task_type="image", request_payload={"prompt": "test"}, plan_id=plan["id"], ) # 应该跳过而不抛出错误 await _process_task_queue() class TestExecuteTask: """任务执行测试""" @pytest.mark.asyncio async def test_execute_image_task(self): """测试执行图片生成任务""" from app.providers import ProviderRegistry class MockImageProvider: name = "mock_img" capabilities = [] async def generate_image(self, prompt, plan, **kwargs): return {"url": f"https://example.com/{prompt}.png"} ProviderRegistry._providers["mock_img"] = MockImageProvider() plan = {"api_key": "sk-test"} task = { "id": "task123", "task_type": "image", "request_payload": {"prompt": "a cat", "size": "512x512"}, } provider = MockImageProvider() result = await _execute_task(provider, plan, task) assert "url" in result assert "cat" in result["url"] @pytest.mark.asyncio async def test_execute_unknown_task(self): """测试执行未知类型任务""" from app.providers import ProviderRegistry class MockProvider: name = "mock" capabilities = [] ProviderRegistry._providers["mock"] = MockProvider() plan = {"api_key": "sk-test"} task = { "id": "task123", "task_type": "unknown_type", "request_payload": {}, } provider = MockProvider() result = await _execute_task(provider, plan, task) assert "error" in result assert "Unknown task type" in result["error"]