Files
planManage/app/providers/base.py
T

88 lines
2.4 KiB
Python
Raw Normal View History

"""Provider 抽象基类 + 能力枚举"""
from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, AsyncGenerator
from pydantic import BaseModel
class Capability(str, Enum):
CHAT = "chat"
IMAGE = "image"
VOICE = "voice"
VIDEO = "video"
FILE = "file"
EMBEDDING = "embedding"
class QuotaInfo(BaseModel):
"""Provider 返回的额度信息"""
quota_used: int = 0
quota_remaining: int = 0
quota_total: int = 0
unit: str = "tokens"
raw: dict[str, Any] | None = None
class BaseProvider(ABC):
"""
所有平台适配器的基类。
子类需要设置 name / display_name / capabilities
并实现对应能力的方法。
"""
name: str = ""
display_name: str = ""
capabilities: list[Capability] = []
@abstractmethod
async def chat(
self,
messages: list[dict],
model: str,
plan: dict,
stream: bool = True,
**kwargs,
) -> AsyncGenerator[str, None]:
"""
聊天补全。返回 SSE 格式的 data 行。
每 yield 一次代表一个 SSE event 的 data 字段内容。
"""
yield "" # pragma: no cover
async def generate_image(
self, prompt: str, plan: dict, **kwargs
) -> dict[str, Any]:
raise NotImplementedError(f"{self.name} does not support image generation")
async def generate_voice(
self, text: str, plan: dict, **kwargs
) -> bytes:
raise NotImplementedError(f"{self.name} does not support voice synthesis")
async def generate_video(
self, prompt: str, plan: dict, **kwargs
) -> dict[str, Any]:
raise NotImplementedError(f"{self.name} does not support video generation")
async def query_quota(self, plan: dict) -> QuotaInfo | None:
"""
查询平台额度。返回 None 表示该平台不支持 API 查询,走本地追踪。
"""
return None
def _build_headers(self, plan: dict) -> dict[str, str]:
"""构建请求头: Authorization + extra_headers"""
headers = {"Content-Type": "application/json"}
api_key = plan.get("api_key", "")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
extra = plan.get("extra_headers") or {}
headers.update(extra)
return headers
def _base_url(self, plan: dict) -> str:
return (plan.get("api_base") or "").rstrip("/")