"""Provider 抽象基类 + 能力枚举""" from __future__ import annotations from abc import ABC, abstractmethod from enum import Enum from typing import Any, AsyncGenerator from pydantic import BaseModel class Capability(str, Enum): CHAT = "chat" IMAGE = "image" VOICE = "voice" VIDEO = "video" FILE = "file" EMBEDDING = "embedding" class QuotaInfo(BaseModel): """Provider 返回的额度信息""" quota_used: int = 0 quota_remaining: int = 0 quota_total: int = 0 unit: str = "tokens" raw: dict[str, Any] | None = None class BaseProvider(ABC): """ 所有平台适配器的基类。 子类需要设置 name / display_name / capabilities, 并实现对应能力的方法。 """ name: str = "" display_name: str = "" capabilities: list[Capability] = [] @abstractmethod async def chat( self, messages: list[dict], model: str, plan: dict, stream: bool = True, **kwargs, ) -> AsyncGenerator[str, None]: """ 聊天补全。返回 SSE 格式的 data 行。 每 yield 一次代表一个 SSE event 的 data 字段内容。 """ yield "" # pragma: no cover async def generate_image( self, prompt: str, plan: dict, **kwargs ) -> dict[str, Any]: raise NotImplementedError(f"{self.name} does not support image generation") async def generate_voice( self, text: str, plan: dict, **kwargs ) -> bytes: raise NotImplementedError(f"{self.name} does not support voice synthesis") async def generate_video( self, prompt: str, plan: dict, **kwargs ) -> dict[str, Any]: raise NotImplementedError(f"{self.name} does not support video generation") async def query_quota(self, plan: dict) -> QuotaInfo | None: """ 查询平台额度。返回 None 表示该平台不支持 API 查询,走本地追踪。 """ return None def _build_headers(self, plan: dict) -> dict[str, str]: """构建请求头: Authorization + extra_headers""" headers = {"Content-Type": "application/json"} api_key = plan.get("api_key", "") if api_key: headers["Authorization"] = f"Bearer {api_key}" extra = plan.get("extra_headers") or {} headers.update(extra) return headers def _base_url(self, plan: dict) -> str: return (plan.get("api_base") or "").rstrip("/")