Files
SnapAndAnaly/backend/app/services/provider_test.py
T
congsh 5c028d7952 Initial commit: snapAna 截图智能整理工具
包含 FastAPI 后端、React 前端、队列/OCR/标签/待办等完整功能。

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-27 15:45:50 +08:00

208 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Provider 连通性测试:OCR / 视觉 AI。"""
from __future__ import annotations
import asyncio
import base64
import time
from typing import Any
import httpx
from app.providers import build_ocr_provider
from app.schemas.common import ProviderConfig
class ProviderTestError(Exception):
"""测试失败,携带用户可读信息。"""
async def test_provider_config(key: str, cfg: ProviderConfig) -> dict[str, Any]:
"""测试 OCR 或 VLM Provider 连通性,返回 {ok, message, detail, latency_ms}。"""
started = time.perf_counter()
try:
if cfg.type in ("", "none", "disabled"):
raise ProviderTestError("当前类型为「不使用」,无需测试")
if key == KEY_OCR:
message, detail = await _test_ocr(cfg)
elif key == KEY_VLM:
message, detail = await _test_vlm(cfg)
else:
raise ProviderTestError(f"未知配置键: {key}")
latency = int((time.perf_counter() - started) * 1000)
return {"ok": True, "message": message, "detail": detail, "latency_ms": latency}
except ProviderTestError as exc:
latency = int((time.perf_counter() - started) * 1000)
return {"ok": False, "message": str(exc), "detail": None, "latency_ms": latency}
except Exception as exc: # noqa: BLE001
latency = int((time.perf_counter() - started) * 1000)
return {
"ok": False,
"message": f"测试失败: {exc}",
"detail": repr(exc),
"latency_ms": latency,
}
KEY_OCR = "ocr_provider"
KEY_VLM = "vlm_provider"
# 1x1 白图,用于 HTTP OCR / 视觉测试
_TINY_PNG_B64 = (
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
)
async def _test_ocr(cfg: ProviderConfig) -> tuple[str, str | None]:
if cfg.type == "tesseract":
return await _test_tesseract(cfg)
if cfg.type == "paddleocr":
return await _test_paddle(cfg)
if cfg.type == "http":
return await _test_http_ocr(cfg)
if cfg.type == "vision":
return await _test_openai_compat(cfg, label="视觉 OCR")
raise ProviderTestError(f"不支持的 OCR 类型: {cfg.type}")
async def _test_vlm(cfg: ProviderConfig) -> tuple[str, str | None]:
if cfg.type in ("openai_compat", "openai", "ollama", "glm", "minimax", "moonshot", "vision"):
return await _test_openai_compat(cfg, label="视觉 AI")
raise ProviderTestError(f"不支持的 VLM 类型: {cfg.type}")
async def _test_tesseract(cfg: ProviderConfig) -> tuple[str, str | None]:
provider = build_ocr_provider(cfg, allow_upload=True)
if provider is None:
raise ProviderTestError("无法构造 Tesseract Provider")
def _check() -> str:
import pytesseract
if cfg.extra.get("cmd"):
pytesseract.pytesseract.tesseract_cmd = cfg.extra["cmd"]
version = pytesseract.get_tesseract_version()
return str(version)
version = await asyncio.to_thread(_check)
return f"Tesseract 可用,版本 {version}", f"lang={cfg.extra.get('lang', 'chi_sim+eng')}"
async def _test_paddle(cfg: ProviderConfig) -> tuple[str, str | None]:
def _check() -> str:
try:
import paddleocr # noqa: F401
except ImportError as exc:
raise ProviderTestError(
"未安装 PaddleOCR,请执行: pip install paddleocr paddlepaddle"
) from exc
return "PaddleOCR 模块已安装"
detail = await asyncio.to_thread(_check)
provider = build_ocr_provider(cfg, allow_upload=True)
if provider is None:
raise ProviderTestError("无法构造 PaddleOCR Provider")
return "PaddleOCR 可用", detail
async def _test_http_ocr(cfg: ProviderConfig) -> tuple[str, str | None]:
if not cfg.base_url:
raise ProviderTestError("请填写 OCR API URL")
provider = build_ocr_provider(cfg, allow_upload=True)
if provider is None:
raise ProviderTestError("无法构造 HTTP OCR Provider")
# 写入临时 tiny png 再调用
from pathlib import Path
import tempfile
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
tmp.write(base64.b64decode(_TINY_PNG_B64))
tmp_path = Path(tmp.name)
try:
text = await provider.recognize(tmp_path)
finally:
try:
tmp_path.unlink(missing_ok=True)
except OSError:
pass
preview = (text or "").strip()[:80] or "(空响应,但接口可达)"
return "HTTP OCR 接口可达", f"响应预览: {preview}"
async def _test_openai_compat(cfg: ProviderConfig, *, label: str) -> tuple[str, str | None]:
base_url = (cfg.base_url or "http://localhost:11434/v1").rstrip("/")
api_key = cfg.api_key or ""
model = cfg.model or "gpt-4o-mini"
timeout = float(cfg.extra.get("timeout", 30))
headers: dict[str, str] = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
# 1) 尝试 /modelsOllama、OpenAI 兼容)
models_url = f"{base_url}/models"
async with httpx.AsyncClient(timeout=timeout) as client:
try:
resp = await client.get(models_url, headers=headers)
if resp.status_code == 200:
data = resp.json()
ids = _extract_model_ids(data)
if model and ids and model not in ids:
return (
f"{label} 服务可达",
f"已连接 /models,但未找到模型「{model}」。可用: {', '.join(ids[:8])}",
)
return f"{label} 服务可达", f"已连接 /models,目标模型: {model}"
except httpx.HTTPError:
pass
# 2) 最小 chat 探活
chat_url = f"{base_url}/chat/completions"
payload = {
"model": model,
"messages": [{"role": "user", "content": "请只回复 OK"}],
"max_tokens": 16,
"temperature": 0,
}
try:
resp = await client.post(chat_url, json=payload, headers=headers)
resp.raise_for_status()
data = resp.json()
content = data["choices"][0]["message"]["content"]
return f"{label} 对话成功", f"模型 {model} 回复: {str(content).strip()[:60]}"
except httpx.HTTPStatusError as exc:
body = exc.response.text[:200]
raise ProviderTestError(
f"API 返回 {exc.response.status_code}: {body}"
) from exc
except httpx.HTTPError as exc:
raise ProviderTestError(f"无法连接 {base_url}: {exc}") from exc
def _extract_model_ids(data: Any) -> list[str]:
"""从 /models 响应中提取 model id 列表。"""
if not isinstance(data, dict):
return []
items = data.get("data") or data.get("models") or []
ids: list[str] = []
if isinstance(items, list):
for item in items:
if isinstance(item, dict):
mid = item.get("id") or item.get("name") or item.get("model")
if mid:
ids.append(str(mid))
elif isinstance(item, str):
ids.append(item)
return ids
def merge_provider_api_key(cfg: ProviderConfig, existing: dict | None) -> ProviderConfig:
"""测试时若 api_key 为空,合并已保存的 key。"""
payload = cfg.model_dump()
if (not payload.get("api_key")) and isinstance(existing, dict):
payload["api_key"] = existing.get("api_key", "")
return ProviderConfig(**payload)