"""Provider 连通性测试:OCR / 视觉 AI。""" from __future__ import annotations import asyncio import base64 import time from typing import Any import httpx from app.providers import build_ocr_provider from app.schemas.common import ProviderConfig class ProviderTestError(Exception): """测试失败,携带用户可读信息。""" async def test_provider_config(key: str, cfg: ProviderConfig) -> dict[str, Any]: """测试 OCR 或 VLM Provider 连通性,返回 {ok, message, detail, latency_ms}。""" started = time.perf_counter() try: if cfg.type in ("", "none", "disabled"): raise ProviderTestError("当前类型为「不使用」,无需测试") if key == KEY_OCR: message, detail = await _test_ocr(cfg) elif key == KEY_VLM: message, detail = await _test_vlm(cfg) else: raise ProviderTestError(f"未知配置键: {key}") latency = int((time.perf_counter() - started) * 1000) return {"ok": True, "message": message, "detail": detail, "latency_ms": latency} except ProviderTestError as exc: latency = int((time.perf_counter() - started) * 1000) return {"ok": False, "message": str(exc), "detail": None, "latency_ms": latency} except Exception as exc: # noqa: BLE001 latency = int((time.perf_counter() - started) * 1000) return { "ok": False, "message": f"测试失败: {exc}", "detail": repr(exc), "latency_ms": latency, } KEY_OCR = "ocr_provider" KEY_VLM = "vlm_provider" # 1x1 白图,用于 HTTP OCR / 视觉测试 _TINY_PNG_B64 = ( "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==" ) async def _test_ocr(cfg: ProviderConfig) -> tuple[str, str | None]: if cfg.type == "tesseract": return await _test_tesseract(cfg) if cfg.type == "paddleocr": return await _test_paddle(cfg) if cfg.type == "http": return await _test_http_ocr(cfg) if cfg.type == "vision": return await _test_openai_compat(cfg, label="视觉 OCR") raise ProviderTestError(f"不支持的 OCR 类型: {cfg.type}") async def _test_vlm(cfg: ProviderConfig) -> tuple[str, str | None]: if cfg.type in ("openai_compat", "openai", "ollama", "glm", "minimax", "moonshot", "vision"): return await _test_openai_compat(cfg, label="视觉 AI") raise ProviderTestError(f"不支持的 VLM 类型: {cfg.type}") async def _test_tesseract(cfg: ProviderConfig) -> tuple[str, str | None]: provider = build_ocr_provider(cfg, allow_upload=True) if provider is None: raise ProviderTestError("无法构造 Tesseract Provider") def _check() -> str: import pytesseract if cfg.extra.get("cmd"): pytesseract.pytesseract.tesseract_cmd = cfg.extra["cmd"] version = pytesseract.get_tesseract_version() return str(version) version = await asyncio.to_thread(_check) return f"Tesseract 可用,版本 {version}", f"lang={cfg.extra.get('lang', 'chi_sim+eng')}" async def _test_paddle(cfg: ProviderConfig) -> tuple[str, str | None]: def _check() -> str: try: import paddleocr # noqa: F401 except ImportError as exc: raise ProviderTestError( "未安装 PaddleOCR,请执行: pip install paddleocr paddlepaddle" ) from exc return "PaddleOCR 模块已安装" detail = await asyncio.to_thread(_check) provider = build_ocr_provider(cfg, allow_upload=True) if provider is None: raise ProviderTestError("无法构造 PaddleOCR Provider") return "PaddleOCR 可用", detail async def _test_http_ocr(cfg: ProviderConfig) -> tuple[str, str | None]: if not cfg.base_url: raise ProviderTestError("请填写 OCR API URL") provider = build_ocr_provider(cfg, allow_upload=True) if provider is None: raise ProviderTestError("无法构造 HTTP OCR Provider") # 写入临时 tiny png 再调用 from pathlib import Path import tempfile with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: tmp.write(base64.b64decode(_TINY_PNG_B64)) tmp_path = Path(tmp.name) try: text = await provider.recognize(tmp_path) finally: try: tmp_path.unlink(missing_ok=True) except OSError: pass preview = (text or "").strip()[:80] or "(空响应,但接口可达)" return "HTTP OCR 接口可达", f"响应预览: {preview}" async def _test_openai_compat(cfg: ProviderConfig, *, label: str) -> tuple[str, str | None]: base_url = (cfg.base_url or "http://localhost:11434/v1").rstrip("/") api_key = cfg.api_key or "" model = cfg.model or "gpt-4o-mini" timeout = float(cfg.extra.get("timeout", 30)) headers: dict[str, str] = {} if api_key: headers["Authorization"] = f"Bearer {api_key}" # 1) 尝试 /models(Ollama、OpenAI 兼容) models_url = f"{base_url}/models" async with httpx.AsyncClient(timeout=timeout) as client: try: resp = await client.get(models_url, headers=headers) if resp.status_code == 200: data = resp.json() ids = _extract_model_ids(data) if model and ids and model not in ids: return ( f"{label} 服务可达", f"已连接 /models,但未找到模型「{model}」。可用: {', '.join(ids[:8])}", ) return f"{label} 服务可达", f"已连接 /models,目标模型: {model}" except httpx.HTTPError: pass # 2) 最小 chat 探活 chat_url = f"{base_url}/chat/completions" payload = { "model": model, "messages": [{"role": "user", "content": "请只回复 OK"}], "max_tokens": 16, "temperature": 0, } try: resp = await client.post(chat_url, json=payload, headers=headers) resp.raise_for_status() data = resp.json() content = data["choices"][0]["message"]["content"] return f"{label} 对话成功", f"模型 {model} 回复: {str(content).strip()[:60]}" except httpx.HTTPStatusError as exc: body = exc.response.text[:200] raise ProviderTestError( f"API 返回 {exc.response.status_code}: {body}" ) from exc except httpx.HTTPError as exc: raise ProviderTestError(f"无法连接 {base_url}: {exc}") from exc def _extract_model_ids(data: Any) -> list[str]: """从 /models 响应中提取 model id 列表。""" if not isinstance(data, dict): return [] items = data.get("data") or data.get("models") or [] ids: list[str] = [] if isinstance(items, list): for item in items: if isinstance(item, dict): mid = item.get("id") or item.get("name") or item.get("model") if mid: ids.append(str(mid)) elif isinstance(item, str): ids.append(item) return ids def merge_provider_api_key(cfg: ProviderConfig, existing: dict | None) -> ProviderConfig: """测试时若 api_key 为空,合并已保存的 key。""" payload = cfg.model_dump() if (not payload.get("api_key")) and isinstance(existing, dict): payload["api_key"] = existing.get("api_key", "") return ProviderConfig(**payload)