Files
planManage/tests/test_providers.py
congsh 37d282c0a2 test: 添加测试框架和全面的单元测试
- 添加 pytest 配置和测试依赖到 requirements.txt
- 创建测试包结构和 fixtures (conftest.py)
- 添加数据库模块的 CRUD 操作测试 (test_database.py)
- 添加 Provider 插件系统测试 (test_providers.py)
- 添加调度器模块测试 (test_scheduler.py)
- 添加 API 路由测试 (test_api.py)
- 添加回归测试覆盖边界条件和错误处理 (test_regressions.py)
- 添加健康检查端点用于容器监控
- 修复调度器中的日历计算逻辑和任务执行参数处理
- 更新数据库函数以返回操作结果状态
2026-03-31 22:36:18 +08:00

244 lines
7.3 KiB
Python

"""Provider 模块测试"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.providers import ProviderRegistry
from app.providers.base import BaseProvider, Capability, QuotaInfo
class MockProvider(BaseProvider):
"""测试用 Mock Provider"""
name = "mock"
display_name = "Mock Provider"
capabilities = [Capability.CHAT, Capability.IMAGE]
async def chat(self, messages, model, plan, stream=True, **kwargs):
if stream:
yield 'data: {"choices": [{"delta": {"content": "hello"}}]}'
else:
yield '{"choices": [{"message": {"content": "hello"}}]}'
async def generate_image(self, prompt, plan, **kwargs):
return {"url": f"https://example.com/{prompt}.png"}
async def query_quota(self, plan):
return QuotaInfo(quota_used=50, quota_total=100, quota_remaining=50, unit="tokens")
class TestProviderRegistry:
"""Provider 注册表测试"""
def test_register_and_get(self):
"""测试注册和获取 Provider"""
mock = MockProvider()
ProviderRegistry._providers["mock"] = mock
retrieved = ProviderRegistry.get("mock")
assert retrieved is mock
assert retrieved.name == "mock"
def test_get_nonexistent(self):
"""测试获取不存在的 Provider"""
retrieved = ProviderRegistry.get("nonexistent")
assert retrieved is None
def test_all_providers(self):
"""测试获取所有 Provider"""
mock = MockProvider()
ProviderRegistry._providers = {"mock": mock}
all_providers = ProviderRegistry.all()
assert "mock" in all_providers
assert all_providers["mock"] is mock
def test_by_capability(self):
"""测试按能力筛选 Provider"""
mock = MockProvider()
ProviderRegistry._providers = {"mock": mock}
chat_providers = ProviderRegistry.by_capability(Capability.CHAT)
assert mock in chat_providers
video_providers = ProviderRegistry.by_capability(Capability.VIDEO)
assert mock not in video_providers
class TestBaseProvider:
"""BaseProvider 测试"""
def test_build_headers(self):
"""测试构建请求头"""
provider = MockProvider()
plan = {
"api_key": "sk-test-key",
"extra_headers": {"X-Custom": "value"},
}
headers = provider._build_headers(plan)
assert headers["Authorization"] == "Bearer sk-test-key"
assert headers["X-Custom"] == "value"
def test_build_headers_no_key(self):
"""测试无 API Key 时的请求头"""
provider = MockProvider()
plan = {}
headers = provider._build_headers(plan)
assert "Authorization" not in headers
assert headers["Content-Type"] == "application/json"
def test_base_url(self):
"""测试获取基础 URL"""
provider = MockProvider()
plan = {"api_base": "https://api.example.com/v1/"}
url = provider._base_url(plan)
assert url == "https://api.example.com/v1"
def test_base_url_empty(self):
"""测试空基础 URL"""
provider = MockProvider()
plan = {}
url = provider._base_url(plan)
assert url == ""
class TestMockProvider:
"""Mock Provider 功能测试"""
@pytest.mark.asyncio
async def test_chat_stream(self):
"""测试流式聊天"""
provider = MockProvider()
plan = {"api_key": "sk-test"}
chunks = []
async for chunk in provider.chat([], "gpt-4", plan, stream=True):
chunks.append(chunk)
assert len(chunks) == 1
assert "hello" in chunks[0]
@pytest.mark.asyncio
async def test_chat_non_stream(self):
"""测试非流式聊天"""
provider = MockProvider()
plan = {"api_key": "sk-test"}
chunks = []
async for chunk in provider.chat([], "gpt-4", plan, stream=False):
chunks.append(chunk)
assert len(chunks) == 1
@pytest.mark.asyncio
async def test_generate_image(self):
"""测试图片生成"""
provider = MockProvider()
plan = {"api_key": "sk-test"}
result = await provider.generate_image("a cat", plan)
assert "url" in result
assert "cat" in result["url"]
@pytest.mark.asyncio
async def test_query_quota(self):
"""测试额度查询"""
provider = MockProvider()
plan = {"api_key": "sk-test"}
info = await provider.query_quota(plan)
assert info.quota_used == 50
assert info.quota_total == 100
assert info.quota_remaining == 50
assert info.unit == "tokens"
class TestOpenAIProvider:
"""OpenAI Provider 测试"""
@pytest.mark.asyncio
async def test_chat_requires_api_key(self):
"""测试需要 API Key"""
from app.providers.openai_provider import OpenAIProvider
provider = OpenAIProvider()
plan = {"api_base": "https://api.openai.com/v1"}
# 没有 API key 不应该抛出错误,但 headers 中不应该有 Authorization
headers = provider._build_headers(plan)
assert "Authorization" not in headers or headers["Authorization"] == "Bearer "
@pytest.mark.asyncio
async def test_build_headers_with_extra(self):
"""测试额外的请求头"""
from app.providers.openai_provider import OpenAIProvider
provider = OpenAIProvider()
plan = {
"api_key": "sk-test",
"extra_headers": {"OpenAI-Organization": "org-123"},
}
headers = provider._build_headers(plan)
assert headers["Authorization"] == "Bearer sk-test"
assert headers["OpenAI-Organization"] == "org-123"
class TestKimiProvider:
"""Kimi Provider 测试"""
@pytest.mark.asyncio
async def test_provider_metadata(self):
"""测试 Provider 元数据"""
from app.providers.kimi import KimiProvider
provider = KimiProvider()
assert provider.name == "kimi"
assert provider.display_name == "Kimi (Moonshot)"
assert Capability.CHAT in provider.capabilities
class TestMiniMaxProvider:
"""MiniMax Provider 测试"""
@pytest.mark.asyncio
async def test_provider_metadata(self):
"""测试 Provider 元数据"""
from app.providers.minimax import MiniMaxProvider
provider = MiniMaxProvider()
assert provider.name == "minimax"
assert provider.display_name == "MiniMax"
assert Capability.CHAT in provider.capabilities
class TestGoogleProvider:
"""Google Provider 测试"""
@pytest.mark.asyncio
async def test_provider_metadata(self):
"""测试 Provider 元数据"""
from app.providers.google import GoogleProvider
provider = GoogleProvider()
assert provider.name == "google"
assert provider.display_name == "Google Gemini"
assert Capability.CHAT in provider.capabilities
class TestZhipuProvider:
"""智谱 Provider 测试"""
@pytest.mark.asyncio
async def test_provider_metadata(self):
"""测试 Provider 元数据"""
from app.providers.zhipu import ZhipuProvider
provider = ZhipuProvider()
assert provider.name == "zhipu"
assert "GLM" in provider.display_name or "智谱" in provider.display_name
assert Capability.CHAT in provider.capabilities