5c028d7952
包含 FastAPI 后端、React 前端、队列/OCR/标签/待办等完整功能。 Co-authored-by: Cursor <cursoragent@cursor.com>
208 lines
7.4 KiB
Python
208 lines
7.4 KiB
Python
"""Provider 连通性测试:OCR / 视觉 AI。"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import base64
|
||
import time
|
||
from typing import Any
|
||
|
||
import httpx
|
||
|
||
from app.providers import build_ocr_provider
|
||
from app.schemas.common import ProviderConfig
|
||
|
||
|
||
class ProviderTestError(Exception):
|
||
"""测试失败,携带用户可读信息。"""
|
||
|
||
|
||
async def test_provider_config(key: str, cfg: ProviderConfig) -> dict[str, Any]:
|
||
"""测试 OCR 或 VLM Provider 连通性,返回 {ok, message, detail, latency_ms}。"""
|
||
started = time.perf_counter()
|
||
try:
|
||
if cfg.type in ("", "none", "disabled"):
|
||
raise ProviderTestError("当前类型为「不使用」,无需测试")
|
||
|
||
if key == KEY_OCR:
|
||
message, detail = await _test_ocr(cfg)
|
||
elif key == KEY_VLM:
|
||
message, detail = await _test_vlm(cfg)
|
||
else:
|
||
raise ProviderTestError(f"未知配置键: {key}")
|
||
|
||
latency = int((time.perf_counter() - started) * 1000)
|
||
return {"ok": True, "message": message, "detail": detail, "latency_ms": latency}
|
||
except ProviderTestError as exc:
|
||
latency = int((time.perf_counter() - started) * 1000)
|
||
return {"ok": False, "message": str(exc), "detail": None, "latency_ms": latency}
|
||
except Exception as exc: # noqa: BLE001
|
||
latency = int((time.perf_counter() - started) * 1000)
|
||
return {
|
||
"ok": False,
|
||
"message": f"测试失败: {exc}",
|
||
"detail": repr(exc),
|
||
"latency_ms": latency,
|
||
}
|
||
|
||
|
||
KEY_OCR = "ocr_provider"
|
||
KEY_VLM = "vlm_provider"
|
||
|
||
# 1x1 白图,用于 HTTP OCR / 视觉测试
|
||
_TINY_PNG_B64 = (
|
||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
|
||
)
|
||
|
||
|
||
async def _test_ocr(cfg: ProviderConfig) -> tuple[str, str | None]:
|
||
if cfg.type == "tesseract":
|
||
return await _test_tesseract(cfg)
|
||
if cfg.type == "paddleocr":
|
||
return await _test_paddle(cfg)
|
||
if cfg.type == "http":
|
||
return await _test_http_ocr(cfg)
|
||
if cfg.type == "vision":
|
||
return await _test_openai_compat(cfg, label="视觉 OCR")
|
||
raise ProviderTestError(f"不支持的 OCR 类型: {cfg.type}")
|
||
|
||
|
||
async def _test_vlm(cfg: ProviderConfig) -> tuple[str, str | None]:
|
||
if cfg.type in ("openai_compat", "openai", "ollama", "glm", "minimax", "moonshot", "vision"):
|
||
return await _test_openai_compat(cfg, label="视觉 AI")
|
||
raise ProviderTestError(f"不支持的 VLM 类型: {cfg.type}")
|
||
|
||
|
||
async def _test_tesseract(cfg: ProviderConfig) -> tuple[str, str | None]:
|
||
provider = build_ocr_provider(cfg, allow_upload=True)
|
||
if provider is None:
|
||
raise ProviderTestError("无法构造 Tesseract Provider")
|
||
|
||
def _check() -> str:
|
||
import pytesseract
|
||
|
||
if cfg.extra.get("cmd"):
|
||
pytesseract.pytesseract.tesseract_cmd = cfg.extra["cmd"]
|
||
version = pytesseract.get_tesseract_version()
|
||
return str(version)
|
||
|
||
version = await asyncio.to_thread(_check)
|
||
return f"Tesseract 可用,版本 {version}", f"lang={cfg.extra.get('lang', 'chi_sim+eng')}"
|
||
|
||
|
||
async def _test_paddle(cfg: ProviderConfig) -> tuple[str, str | None]:
|
||
def _check() -> str:
|
||
try:
|
||
import paddleocr # noqa: F401
|
||
except ImportError as exc:
|
||
raise ProviderTestError(
|
||
"未安装 PaddleOCR,请执行: pip install paddleocr paddlepaddle"
|
||
) from exc
|
||
return "PaddleOCR 模块已安装"
|
||
|
||
detail = await asyncio.to_thread(_check)
|
||
provider = build_ocr_provider(cfg, allow_upload=True)
|
||
if provider is None:
|
||
raise ProviderTestError("无法构造 PaddleOCR Provider")
|
||
return "PaddleOCR 可用", detail
|
||
|
||
|
||
async def _test_http_ocr(cfg: ProviderConfig) -> tuple[str, str | None]:
|
||
if not cfg.base_url:
|
||
raise ProviderTestError("请填写 OCR API URL")
|
||
provider = build_ocr_provider(cfg, allow_upload=True)
|
||
if provider is None:
|
||
raise ProviderTestError("无法构造 HTTP OCR Provider")
|
||
|
||
# 写入临时 tiny png 再调用
|
||
from pathlib import Path
|
||
import tempfile
|
||
|
||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
||
tmp.write(base64.b64decode(_TINY_PNG_B64))
|
||
tmp_path = Path(tmp.name)
|
||
try:
|
||
text = await provider.recognize(tmp_path)
|
||
finally:
|
||
try:
|
||
tmp_path.unlink(missing_ok=True)
|
||
except OSError:
|
||
pass
|
||
|
||
preview = (text or "").strip()[:80] or "(空响应,但接口可达)"
|
||
return "HTTP OCR 接口可达", f"响应预览: {preview}"
|
||
|
||
|
||
async def _test_openai_compat(cfg: ProviderConfig, *, label: str) -> tuple[str, str | None]:
|
||
base_url = (cfg.base_url or "http://localhost:11434/v1").rstrip("/")
|
||
api_key = cfg.api_key or ""
|
||
model = cfg.model or "gpt-4o-mini"
|
||
timeout = float(cfg.extra.get("timeout", 30))
|
||
|
||
headers: dict[str, str] = {}
|
||
if api_key:
|
||
headers["Authorization"] = f"Bearer {api_key}"
|
||
|
||
# 1) 尝试 /models(Ollama、OpenAI 兼容)
|
||
models_url = f"{base_url}/models"
|
||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||
try:
|
||
resp = await client.get(models_url, headers=headers)
|
||
if resp.status_code == 200:
|
||
data = resp.json()
|
||
ids = _extract_model_ids(data)
|
||
if model and ids and model not in ids:
|
||
return (
|
||
f"{label} 服务可达",
|
||
f"已连接 /models,但未找到模型「{model}」。可用: {', '.join(ids[:8])}",
|
||
)
|
||
return f"{label} 服务可达", f"已连接 /models,目标模型: {model}"
|
||
except httpx.HTTPError:
|
||
pass
|
||
|
||
# 2) 最小 chat 探活
|
||
chat_url = f"{base_url}/chat/completions"
|
||
payload = {
|
||
"model": model,
|
||
"messages": [{"role": "user", "content": "请只回复 OK"}],
|
||
"max_tokens": 16,
|
||
"temperature": 0,
|
||
}
|
||
try:
|
||
resp = await client.post(chat_url, json=payload, headers=headers)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
content = data["choices"][0]["message"]["content"]
|
||
return f"{label} 对话成功", f"模型 {model} 回复: {str(content).strip()[:60]}"
|
||
except httpx.HTTPStatusError as exc:
|
||
body = exc.response.text[:200]
|
||
raise ProviderTestError(
|
||
f"API 返回 {exc.response.status_code}: {body}"
|
||
) from exc
|
||
except httpx.HTTPError as exc:
|
||
raise ProviderTestError(f"无法连接 {base_url}: {exc}") from exc
|
||
|
||
|
||
def _extract_model_ids(data: Any) -> list[str]:
|
||
"""从 /models 响应中提取 model id 列表。"""
|
||
if not isinstance(data, dict):
|
||
return []
|
||
items = data.get("data") or data.get("models") or []
|
||
ids: list[str] = []
|
||
if isinstance(items, list):
|
||
for item in items:
|
||
if isinstance(item, dict):
|
||
mid = item.get("id") or item.get("name") or item.get("model")
|
||
if mid:
|
||
ids.append(str(mid))
|
||
elif isinstance(item, str):
|
||
ids.append(item)
|
||
return ids
|
||
|
||
|
||
def merge_provider_api_key(cfg: ProviderConfig, existing: dict | None) -> ProviderConfig:
|
||
"""测试时若 api_key 为空,合并已保存的 key。"""
|
||
payload = cfg.model_dump()
|
||
if (not payload.get("api_key")) and isinstance(existing, dict):
|
||
payload["api_key"] = existing.get("api_key", "")
|
||
return ProviderConfig(**payload)
|