88 lines
2.4 KiB
Python
88 lines
2.4 KiB
Python
|
|
"""Provider 抽象基类 + 能力枚举"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
from abc import ABC, abstractmethod
|
|||
|
|
from enum import Enum
|
|||
|
|
from typing import Any, AsyncGenerator
|
|||
|
|
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
|
|||
|
|
|
|||
|
|
class Capability(str, Enum):
|
|||
|
|
CHAT = "chat"
|
|||
|
|
IMAGE = "image"
|
|||
|
|
VOICE = "voice"
|
|||
|
|
VIDEO = "video"
|
|||
|
|
FILE = "file"
|
|||
|
|
EMBEDDING = "embedding"
|
|||
|
|
|
|||
|
|
|
|||
|
|
class QuotaInfo(BaseModel):
|
|||
|
|
"""Provider 返回的额度信息"""
|
|||
|
|
quota_used: int = 0
|
|||
|
|
quota_remaining: int = 0
|
|||
|
|
quota_total: int = 0
|
|||
|
|
unit: str = "tokens"
|
|||
|
|
raw: dict[str, Any] | None = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class BaseProvider(ABC):
|
|||
|
|
"""
|
|||
|
|
所有平台适配器的基类。
|
|||
|
|
子类需要设置 name / display_name / capabilities,
|
|||
|
|
并实现对应能力的方法。
|
|||
|
|
"""
|
|||
|
|
name: str = ""
|
|||
|
|
display_name: str = ""
|
|||
|
|
capabilities: list[Capability] = []
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
async def chat(
|
|||
|
|
self,
|
|||
|
|
messages: list[dict],
|
|||
|
|
model: str,
|
|||
|
|
plan: dict,
|
|||
|
|
stream: bool = True,
|
|||
|
|
**kwargs,
|
|||
|
|
) -> AsyncGenerator[str, None]:
|
|||
|
|
"""
|
|||
|
|
聊天补全。返回 SSE 格式的 data 行。
|
|||
|
|
每 yield 一次代表一个 SSE event 的 data 字段内容。
|
|||
|
|
"""
|
|||
|
|
yield "" # pragma: no cover
|
|||
|
|
|
|||
|
|
async def generate_image(
|
|||
|
|
self, prompt: str, plan: dict, **kwargs
|
|||
|
|
) -> dict[str, Any]:
|
|||
|
|
raise NotImplementedError(f"{self.name} does not support image generation")
|
|||
|
|
|
|||
|
|
async def generate_voice(
|
|||
|
|
self, text: str, plan: dict, **kwargs
|
|||
|
|
) -> bytes:
|
|||
|
|
raise NotImplementedError(f"{self.name} does not support voice synthesis")
|
|||
|
|
|
|||
|
|
async def generate_video(
|
|||
|
|
self, prompt: str, plan: dict, **kwargs
|
|||
|
|
) -> dict[str, Any]:
|
|||
|
|
raise NotImplementedError(f"{self.name} does not support video generation")
|
|||
|
|
|
|||
|
|
async def query_quota(self, plan: dict) -> QuotaInfo | None:
|
|||
|
|
"""
|
|||
|
|
查询平台额度。返回 None 表示该平台不支持 API 查询,走本地追踪。
|
|||
|
|
"""
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def _build_headers(self, plan: dict) -> dict[str, str]:
|
|||
|
|
"""构建请求头: Authorization + extra_headers"""
|
|||
|
|
headers = {"Content-Type": "application/json"}
|
|||
|
|
api_key = plan.get("api_key", "")
|
|||
|
|
if api_key:
|
|||
|
|
headers["Authorization"] = f"Bearer {api_key}"
|
|||
|
|
extra = plan.get("extra_headers") or {}
|
|||
|
|
headers.update(extra)
|
|||
|
|
return headers
|
|||
|
|
|
|||
|
|
def _base_url(self, plan: dict) -> str:
|
|||
|
|
return (plan.get("api_base") or "").rstrip("/")
|