- 添加 pytest 配置和测试依赖到 requirements.txt - 创建测试包结构和 fixtures (conftest.py) - 添加数据库模块的 CRUD 操作测试 (test_database.py) - 添加 Provider 插件系统测试 (test_providers.py) - 添加调度器模块测试 (test_scheduler.py) - 添加 API 路由测试 (test_api.py) - 添加回归测试覆盖边界条件和错误处理 (test_regressions.py) - 添加健康检查端点用于容器监控 - 修复调度器中的日历计算逻辑和任务执行参数处理 - 更新数据库函数以返回操作结果状态
244 lines
7.3 KiB
Python
244 lines
7.3 KiB
Python
"""Provider 模块测试"""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from app.providers import ProviderRegistry
|
|
from app.providers.base import BaseProvider, Capability, QuotaInfo
|
|
|
|
|
|
class MockProvider(BaseProvider):
|
|
"""测试用 Mock Provider"""
|
|
name = "mock"
|
|
display_name = "Mock Provider"
|
|
capabilities = [Capability.CHAT, Capability.IMAGE]
|
|
|
|
async def chat(self, messages, model, plan, stream=True, **kwargs):
|
|
if stream:
|
|
yield 'data: {"choices": [{"delta": {"content": "hello"}}]}'
|
|
else:
|
|
yield '{"choices": [{"message": {"content": "hello"}}]}'
|
|
|
|
async def generate_image(self, prompt, plan, **kwargs):
|
|
return {"url": f"https://example.com/{prompt}.png"}
|
|
|
|
async def query_quota(self, plan):
|
|
return QuotaInfo(quota_used=50, quota_total=100, quota_remaining=50, unit="tokens")
|
|
|
|
|
|
class TestProviderRegistry:
|
|
"""Provider 注册表测试"""
|
|
|
|
def test_register_and_get(self):
|
|
"""测试注册和获取 Provider"""
|
|
mock = MockProvider()
|
|
ProviderRegistry._providers["mock"] = mock
|
|
|
|
retrieved = ProviderRegistry.get("mock")
|
|
assert retrieved is mock
|
|
assert retrieved.name == "mock"
|
|
|
|
def test_get_nonexistent(self):
|
|
"""测试获取不存在的 Provider"""
|
|
retrieved = ProviderRegistry.get("nonexistent")
|
|
assert retrieved is None
|
|
|
|
def test_all_providers(self):
|
|
"""测试获取所有 Provider"""
|
|
mock = MockProvider()
|
|
ProviderRegistry._providers = {"mock": mock}
|
|
|
|
all_providers = ProviderRegistry.all()
|
|
assert "mock" in all_providers
|
|
assert all_providers["mock"] is mock
|
|
|
|
def test_by_capability(self):
|
|
"""测试按能力筛选 Provider"""
|
|
mock = MockProvider()
|
|
ProviderRegistry._providers = {"mock": mock}
|
|
|
|
chat_providers = ProviderRegistry.by_capability(Capability.CHAT)
|
|
assert mock in chat_providers
|
|
|
|
video_providers = ProviderRegistry.by_capability(Capability.VIDEO)
|
|
assert mock not in video_providers
|
|
|
|
|
|
class TestBaseProvider:
|
|
"""BaseProvider 测试"""
|
|
|
|
def test_build_headers(self):
|
|
"""测试构建请求头"""
|
|
provider = MockProvider()
|
|
plan = {
|
|
"api_key": "sk-test-key",
|
|
"extra_headers": {"X-Custom": "value"},
|
|
}
|
|
|
|
headers = provider._build_headers(plan)
|
|
assert headers["Authorization"] == "Bearer sk-test-key"
|
|
assert headers["X-Custom"] == "value"
|
|
|
|
def test_build_headers_no_key(self):
|
|
"""测试无 API Key 时的请求头"""
|
|
provider = MockProvider()
|
|
plan = {}
|
|
|
|
headers = provider._build_headers(plan)
|
|
assert "Authorization" not in headers
|
|
assert headers["Content-Type"] == "application/json"
|
|
|
|
def test_base_url(self):
|
|
"""测试获取基础 URL"""
|
|
provider = MockProvider()
|
|
plan = {"api_base": "https://api.example.com/v1/"}
|
|
|
|
url = provider._base_url(plan)
|
|
assert url == "https://api.example.com/v1"
|
|
|
|
def test_base_url_empty(self):
|
|
"""测试空基础 URL"""
|
|
provider = MockProvider()
|
|
plan = {}
|
|
|
|
url = provider._base_url(plan)
|
|
assert url == ""
|
|
|
|
|
|
class TestMockProvider:
|
|
"""Mock Provider 功能测试"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_stream(self):
|
|
"""测试流式聊天"""
|
|
provider = MockProvider()
|
|
plan = {"api_key": "sk-test"}
|
|
|
|
chunks = []
|
|
async for chunk in provider.chat([], "gpt-4", plan, stream=True):
|
|
chunks.append(chunk)
|
|
|
|
assert len(chunks) == 1
|
|
assert "hello" in chunks[0]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_non_stream(self):
|
|
"""测试非流式聊天"""
|
|
provider = MockProvider()
|
|
plan = {"api_key": "sk-test"}
|
|
|
|
chunks = []
|
|
async for chunk in provider.chat([], "gpt-4", plan, stream=False):
|
|
chunks.append(chunk)
|
|
|
|
assert len(chunks) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_image(self):
|
|
"""测试图片生成"""
|
|
provider = MockProvider()
|
|
plan = {"api_key": "sk-test"}
|
|
|
|
result = await provider.generate_image("a cat", plan)
|
|
assert "url" in result
|
|
assert "cat" in result["url"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_quota(self):
|
|
"""测试额度查询"""
|
|
provider = MockProvider()
|
|
plan = {"api_key": "sk-test"}
|
|
|
|
info = await provider.query_quota(plan)
|
|
assert info.quota_used == 50
|
|
assert info.quota_total == 100
|
|
assert info.quota_remaining == 50
|
|
assert info.unit == "tokens"
|
|
|
|
|
|
class TestOpenAIProvider:
|
|
"""OpenAI Provider 测试"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_requires_api_key(self):
|
|
"""测试需要 API Key"""
|
|
from app.providers.openai_provider import OpenAIProvider
|
|
|
|
provider = OpenAIProvider()
|
|
plan = {"api_base": "https://api.openai.com/v1"}
|
|
|
|
# 没有 API key 不应该抛出错误,但 headers 中不应该有 Authorization
|
|
headers = provider._build_headers(plan)
|
|
assert "Authorization" not in headers or headers["Authorization"] == "Bearer "
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_build_headers_with_extra(self):
|
|
"""测试额外的请求头"""
|
|
from app.providers.openai_provider import OpenAIProvider
|
|
|
|
provider = OpenAIProvider()
|
|
plan = {
|
|
"api_key": "sk-test",
|
|
"extra_headers": {"OpenAI-Organization": "org-123"},
|
|
}
|
|
|
|
headers = provider._build_headers(plan)
|
|
assert headers["Authorization"] == "Bearer sk-test"
|
|
assert headers["OpenAI-Organization"] == "org-123"
|
|
|
|
|
|
class TestKimiProvider:
|
|
"""Kimi Provider 测试"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_provider_metadata(self):
|
|
"""测试 Provider 元数据"""
|
|
from app.providers.kimi import KimiProvider
|
|
|
|
provider = KimiProvider()
|
|
assert provider.name == "kimi"
|
|
assert provider.display_name == "Kimi (Moonshot)"
|
|
assert Capability.CHAT in provider.capabilities
|
|
|
|
|
|
class TestMiniMaxProvider:
|
|
"""MiniMax Provider 测试"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_provider_metadata(self):
|
|
"""测试 Provider 元数据"""
|
|
from app.providers.minimax import MiniMaxProvider
|
|
|
|
provider = MiniMaxProvider()
|
|
assert provider.name == "minimax"
|
|
assert provider.display_name == "MiniMax"
|
|
assert Capability.CHAT in provider.capabilities
|
|
|
|
|
|
class TestGoogleProvider:
|
|
"""Google Provider 测试"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_provider_metadata(self):
|
|
"""测试 Provider 元数据"""
|
|
from app.providers.google import GoogleProvider
|
|
|
|
provider = GoogleProvider()
|
|
assert provider.name == "google"
|
|
assert provider.display_name == "Google Gemini"
|
|
assert Capability.CHAT in provider.capabilities
|
|
|
|
|
|
class TestZhipuProvider:
|
|
"""智谱 Provider 测试"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_provider_metadata(self):
|
|
"""测试 Provider 元数据"""
|
|
from app.providers.zhipu import ZhipuProvider
|
|
|
|
provider = ZhipuProvider()
|
|
assert provider.name == "zhipu"
|
|
assert "GLM" in provider.display_name or "智谱" in provider.display_name
|
|
assert Capability.CHAT in provider.capabilities
|