Files
planManage/tests/test_providers.py
T

244 lines
7.3 KiB
Python
Raw Normal View History

"""Provider 模块测试"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.providers import ProviderRegistry
from app.providers.base import BaseProvider, Capability, QuotaInfo
class MockProvider(BaseProvider):
"""测试用 Mock Provider"""
name = "mock"
display_name = "Mock Provider"
capabilities = [Capability.CHAT, Capability.IMAGE]
async def chat(self, messages, model, plan, stream=True, **kwargs):
if stream:
yield 'data: {"choices": [{"delta": {"content": "hello"}}]}'
else:
yield '{"choices": [{"message": {"content": "hello"}}]}'
async def generate_image(self, prompt, plan, **kwargs):
return {"url": f"https://example.com/{prompt}.png"}
async def query_quota(self, plan):
return QuotaInfo(quota_used=50, quota_total=100, quota_remaining=50, unit="tokens")
class TestProviderRegistry:
"""Provider 注册表测试"""
def test_register_and_get(self):
"""测试注册和获取 Provider"""
mock = MockProvider()
ProviderRegistry._providers["mock"] = mock
retrieved = ProviderRegistry.get("mock")
assert retrieved is mock
assert retrieved.name == "mock"
def test_get_nonexistent(self):
"""测试获取不存在的 Provider"""
retrieved = ProviderRegistry.get("nonexistent")
assert retrieved is None
def test_all_providers(self):
"""测试获取所有 Provider"""
mock = MockProvider()
ProviderRegistry._providers = {"mock": mock}
all_providers = ProviderRegistry.all()
assert "mock" in all_providers
assert all_providers["mock"] is mock
def test_by_capability(self):
"""测试按能力筛选 Provider"""
mock = MockProvider()
ProviderRegistry._providers = {"mock": mock}
chat_providers = ProviderRegistry.by_capability(Capability.CHAT)
assert mock in chat_providers
video_providers = ProviderRegistry.by_capability(Capability.VIDEO)
assert mock not in video_providers
class TestBaseProvider:
"""BaseProvider 测试"""
def test_build_headers(self):
"""测试构建请求头"""
provider = MockProvider()
plan = {
"api_key": "sk-test-key",
"extra_headers": {"X-Custom": "value"},
}
headers = provider._build_headers(plan)
assert headers["Authorization"] == "Bearer sk-test-key"
assert headers["X-Custom"] == "value"
def test_build_headers_no_key(self):
"""测试无 API Key 时的请求头"""
provider = MockProvider()
plan = {}
headers = provider._build_headers(plan)
assert "Authorization" not in headers
assert headers["Content-Type"] == "application/json"
def test_base_url(self):
"""测试获取基础 URL"""
provider = MockProvider()
plan = {"api_base": "https://api.example.com/v1/"}
url = provider._base_url(plan)
assert url == "https://api.example.com/v1"
def test_base_url_empty(self):
"""测试空基础 URL"""
provider = MockProvider()
plan = {}
url = provider._base_url(plan)
assert url == ""
class TestMockProvider:
"""Mock Provider 功能测试"""
@pytest.mark.asyncio
async def test_chat_stream(self):
"""测试流式聊天"""
provider = MockProvider()
plan = {"api_key": "sk-test"}
chunks = []
async for chunk in provider.chat([], "gpt-4", plan, stream=True):
chunks.append(chunk)
assert len(chunks) == 1
assert "hello" in chunks[0]
@pytest.mark.asyncio
async def test_chat_non_stream(self):
"""测试非流式聊天"""
provider = MockProvider()
plan = {"api_key": "sk-test"}
chunks = []
async for chunk in provider.chat([], "gpt-4", plan, stream=False):
chunks.append(chunk)
assert len(chunks) == 1
@pytest.mark.asyncio
async def test_generate_image(self):
"""测试图片生成"""
provider = MockProvider()
plan = {"api_key": "sk-test"}
result = await provider.generate_image("a cat", plan)
assert "url" in result
assert "cat" in result["url"]
@pytest.mark.asyncio
async def test_query_quota(self):
"""测试额度查询"""
provider = MockProvider()
plan = {"api_key": "sk-test"}
info = await provider.query_quota(plan)
assert info.quota_used == 50
assert info.quota_total == 100
assert info.quota_remaining == 50
assert info.unit == "tokens"
class TestOpenAIProvider:
"""OpenAI Provider 测试"""
@pytest.mark.asyncio
async def test_chat_requires_api_key(self):
"""测试需要 API Key"""
from app.providers.openai_provider import OpenAIProvider
provider = OpenAIProvider()
plan = {"api_base": "https://api.openai.com/v1"}
# 没有 API key 不应该抛出错误,但 headers 中不应该有 Authorization
headers = provider._build_headers(plan)
assert "Authorization" not in headers or headers["Authorization"] == "Bearer "
@pytest.mark.asyncio
async def test_build_headers_with_extra(self):
"""测试额外的请求头"""
from app.providers.openai_provider import OpenAIProvider
provider = OpenAIProvider()
plan = {
"api_key": "sk-test",
"extra_headers": {"OpenAI-Organization": "org-123"},
}
headers = provider._build_headers(plan)
assert headers["Authorization"] == "Bearer sk-test"
assert headers["OpenAI-Organization"] == "org-123"
class TestKimiProvider:
"""Kimi Provider 测试"""
@pytest.mark.asyncio
async def test_provider_metadata(self):
"""测试 Provider 元数据"""
from app.providers.kimi import KimiProvider
provider = KimiProvider()
assert provider.name == "kimi"
assert provider.display_name == "Kimi (Moonshot)"
assert Capability.CHAT in provider.capabilities
class TestMiniMaxProvider:
"""MiniMax Provider 测试"""
@pytest.mark.asyncio
async def test_provider_metadata(self):
"""测试 Provider 元数据"""
from app.providers.minimax import MiniMaxProvider
provider = MiniMaxProvider()
assert provider.name == "minimax"
assert provider.display_name == "MiniMax"
assert Capability.CHAT in provider.capabilities
class TestGoogleProvider:
"""Google Provider 测试"""
@pytest.mark.asyncio
async def test_provider_metadata(self):
"""测试 Provider 元数据"""
from app.providers.google import GoogleProvider
provider = GoogleProvider()
assert provider.name == "google"
assert provider.display_name == "Google Gemini"
assert Capability.CHAT in provider.capabilities
class TestZhipuProvider:
"""智谱 Provider 测试"""
@pytest.mark.asyncio
async def test_provider_metadata(self):
"""测试 Provider 元数据"""
from app.providers.zhipu import ZhipuProvider
provider = ZhipuProvider()
assert provider.name == "zhipu"
assert "GLM" in provider.display_name or "智谱" in provider.display_name
assert Capability.CHAT in provider.capabilities