Initial commit: snapAna 截图智能整理工具
包含 FastAPI 后端、React 前端、队列/OCR/标签/待办等完整功能。 Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -0,0 +1,81 @@
|
||||
"""Provider 工厂,按设置中的 type 字段实例化。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from app.schemas.common import ProviderConfig
|
||||
|
||||
from .base import OCRProvider, VLMProvider
|
||||
from .ocr_http import HttpOCR
|
||||
from .ocr_paddle import PaddleOCRProvider
|
||||
from .ocr_tesseract import TesseractOCR
|
||||
from .ocr_vision import VisionOCR
|
||||
from .vlm_openai import OpenAICompatVLM
|
||||
|
||||
# OCR Provider 类型常量
|
||||
OCR_TYPES = ("tesseract", "paddleocr", "http", "vision", "none")
|
||||
VLM_TYPES = ("openai_compat", "none")
|
||||
RECOGNITION_MODES = ("ocr", "vision", "hybrid")
|
||||
|
||||
|
||||
def build_ocr_provider(
|
||||
cfg: ProviderConfig | None,
|
||||
*,
|
||||
allow_upload: bool = True,
|
||||
) -> Optional[OCRProvider]:
|
||||
"""根据配置构造传统 OCR / 视觉 OCR Provider。"""
|
||||
if cfg is None or cfg.type in ("", "none", "disabled"):
|
||||
return None
|
||||
if cfg.type == "tesseract":
|
||||
return TesseractOCR(
|
||||
lang=cfg.extra.get("lang", "chi_sim+eng"),
|
||||
cmd=cfg.extra.get("cmd"),
|
||||
)
|
||||
if cfg.type == "paddleocr":
|
||||
return PaddleOCRProvider(lang=cfg.extra.get("lang", "ch"))
|
||||
if cfg.type == "http":
|
||||
if not cfg.base_url:
|
||||
raise ValueError("HTTP OCR 需要配置 base_url")
|
||||
return HttpOCR(
|
||||
base_url=cfg.base_url,
|
||||
api_key=cfg.api_key or "",
|
||||
text_path=str(cfg.extra.get("text_path", "text")),
|
||||
headers=cfg.extra.get("headers") if isinstance(cfg.extra.get("headers"), dict) else None,
|
||||
timeout=float(cfg.extra.get("timeout", 30)),
|
||||
)
|
||||
if cfg.type == "vision":
|
||||
return build_vision_ocr(cfg, allow_upload=allow_upload)
|
||||
raise ValueError(f"未知 OCR Provider 类型: {cfg.type}")
|
||||
|
||||
|
||||
def build_vision_ocr(
|
||||
cfg: ProviderConfig | None,
|
||||
*,
|
||||
allow_upload: bool = True,
|
||||
) -> Optional[VisionOCR]:
|
||||
"""从 ProviderConfig 构造视觉 OCR(可与 VLM 共用同一套接口配置)。"""
|
||||
if cfg is None or cfg.type in ("", "none", "disabled"):
|
||||
return None
|
||||
base_url = cfg.base_url or "http://localhost:11434/v1"
|
||||
model = cfg.model or "qwen2.5vl:7b"
|
||||
return VisionOCR(
|
||||
base_url=base_url,
|
||||
api_key=cfg.api_key or "",
|
||||
model=model,
|
||||
timeout=float(cfg.extra.get("timeout", 60)),
|
||||
allow_upload=allow_upload,
|
||||
)
|
||||
|
||||
|
||||
def build_vlm_provider(cfg: ProviderConfig | None) -> Optional[VLMProvider]:
|
||||
"""根据配置构造 VLM Provider(AI 分类/摘要/标签)。"""
|
||||
if cfg is None or cfg.type in ("", "none", "disabled"):
|
||||
return None
|
||||
if cfg.type in ("openai_compat", "openai", "ollama", "glm", "minimax", "moonshot", "vision"):
|
||||
return OpenAICompatVLM(
|
||||
base_url=cfg.base_url or "http://localhost:11434/v1",
|
||||
api_key=cfg.api_key or "",
|
||||
model=cfg.model or "gpt-4o-mini",
|
||||
timeout=float(cfg.extra.get("timeout", 60)),
|
||||
)
|
||||
raise ValueError(f"未知 VLM Provider 类型: {cfg.type}")
|
||||
@@ -0,0 +1,46 @@
|
||||
"""OCR / VLM Provider 抽象接口。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLMResult:
|
||||
"""VLM 结构化分析结果。"""
|
||||
|
||||
title: str = ""
|
||||
summary: str = ""
|
||||
category: str | None = None
|
||||
tags: list[str] = field(default_factory=list)
|
||||
todos: list[dict[str, str]] = field(default_factory=list) # [{title, kind, note}]
|
||||
suggestion: str = ""
|
||||
raw: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class OCRProvider(ABC):
|
||||
"""OCR 接口:输入图片路径,返回文本。"""
|
||||
|
||||
name: str = "ocr"
|
||||
|
||||
@abstractmethod
|
||||
async def recognize(self, image_path: Path) -> str:
|
||||
...
|
||||
|
||||
|
||||
class VLMProvider(ABC):
|
||||
"""多模态接口:根据图片 + OCR 文本生成结构化分析。"""
|
||||
|
||||
name: str = "vlm"
|
||||
|
||||
@abstractmethod
|
||||
async def analyze(
|
||||
self,
|
||||
image_path: Path,
|
||||
ocr_text: str,
|
||||
categories: list[str],
|
||||
allow_upload: bool,
|
||||
) -> VLMResult:
|
||||
...
|
||||
@@ -0,0 +1,63 @@
|
||||
"""通用 HTTP OCR:向自定义 REST 接口 POST 图片并解析文本。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import OCRProvider
|
||||
|
||||
|
||||
class HttpOCR(OCRProvider):
|
||||
"""POST JSON {"image_base64": "..."} 到指定 URL,从响应 JSON 取文本。
|
||||
|
||||
extra 配置项:
|
||||
- text_path: 点分路径,如 "data.text" 或 "result",默认 "text"
|
||||
- headers: 额外请求头 dict
|
||||
"""
|
||||
|
||||
name = "http"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: str = "",
|
||||
text_path: str = "text",
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float = 30.0,
|
||||
) -> None:
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.api_key = api_key
|
||||
self.text_path = text_path
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
|
||||
async def recognize(self, image_path: Path) -> str:
|
||||
with open(image_path, "rb") as f:
|
||||
encoded = base64.b64encode(f.read()).decode("ascii")
|
||||
|
||||
headers = {"Content-Type": "application/json", **self.headers}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
payload = {"image_base64": encoded, "image": encoded}
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
resp = await client.post(self.base_url, json=payload, headers=headers)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
return str(_dig(data, self.text_path) or "").strip()
|
||||
|
||||
|
||||
def _dig(obj: Any, path: str) -> Any:
|
||||
"""按点分路径从嵌套 dict 取值。"""
|
||||
cur = obj
|
||||
for part in path.split("."):
|
||||
if not isinstance(cur, dict):
|
||||
return None
|
||||
cur = cur.get(part)
|
||||
return cur
|
||||
@@ -0,0 +1,43 @@
|
||||
"""PaddleOCR 本地 OCR(可选依赖)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from .base import OCRProvider
|
||||
|
||||
|
||||
class PaddleOCRProvider(OCRProvider):
|
||||
"""通过 PaddleOCR 本地识文。需 pip install paddleocr paddlepaddle。"""
|
||||
|
||||
name = "paddleocr"
|
||||
|
||||
def __init__(self, lang: str = "ch") -> None:
|
||||
self.lang = lang
|
||||
self._engine = None
|
||||
|
||||
async def recognize(self, image_path: Path) -> str:
|
||||
return await asyncio.to_thread(self._sync_recognize, image_path)
|
||||
|
||||
def _sync_recognize(self, image_path: Path) -> str:
|
||||
try:
|
||||
from paddleocr import PaddleOCR # type: ignore
|
||||
except ImportError as exc:
|
||||
raise RuntimeError(
|
||||
"未安装 PaddleOCR,请执行: pip install paddleocr paddlepaddle"
|
||||
) from exc
|
||||
|
||||
if self._engine is None:
|
||||
self._engine = PaddleOCR(use_angle_cls=True, lang=self.lang, show_log=False)
|
||||
|
||||
result = self._engine.ocr(str(image_path), cls=True)
|
||||
lines: list[str] = []
|
||||
if result and result[0]:
|
||||
for line in result[0]:
|
||||
if line and len(line) >= 2:
|
||||
text_part = line[1]
|
||||
if isinstance(text_part, (list, tuple)) and text_part:
|
||||
lines.append(str(text_part[0]))
|
||||
elif isinstance(text_part, str):
|
||||
lines.append(text_part)
|
||||
return "\n".join(lines).strip()
|
||||
@@ -0,0 +1,39 @@
|
||||
"""Tesseract 本地 OCR 实现。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from .base import OCRProvider
|
||||
|
||||
|
||||
class TesseractOCR(OCRProvider):
|
||||
"""通过 pytesseract 调用本地 tesseract。
|
||||
|
||||
需提前安装 tesseract-ocr 及中文语言包。
|
||||
"""
|
||||
|
||||
name = "tesseract"
|
||||
|
||||
def __init__(self, lang: str = "chi_sim+eng", cmd: Optional[str] = None) -> None:
|
||||
self.lang = lang
|
||||
self.cmd = cmd
|
||||
|
||||
async def recognize(self, image_path: Path) -> str:
|
||||
"""异步包装:避免阻塞事件循环。"""
|
||||
return await asyncio.to_thread(self._sync_recognize, image_path)
|
||||
|
||||
def _sync_recognize(self, image_path: Path) -> str:
|
||||
try:
|
||||
import pytesseract
|
||||
from PIL import Image
|
||||
except ImportError as exc: # pragma: no cover
|
||||
raise RuntimeError("未安装 pytesseract / Pillow") from exc
|
||||
|
||||
if self.cmd:
|
||||
pytesseract.pytesseract.tesseract_cmd = self.cmd
|
||||
|
||||
with Image.open(image_path) as img:
|
||||
text = pytesseract.image_to_string(img, lang=self.lang)
|
||||
return text.strip()
|
||||
@@ -0,0 +1,52 @@
|
||||
"""视觉大模型 OCR:用多模态 API 从截图中提取文字。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from .base import OCRProvider
|
||||
from .openai_vision_client import chat_completions, safe_parse_json
|
||||
|
||||
|
||||
_VISION_OCR_SYSTEM = """你是 OCR 助手。用户会给你一张截图,请尽可能完整地提取其中的文字。
|
||||
只输出 JSON,格式:{"text": "提取到的全部文字,保留换行"}
|
||||
如果没有可识别文字,text 填空字符串。"""
|
||||
|
||||
|
||||
class VisionOCR(OCRProvider):
|
||||
"""OpenAI 兼容视觉模型识文(GLM-4V / GPT-4o / Qwen-VL / Ollama 等)。"""
|
||||
|
||||
name = "vision"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
model: str,
|
||||
timeout: float = 60.0,
|
||||
allow_upload: bool = True,
|
||||
) -> None:
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.allow_upload = allow_upload
|
||||
|
||||
async def recognize(self, image_path: Path) -> str:
|
||||
"""调用视觉模型提取文字。"""
|
||||
if not self.allow_upload:
|
||||
raise RuntimeError("敏感目录禁止上传图片,无法使用视觉 OCR")
|
||||
|
||||
content = await chat_completions(
|
||||
base_url=self.base_url,
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
system_prompt=_VISION_OCR_SYSTEM,
|
||||
user_text="请提取这张截图中的所有文字。",
|
||||
image_path=image_path,
|
||||
allow_upload=True,
|
||||
timeout=self.timeout,
|
||||
json_mode=True,
|
||||
)
|
||||
parsed = safe_parse_json(content)
|
||||
text = parsed.get("text") or parsed.get("ocr_text") or content
|
||||
return str(text).strip()
|
||||
@@ -0,0 +1,107 @@
|
||||
"""OpenAI 兼容视觉 API 的公共封装:图片编码 + chat/completions 调用。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from PIL import Image
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def image_to_data_url(image_path: Path, max_side: int | None = None) -> str:
|
||||
"""将图片压缩并编码为 data URL。"""
|
||||
max_side = max_side or settings.vlm_max_side
|
||||
with Image.open(image_path) as img:
|
||||
img = img.convert("RGB")
|
||||
w, h = img.size
|
||||
scale = max(w, h) / max_side
|
||||
if scale > 1:
|
||||
img = img.resize((int(w / scale), int(h / scale)), Image.LANCZOS)
|
||||
buf = BytesIO()
|
||||
img.save(buf, format="JPEG", quality=82)
|
||||
encoded = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
return f"data:image/jpeg;base64,{encoded}"
|
||||
|
||||
|
||||
def safe_parse_json(content: str) -> dict[str, Any]:
|
||||
"""解析模型 JSON 输出,兼容 markdown 包裹。"""
|
||||
text = content.strip()
|
||||
if text.startswith("```"):
|
||||
text = text.strip("`")
|
||||
if text.lower().startswith("json"):
|
||||
text = text[4:].strip()
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
start = text.find("{")
|
||||
end = text.rfind("}")
|
||||
if start >= 0 and end > start:
|
||||
try:
|
||||
return json.loads(text[start : end + 1])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return {"text": content}
|
||||
|
||||
|
||||
async def chat_completions(
|
||||
*,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
model: str,
|
||||
system_prompt: str,
|
||||
user_text: str,
|
||||
image_path: Path | None = None,
|
||||
allow_upload: bool = True,
|
||||
timeout: float = 60.0,
|
||||
json_mode: bool = True,
|
||||
) -> str:
|
||||
"""调用 /v1/chat/completions,返回 message.content 字符串。"""
|
||||
user_content: list[dict[str, Any]] = [{"type": "text", "text": user_text}]
|
||||
if image_path is not None and allow_upload:
|
||||
data_url = image_to_data_url(image_path)
|
||||
user_content.append({"type": "image_url", "image_url": {"url": data_url}})
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_content},
|
||||
],
|
||||
"temperature": 0.2,
|
||||
}
|
||||
if json_mode:
|
||||
payload["response_format"] = {"type": "json_object"}
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
url = f"{base_url.rstrip('/')}/chat/completions"
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
try:
|
||||
resp = await client.post(url, json=payload, headers=headers)
|
||||
except httpx.HTTPError as exc:
|
||||
logger.warning("视觉 API 请求失败,尝试移除 response_format:%s", exc)
|
||||
payload.pop("response_format", None)
|
||||
resp = await client.post(url, json=payload, headers=headers)
|
||||
|
||||
if resp.status_code == 400 and "response_format" in resp.text:
|
||||
payload.pop("response_format", None)
|
||||
resp = await client.post(url, json=payload, headers=headers)
|
||||
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
try:
|
||||
return data["choices"][0]["message"]["content"]
|
||||
except (KeyError, IndexError) as exc:
|
||||
raise RuntimeError(f"视觉 API 返回结构异常: {data}") from exc
|
||||
@@ -0,0 +1,107 @@
|
||||
"""OpenAI 兼容 VLM 实现:覆盖 Ollama / GLM / MiniMax / Moonshot / OpenRouter / OpenAI。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.core.logger import get_logger
|
||||
|
||||
from .base import VLMProvider, VLMResult
|
||||
from .openai_vision_client import chat_completions, safe_parse_json
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
_SYSTEM_PROMPT = """你是一个截图整理助手。用户会给你一张截图(可能附带 OCR 文本)。
|
||||
请用简洁的中文,按以下 JSON 结构返回分析结果,**只输出 JSON,不要解释**:
|
||||
|
||||
{
|
||||
"title": "一句话标题,不超过 24 个字",
|
||||
"summary": "2-3 句话总结这张截图的内容、要点或笑点",
|
||||
"category": "从给定分类列表中选一个最贴切的名字;如果都不符合就填'其他'",
|
||||
"tags": ["3-6 个能帮助检索的细分标签"],
|
||||
"todos": [
|
||||
{"title": "如果截图里出现'待看/待读/待办/想试试/记一下'的内容,抽成一条 todo", "kind": "待读|待看|待办|学习", "note": "可空"}
|
||||
],
|
||||
"suggestion": "可选:给用户的进一步行动建议或同类资源提示,可空"
|
||||
}
|
||||
|
||||
要求:
|
||||
- 标题要可读,不要复述"这是一张..."。
|
||||
- summary 不要超过 80 字。
|
||||
- todos 没有可识别项时给空数组。"""
|
||||
|
||||
|
||||
class OpenAICompatVLM(VLMProvider):
|
||||
"""统一调用 /v1/chat/completions,图片以 base64 data URL 传入。"""
|
||||
|
||||
name = "openai_compat"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
model: str,
|
||||
timeout: float = 60.0,
|
||||
) -> None:
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
|
||||
async def analyze(
|
||||
self,
|
||||
image_path: Path,
|
||||
ocr_text: str,
|
||||
categories: list[str],
|
||||
allow_upload: bool,
|
||||
) -> VLMResult:
|
||||
"""调用模型并解析结构化 JSON。"""
|
||||
prompt = (
|
||||
f"可选分类:{', '.join(categories)}\n\n"
|
||||
f"OCR 文本(可能不完整或为空):\n{ocr_text or '(无)'}"
|
||||
)
|
||||
content = await chat_completions(
|
||||
base_url=self.base_url,
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
system_prompt=_SYSTEM_PROMPT,
|
||||
user_text=prompt,
|
||||
image_path=image_path if allow_upload else None,
|
||||
allow_upload=allow_upload,
|
||||
timeout=self.timeout,
|
||||
json_mode=True,
|
||||
)
|
||||
parsed = safe_parse_json(content)
|
||||
return _to_vlm_result(parsed)
|
||||
|
||||
|
||||
def _to_vlm_result(data: dict[str, Any]) -> VLMResult:
|
||||
"""JSON -> dataclass,容错地兜住字段。"""
|
||||
todos_raw = data.get("todos") or []
|
||||
todos: list[dict[str, str]] = []
|
||||
if isinstance(todos_raw, list):
|
||||
for item in todos_raw:
|
||||
if isinstance(item, dict) and item.get("title"):
|
||||
todos.append(
|
||||
{
|
||||
"title": str(item.get("title", ""))[:512],
|
||||
"kind": str(item.get("kind", "")) or "待办",
|
||||
"note": str(item.get("note", "") or ""),
|
||||
}
|
||||
)
|
||||
elif isinstance(item, str):
|
||||
todos.append({"title": item, "kind": "待办", "note": ""})
|
||||
tags_raw = data.get("tags") or []
|
||||
if not isinstance(tags_raw, list):
|
||||
tags_raw = []
|
||||
return VLMResult(
|
||||
title=str(data.get("title", "") or "")[:128],
|
||||
summary=str(data.get("summary", "") or ""),
|
||||
category=str(data.get("category") or "") or None,
|
||||
tags=[str(t) for t in tags_raw if t][:8],
|
||||
todos=todos,
|
||||
suggestion=str(data.get("suggestion", "") or ""),
|
||||
raw=data,
|
||||
)
|
||||
Reference in New Issue
Block a user