5c028d7952
包含 FastAPI 后端、React 前端、队列/OCR/标签/待办等完整功能。 Co-authored-by: Cursor <cursoragent@cursor.com>
82 lines
2.9 KiB
Python
82 lines
2.9 KiB
Python
"""Provider 工厂,按设置中的 type 字段实例化。"""
|
|
from __future__ import annotations
|
|
|
|
from typing import Optional
|
|
|
|
from app.schemas.common import ProviderConfig
|
|
|
|
from .base import OCRProvider, VLMProvider
|
|
from .ocr_http import HttpOCR
|
|
from .ocr_paddle import PaddleOCRProvider
|
|
from .ocr_tesseract import TesseractOCR
|
|
from .ocr_vision import VisionOCR
|
|
from .vlm_openai import OpenAICompatVLM
|
|
|
|
# OCR Provider 类型常量
|
|
OCR_TYPES = ("tesseract", "paddleocr", "http", "vision", "none")
|
|
VLM_TYPES = ("openai_compat", "none")
|
|
RECOGNITION_MODES = ("ocr", "vision", "hybrid")
|
|
|
|
|
|
def build_ocr_provider(
|
|
cfg: ProviderConfig | None,
|
|
*,
|
|
allow_upload: bool = True,
|
|
) -> Optional[OCRProvider]:
|
|
"""根据配置构造传统 OCR / 视觉 OCR Provider。"""
|
|
if cfg is None or cfg.type in ("", "none", "disabled"):
|
|
return None
|
|
if cfg.type == "tesseract":
|
|
return TesseractOCR(
|
|
lang=cfg.extra.get("lang", "chi_sim+eng"),
|
|
cmd=cfg.extra.get("cmd"),
|
|
)
|
|
if cfg.type == "paddleocr":
|
|
return PaddleOCRProvider(lang=cfg.extra.get("lang", "ch"))
|
|
if cfg.type == "http":
|
|
if not cfg.base_url:
|
|
raise ValueError("HTTP OCR 需要配置 base_url")
|
|
return HttpOCR(
|
|
base_url=cfg.base_url,
|
|
api_key=cfg.api_key or "",
|
|
text_path=str(cfg.extra.get("text_path", "text")),
|
|
headers=cfg.extra.get("headers") if isinstance(cfg.extra.get("headers"), dict) else None,
|
|
timeout=float(cfg.extra.get("timeout", 30)),
|
|
)
|
|
if cfg.type == "vision":
|
|
return build_vision_ocr(cfg, allow_upload=allow_upload)
|
|
raise ValueError(f"未知 OCR Provider 类型: {cfg.type}")
|
|
|
|
|
|
def build_vision_ocr(
|
|
cfg: ProviderConfig | None,
|
|
*,
|
|
allow_upload: bool = True,
|
|
) -> Optional[VisionOCR]:
|
|
"""从 ProviderConfig 构造视觉 OCR(可与 VLM 共用同一套接口配置)。"""
|
|
if cfg is None or cfg.type in ("", "none", "disabled"):
|
|
return None
|
|
base_url = cfg.base_url or "http://localhost:11434/v1"
|
|
model = cfg.model or "qwen2.5vl:7b"
|
|
return VisionOCR(
|
|
base_url=base_url,
|
|
api_key=cfg.api_key or "",
|
|
model=model,
|
|
timeout=float(cfg.extra.get("timeout", 60)),
|
|
allow_upload=allow_upload,
|
|
)
|
|
|
|
|
|
def build_vlm_provider(cfg: ProviderConfig | None) -> Optional[VLMProvider]:
|
|
"""根据配置构造 VLM Provider(AI 分类/摘要/标签)。"""
|
|
if cfg is None or cfg.type in ("", "none", "disabled"):
|
|
return None
|
|
if cfg.type in ("openai_compat", "openai", "ollama", "glm", "minimax", "moonshot", "vision"):
|
|
return OpenAICompatVLM(
|
|
base_url=cfg.base_url or "http://localhost:11434/v1",
|
|
api_key=cfg.api_key or "",
|
|
model=cfg.model or "gpt-4o-mini",
|
|
timeout=float(cfg.extra.get("timeout", 60)),
|
|
)
|
|
raise ValueError(f"未知 VLM Provider 类型: {cfg.type}")
|