"""Provider 模块测试""" import pytest from unittest.mock import AsyncMock, MagicMock, patch from app.providers import ProviderRegistry from app.providers.base import BaseProvider, Capability, QuotaInfo class MockProvider(BaseProvider): """测试用 Mock Provider""" name = "mock" display_name = "Mock Provider" capabilities = [Capability.CHAT, Capability.IMAGE] async def chat(self, messages, model, plan, stream=True, **kwargs): if stream: yield 'data: {"choices": [{"delta": {"content": "hello"}}]}' else: yield '{"choices": [{"message": {"content": "hello"}}]}' async def generate_image(self, prompt, plan, **kwargs): return {"url": f"https://example.com/{prompt}.png"} async def query_quota(self, plan): return QuotaInfo(quota_used=50, quota_total=100, quota_remaining=50, unit="tokens") class TestProviderRegistry: """Provider 注册表测试""" def test_register_and_get(self): """测试注册和获取 Provider""" mock = MockProvider() ProviderRegistry._providers["mock"] = mock retrieved = ProviderRegistry.get("mock") assert retrieved is mock assert retrieved.name == "mock" def test_get_nonexistent(self): """测试获取不存在的 Provider""" retrieved = ProviderRegistry.get("nonexistent") assert retrieved is None def test_all_providers(self): """测试获取所有 Provider""" mock = MockProvider() ProviderRegistry._providers = {"mock": mock} all_providers = ProviderRegistry.all() assert "mock" in all_providers assert all_providers["mock"] is mock def test_by_capability(self): """测试按能力筛选 Provider""" mock = MockProvider() ProviderRegistry._providers = {"mock": mock} chat_providers = ProviderRegistry.by_capability(Capability.CHAT) assert mock in chat_providers video_providers = ProviderRegistry.by_capability(Capability.VIDEO) assert mock not in video_providers class TestBaseProvider: """BaseProvider 测试""" def test_build_headers(self): """测试构建请求头""" provider = MockProvider() plan = { "api_key": "sk-test-key", "extra_headers": {"X-Custom": "value"}, } headers = provider._build_headers(plan) assert headers["Authorization"] == "Bearer sk-test-key" assert headers["X-Custom"] == "value" def test_build_headers_no_key(self): """测试无 API Key 时的请求头""" provider = MockProvider() plan = {} headers = provider._build_headers(plan) assert "Authorization" not in headers assert headers["Content-Type"] == "application/json" def test_base_url(self): """测试获取基础 URL""" provider = MockProvider() plan = {"api_base": "https://api.example.com/v1/"} url = provider._base_url(plan) assert url == "https://api.example.com/v1" def test_base_url_empty(self): """测试空基础 URL""" provider = MockProvider() plan = {} url = provider._base_url(plan) assert url == "" class TestMockProvider: """Mock Provider 功能测试""" @pytest.mark.asyncio async def test_chat_stream(self): """测试流式聊天""" provider = MockProvider() plan = {"api_key": "sk-test"} chunks = [] async for chunk in provider.chat([], "gpt-4", plan, stream=True): chunks.append(chunk) assert len(chunks) == 1 assert "hello" in chunks[0] @pytest.mark.asyncio async def test_chat_non_stream(self): """测试非流式聊天""" provider = MockProvider() plan = {"api_key": "sk-test"} chunks = [] async for chunk in provider.chat([], "gpt-4", plan, stream=False): chunks.append(chunk) assert len(chunks) == 1 @pytest.mark.asyncio async def test_generate_image(self): """测试图片生成""" provider = MockProvider() plan = {"api_key": "sk-test"} result = await provider.generate_image("a cat", plan) assert "url" in result assert "cat" in result["url"] @pytest.mark.asyncio async def test_query_quota(self): """测试额度查询""" provider = MockProvider() plan = {"api_key": "sk-test"} info = await provider.query_quota(plan) assert info.quota_used == 50 assert info.quota_total == 100 assert info.quota_remaining == 50 assert info.unit == "tokens" class TestOpenAIProvider: """OpenAI Provider 测试""" @pytest.mark.asyncio async def test_chat_requires_api_key(self): """测试需要 API Key""" from app.providers.openai_provider import OpenAIProvider provider = OpenAIProvider() plan = {"api_base": "https://api.openai.com/v1"} # 没有 API key 不应该抛出错误,但 headers 中不应该有 Authorization headers = provider._build_headers(plan) assert "Authorization" not in headers or headers["Authorization"] == "Bearer " @pytest.mark.asyncio async def test_build_headers_with_extra(self): """测试额外的请求头""" from app.providers.openai_provider import OpenAIProvider provider = OpenAIProvider() plan = { "api_key": "sk-test", "extra_headers": {"OpenAI-Organization": "org-123"}, } headers = provider._build_headers(plan) assert headers["Authorization"] == "Bearer sk-test" assert headers["OpenAI-Organization"] == "org-123" class TestKimiProvider: """Kimi Provider 测试""" @pytest.mark.asyncio async def test_provider_metadata(self): """测试 Provider 元数据""" from app.providers.kimi import KimiProvider provider = KimiProvider() assert provider.name == "kimi" assert provider.display_name == "Kimi (Moonshot)" assert Capability.CHAT in provider.capabilities class TestMiniMaxProvider: """MiniMax Provider 测试""" @pytest.mark.asyncio async def test_provider_metadata(self): """测试 Provider 元数据""" from app.providers.minimax import MiniMaxProvider provider = MiniMaxProvider() assert provider.name == "minimax" assert provider.display_name == "MiniMax" assert Capability.CHAT in provider.capabilities class TestGoogleProvider: """Google Provider 测试""" @pytest.mark.asyncio async def test_provider_metadata(self): """测试 Provider 元数据""" from app.providers.google import GoogleProvider provider = GoogleProvider() assert provider.name == "google" assert provider.display_name == "Google Gemini" assert Capability.CHAT in provider.capabilities class TestZhipuProvider: """智谱 Provider 测试""" @pytest.mark.asyncio async def test_provider_metadata(self): """测试 Provider 元数据""" from app.providers.zhipu import ZhipuProvider provider = ZhipuProvider() assert provider.name == "zhipu" assert "GLM" in provider.display_name or "智谱" in provider.display_name assert Capability.CHAT in provider.capabilities