Initial commit: snapAna 截图智能整理工具

包含 FastAPI 后端、React 前端、队列/OCR/标签/待办等完整功能。

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
wjl
2026-05-27 15:45:50 +08:00
commit 5c028d7952
76 changed files with 10467 additions and 0 deletions
+81
View File
@@ -0,0 +1,81 @@
"""Provider 工厂,按设置中的 type 字段实例化。"""
from __future__ import annotations
from typing import Optional
from app.schemas.common import ProviderConfig
from .base import OCRProvider, VLMProvider
from .ocr_http import HttpOCR
from .ocr_paddle import PaddleOCRProvider
from .ocr_tesseract import TesseractOCR
from .ocr_vision import VisionOCR
from .vlm_openai import OpenAICompatVLM
# OCR Provider 类型常量
OCR_TYPES = ("tesseract", "paddleocr", "http", "vision", "none")
VLM_TYPES = ("openai_compat", "none")
RECOGNITION_MODES = ("ocr", "vision", "hybrid")
def build_ocr_provider(
cfg: ProviderConfig | None,
*,
allow_upload: bool = True,
) -> Optional[OCRProvider]:
"""根据配置构造传统 OCR / 视觉 OCR Provider。"""
if cfg is None or cfg.type in ("", "none", "disabled"):
return None
if cfg.type == "tesseract":
return TesseractOCR(
lang=cfg.extra.get("lang", "chi_sim+eng"),
cmd=cfg.extra.get("cmd"),
)
if cfg.type == "paddleocr":
return PaddleOCRProvider(lang=cfg.extra.get("lang", "ch"))
if cfg.type == "http":
if not cfg.base_url:
raise ValueError("HTTP OCR 需要配置 base_url")
return HttpOCR(
base_url=cfg.base_url,
api_key=cfg.api_key or "",
text_path=str(cfg.extra.get("text_path", "text")),
headers=cfg.extra.get("headers") if isinstance(cfg.extra.get("headers"), dict) else None,
timeout=float(cfg.extra.get("timeout", 30)),
)
if cfg.type == "vision":
return build_vision_ocr(cfg, allow_upload=allow_upload)
raise ValueError(f"未知 OCR Provider 类型: {cfg.type}")
def build_vision_ocr(
cfg: ProviderConfig | None,
*,
allow_upload: bool = True,
) -> Optional[VisionOCR]:
"""从 ProviderConfig 构造视觉 OCR(可与 VLM 共用同一套接口配置)。"""
if cfg is None or cfg.type in ("", "none", "disabled"):
return None
base_url = cfg.base_url or "http://localhost:11434/v1"
model = cfg.model or "qwen2.5vl:7b"
return VisionOCR(
base_url=base_url,
api_key=cfg.api_key or "",
model=model,
timeout=float(cfg.extra.get("timeout", 60)),
allow_upload=allow_upload,
)
def build_vlm_provider(cfg: ProviderConfig | None) -> Optional[VLMProvider]:
"""根据配置构造 VLM Provider(AI 分类/摘要/标签)。"""
if cfg is None or cfg.type in ("", "none", "disabled"):
return None
if cfg.type in ("openai_compat", "openai", "ollama", "glm", "minimax", "moonshot", "vision"):
return OpenAICompatVLM(
base_url=cfg.base_url or "http://localhost:11434/v1",
api_key=cfg.api_key or "",
model=cfg.model or "gpt-4o-mini",
timeout=float(cfg.extra.get("timeout", 60)),
)
raise ValueError(f"未知 VLM Provider 类型: {cfg.type}")
+46
View File
@@ -0,0 +1,46 @@
"""OCR / VLM Provider 抽象接口。"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
@dataclass
class VLMResult:
"""VLM 结构化分析结果。"""
title: str = ""
summary: str = ""
category: str | None = None
tags: list[str] = field(default_factory=list)
todos: list[dict[str, str]] = field(default_factory=list) # [{title, kind, note}]
suggestion: str = ""
raw: dict[str, Any] = field(default_factory=dict)
class OCRProvider(ABC):
"""OCR 接口:输入图片路径,返回文本。"""
name: str = "ocr"
@abstractmethod
async def recognize(self, image_path: Path) -> str:
...
class VLMProvider(ABC):
"""多模态接口:根据图片 + OCR 文本生成结构化分析。"""
name: str = "vlm"
@abstractmethod
async def analyze(
self,
image_path: Path,
ocr_text: str,
categories: list[str],
allow_upload: bool,
) -> VLMResult:
...
+63
View File
@@ -0,0 +1,63 @@
"""通用 HTTP OCR:向自定义 REST 接口 POST 图片并解析文本。"""
from __future__ import annotations
import base64
import json
from pathlib import Path
from typing import Any
import httpx
from .base import OCRProvider
class HttpOCR(OCRProvider):
"""POST JSON {"image_base64": "..."} 到指定 URL,从响应 JSON 取文本。
extra 配置项:
- text_path: 点分路径,如 "data.text""result",默认 "text"
- headers: 额外请求头 dict
"""
name = "http"
def __init__(
self,
base_url: str,
api_key: str = "",
text_path: str = "text",
headers: dict[str, str] | None = None,
timeout: float = 30.0,
) -> None:
self.base_url = base_url.rstrip("/")
self.api_key = api_key
self.text_path = text_path
self.headers = headers or {}
self.timeout = timeout
async def recognize(self, image_path: Path) -> str:
with open(image_path, "rb") as f:
encoded = base64.b64encode(f.read()).decode("ascii")
headers = {"Content-Type": "application/json", **self.headers}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
payload = {"image_base64": encoded, "image": encoded}
async with httpx.AsyncClient(timeout=self.timeout) as client:
resp = await client.post(self.base_url, json=payload, headers=headers)
resp.raise_for_status()
data = resp.json()
return str(_dig(data, self.text_path) or "").strip()
def _dig(obj: Any, path: str) -> Any:
"""按点分路径从嵌套 dict 取值。"""
cur = obj
for part in path.split("."):
if not isinstance(cur, dict):
return None
cur = cur.get(part)
return cur
+43
View File
@@ -0,0 +1,43 @@
"""PaddleOCR 本地 OCR(可选依赖)。"""
from __future__ import annotations
import asyncio
from pathlib import Path
from .base import OCRProvider
class PaddleOCRProvider(OCRProvider):
"""通过 PaddleOCR 本地识文。需 pip install paddleocr paddlepaddle。"""
name = "paddleocr"
def __init__(self, lang: str = "ch") -> None:
self.lang = lang
self._engine = None
async def recognize(self, image_path: Path) -> str:
return await asyncio.to_thread(self._sync_recognize, image_path)
def _sync_recognize(self, image_path: Path) -> str:
try:
from paddleocr import PaddleOCR # type: ignore
except ImportError as exc:
raise RuntimeError(
"未安装 PaddleOCR,请执行: pip install paddleocr paddlepaddle"
) from exc
if self._engine is None:
self._engine = PaddleOCR(use_angle_cls=True, lang=self.lang, show_log=False)
result = self._engine.ocr(str(image_path), cls=True)
lines: list[str] = []
if result and result[0]:
for line in result[0]:
if line and len(line) >= 2:
text_part = line[1]
if isinstance(text_part, (list, tuple)) and text_part:
lines.append(str(text_part[0]))
elif isinstance(text_part, str):
lines.append(text_part)
return "\n".join(lines).strip()
+39
View File
@@ -0,0 +1,39 @@
"""Tesseract 本地 OCR 实现。"""
from __future__ import annotations
import asyncio
from pathlib import Path
from typing import Optional
from .base import OCRProvider
class TesseractOCR(OCRProvider):
"""通过 pytesseract 调用本地 tesseract。
需提前安装 tesseract-ocr 及中文语言包。
"""
name = "tesseract"
def __init__(self, lang: str = "chi_sim+eng", cmd: Optional[str] = None) -> None:
self.lang = lang
self.cmd = cmd
async def recognize(self, image_path: Path) -> str:
"""异步包装:避免阻塞事件循环。"""
return await asyncio.to_thread(self._sync_recognize, image_path)
def _sync_recognize(self, image_path: Path) -> str:
try:
import pytesseract
from PIL import Image
except ImportError as exc: # pragma: no cover
raise RuntimeError("未安装 pytesseract / Pillow") from exc
if self.cmd:
pytesseract.pytesseract.tesseract_cmd = self.cmd
with Image.open(image_path) as img:
text = pytesseract.image_to_string(img, lang=self.lang)
return text.strip()
+52
View File
@@ -0,0 +1,52 @@
"""视觉大模型 OCR:用多模态 API 从截图中提取文字。"""
from __future__ import annotations
from pathlib import Path
from .base import OCRProvider
from .openai_vision_client import chat_completions, safe_parse_json
_VISION_OCR_SYSTEM = """你是 OCR 助手。用户会给你一张截图,请尽可能完整地提取其中的文字。
只输出 JSON,格式:{"text": "提取到的全部文字,保留换行"}
如果没有可识别文字,text 填空字符串。"""
class VisionOCR(OCRProvider):
"""OpenAI 兼容视觉模型识文(GLM-4V / GPT-4o / Qwen-VL / Ollama 等)。"""
name = "vision"
def __init__(
self,
base_url: str,
api_key: str,
model: str,
timeout: float = 60.0,
allow_upload: bool = True,
) -> None:
self.base_url = base_url
self.api_key = api_key
self.model = model
self.timeout = timeout
self.allow_upload = allow_upload
async def recognize(self, image_path: Path) -> str:
"""调用视觉模型提取文字。"""
if not self.allow_upload:
raise RuntimeError("敏感目录禁止上传图片,无法使用视觉 OCR")
content = await chat_completions(
base_url=self.base_url,
api_key=self.api_key,
model=self.model,
system_prompt=_VISION_OCR_SYSTEM,
user_text="请提取这张截图中的所有文字。",
image_path=image_path,
allow_upload=True,
timeout=self.timeout,
json_mode=True,
)
parsed = safe_parse_json(content)
text = parsed.get("text") or parsed.get("ocr_text") or content
return str(text).strip()
@@ -0,0 +1,107 @@
"""OpenAI 兼容视觉 API 的公共封装:图片编码 + chat/completions 调用。"""
from __future__ import annotations
import base64
import json
from io import BytesIO
from pathlib import Path
from typing import Any
import httpx
from PIL import Image
from app.core.config import settings
from app.core.logger import get_logger
logger = get_logger(__name__)
def image_to_data_url(image_path: Path, max_side: int | None = None) -> str:
"""将图片压缩并编码为 data URL。"""
max_side = max_side or settings.vlm_max_side
with Image.open(image_path) as img:
img = img.convert("RGB")
w, h = img.size
scale = max(w, h) / max_side
if scale > 1:
img = img.resize((int(w / scale), int(h / scale)), Image.LANCZOS)
buf = BytesIO()
img.save(buf, format="JPEG", quality=82)
encoded = base64.b64encode(buf.getvalue()).decode("ascii")
return f"data:image/jpeg;base64,{encoded}"
def safe_parse_json(content: str) -> dict[str, Any]:
"""解析模型 JSON 输出,兼容 markdown 包裹。"""
text = content.strip()
if text.startswith("```"):
text = text.strip("`")
if text.lower().startswith("json"):
text = text[4:].strip()
try:
return json.loads(text)
except json.JSONDecodeError:
start = text.find("{")
end = text.rfind("}")
if start >= 0 and end > start:
try:
return json.loads(text[start : end + 1])
except json.JSONDecodeError:
pass
return {"text": content}
async def chat_completions(
*,
base_url: str,
api_key: str,
model: str,
system_prompt: str,
user_text: str,
image_path: Path | None = None,
allow_upload: bool = True,
timeout: float = 60.0,
json_mode: bool = True,
) -> str:
"""调用 /v1/chat/completions,返回 message.content 字符串。"""
user_content: list[dict[str, Any]] = [{"type": "text", "text": user_text}]
if image_path is not None and allow_upload:
data_url = image_to_data_url(image_path)
user_content.append({"type": "image_url", "image_url": {"url": data_url}})
payload: dict[str, Any] = {
"model": model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_content},
],
"temperature": 0.2,
}
if json_mode:
payload["response_format"] = {"type": "json_object"}
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
url = f"{base_url.rstrip('/')}/chat/completions"
async with httpx.AsyncClient(timeout=timeout) as client:
try:
resp = await client.post(url, json=payload, headers=headers)
except httpx.HTTPError as exc:
logger.warning("视觉 API 请求失败,尝试移除 response_format%s", exc)
payload.pop("response_format", None)
resp = await client.post(url, json=payload, headers=headers)
if resp.status_code == 400 and "response_format" in resp.text:
payload.pop("response_format", None)
resp = await client.post(url, json=payload, headers=headers)
resp.raise_for_status()
data = resp.json()
try:
return data["choices"][0]["message"]["content"]
except (KeyError, IndexError) as exc:
raise RuntimeError(f"视觉 API 返回结构异常: {data}") from exc
+107
View File
@@ -0,0 +1,107 @@
"""OpenAI 兼容 VLM 实现:覆盖 Ollama / GLM / MiniMax / Moonshot / OpenRouter / OpenAI。"""
from __future__ import annotations
from pathlib import Path
from typing import Any
from app.core.logger import get_logger
from .base import VLMProvider, VLMResult
from .openai_vision_client import chat_completions, safe_parse_json
logger = get_logger(__name__)
_SYSTEM_PROMPT = """你是一个截图整理助手。用户会给你一张截图(可能附带 OCR 文本)。
请用简洁的中文,按以下 JSON 结构返回分析结果,**只输出 JSON,不要解释**:
{
"title": "一句话标题,不超过 24 个字",
"summary": "2-3 句话总结这张截图的内容、要点或笑点",
"category": "从给定分类列表中选一个最贴切的名字;如果都不符合就填'其他'",
"tags": ["3-6 个能帮助检索的细分标签"],
"todos": [
{"title": "如果截图里出现'待看/待读/待办/想试试/记一下'的内容,抽成一条 todo", "kind": "待读|待看|待办|学习", "note": "可空"}
],
"suggestion": "可选:给用户的进一步行动建议或同类资源提示,可空"
}
要求:
- 标题要可读,不要复述"这是一张..."
- summary 不要超过 80 字。
- todos 没有可识别项时给空数组。"""
class OpenAICompatVLM(VLMProvider):
"""统一调用 /v1/chat/completions,图片以 base64 data URL 传入。"""
name = "openai_compat"
def __init__(
self,
base_url: str,
api_key: str,
model: str,
timeout: float = 60.0,
) -> None:
self.base_url = base_url.rstrip("/")
self.api_key = api_key
self.model = model
self.timeout = timeout
async def analyze(
self,
image_path: Path,
ocr_text: str,
categories: list[str],
allow_upload: bool,
) -> VLMResult:
"""调用模型并解析结构化 JSON。"""
prompt = (
f"可选分类:{', '.join(categories)}\n\n"
f"OCR 文本(可能不完整或为空):\n{ocr_text or '(无)'}"
)
content = await chat_completions(
base_url=self.base_url,
api_key=self.api_key,
model=self.model,
system_prompt=_SYSTEM_PROMPT,
user_text=prompt,
image_path=image_path if allow_upload else None,
allow_upload=allow_upload,
timeout=self.timeout,
json_mode=True,
)
parsed = safe_parse_json(content)
return _to_vlm_result(parsed)
def _to_vlm_result(data: dict[str, Any]) -> VLMResult:
"""JSON -> dataclass,容错地兜住字段。"""
todos_raw = data.get("todos") or []
todos: list[dict[str, str]] = []
if isinstance(todos_raw, list):
for item in todos_raw:
if isinstance(item, dict) and item.get("title"):
todos.append(
{
"title": str(item.get("title", ""))[:512],
"kind": str(item.get("kind", "")) or "待办",
"note": str(item.get("note", "") or ""),
}
)
elif isinstance(item, str):
todos.append({"title": item, "kind": "待办", "note": ""})
tags_raw = data.get("tags") or []
if not isinstance(tags_raw, list):
tags_raw = []
return VLMResult(
title=str(data.get("title", "") or "")[:128],
summary=str(data.get("summary", "") or ""),
category=str(data.get("category") or "") or None,
tags=[str(t) for t in tags_raw if t][:8],
todos=todos,
suggestion=str(data.get("suggestion", "") or ""),
raw=data,
)