Files
SnapAndAnaly/backend/app/services/analyze.py
T
congsh 5c028d7952 Initial commit: snapAna 截图智能整理工具
包含 FastAPI 后端、React 前端、队列/OCR/标签/待办等完整功能。

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-27 15:45:50 +08:00

485 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""单张截图的分析逻辑:OCR -> VLM -> 写回数据库。
设计要点:
- 不在长时间网络调用期间持有 SQLite 写事务,避免 `database is locked`。
- 把流程拆为「短事务(取配置/标记状态)」 -> 「无事务(OCR/VLM 网络调用)」
-> 「短事务(写回结果)」。
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from sqlalchemy import select
from app.core.db import session_scope
from app.core.logger import get_logger
from app.core.path_utils import is_accessible_file, path_from_storage, path_is_under
from app.models.category import Category, DEFAULT_CATEGORIES
from app.models.meta import ScreenshotMeta
from app.models.screenshot import ProcessStatus, Screenshot
from app.models.setting import (
DEFAULT_RECOGNITION_MODE,
KEY_OCR_PROVIDER,
KEY_RECOGNITION_MODE,
KEY_VLM_PROVIDER,
)
from app.models.tag import Tag
from app.models.todo import Todo, TodoStatus
from app.models.watch_folder import WatchFolder
from app.providers import (
RECOGNITION_MODES,
build_ocr_provider,
build_vision_ocr,
build_vlm_provider,
)
from app.providers.base import VLMResult
from app.schemas.common import ProviderConfig
from app.services.exif_utils import is_exif_location_tag
from app.services.settings_store import get_provider_config, get_setting
logger = get_logger(__name__)
@dataclass
class _PreparedContext:
"""从短事务中导出的、不依赖 ORM 会话的纯数据。"""
path: Path
ocr_cfg: Optional[ProviderConfig]
vlm_cfg: Optional[ProviderConfig]
recognition_mode: str
category_names: list[str]
allow_upload: bool
exists: bool
async def analyze_screenshot_by_id(screenshot_id: int) -> None:
"""对外入口:按 id 分析单张截图。
被 worker 调度。函数内部自己管理多个短事务。
"""
ctx = _prepare(screenshot_id)
if ctx is None:
return # 截图已被删除
if not ctx.exists:
_persist_missing(screenshot_id)
return
ocr_provider = _safe_build(
lambda c: build_ocr_provider(c, allow_upload=ctx.allow_upload),
ctx.ocr_cfg if _use_traditional_ocr(ctx) else None,
"OCR",
)
vlm_provider = _safe_build(build_vlm_provider, ctx.vlm_cfg, "VLM")
# ---- 文字识别阶段(在事务外执行)----
ocr_text, ocr_status = await _extract_text(screenshot_id, ctx, ocr_provider)
# ---- VLM 阶段(事务外)----
vlm_result: Optional[VLMResult] = None
ai_status = ProcessStatus.SKIPPED.value
vlm_error: Optional[Exception] = None
if vlm_provider is not None:
_mark_status(screenshot_id, ai=ProcessStatus.RUNNING.value)
try:
vlm_result = await vlm_provider.analyze(
image_path=ctx.path,
ocr_text=ocr_text,
categories=ctx.category_names,
allow_upload=ctx.allow_upload,
)
ai_status = ProcessStatus.DONE.value
except Exception as exc: # noqa: BLE001
logger.warning("VLM 失败 #%d: %s", screenshot_id, exc)
ai_status = ProcessStatus.FAILED.value
vlm_error = exc
# ---- 写回阶段(短事务)----
_persist_result(
screenshot_id=screenshot_id,
ocr_text=ocr_text,
ocr_status=ocr_status,
ai_status=ai_status,
vlm_result=vlm_result,
)
if vlm_error is not None:
raise vlm_error # 让 worker 决定重试
async def analyze_ocr_only_by_id(screenshot_id: int) -> None:
"""仅补跑 OCR/视觉识文,不改动 AI 分析结果。
用于 ai_status=done 但 ocr_status=failed 的截图。
OCR 仍失败时抛异常,由 worker 按 max_retries 重试。
"""
ctx = _prepare(screenshot_id)
if ctx is None:
return
if not ctx.exists:
_persist_missing(screenshot_id)
raise RuntimeError("截图文件丢失")
ocr_provider = _safe_build(
lambda c: build_ocr_provider(c, allow_upload=ctx.allow_upload),
ctx.ocr_cfg if _use_traditional_ocr(ctx) else None,
"OCR",
)
ocr_text, ocr_status = await _extract_text(screenshot_id, ctx, ocr_provider)
_persist_ocr_only(screenshot_id, ocr_text, ocr_status)
if ocr_status == ProcessStatus.FAILED.value:
raise RuntimeError("OCR 识别失败")
# ---------------- 短事务工具 ---------------- #
def _prepare(screenshot_id: int) -> Optional[_PreparedContext]:
"""短事务:读取 Provider 配置、分类列表、敏感目录判定。"""
with session_scope() as session:
shot = session.get(Screenshot, screenshot_id)
if shot is None:
return None
image_path = path_from_storage(shot.path)
exists = is_accessible_file(image_path)
ocr_cfg = _load_provider_config(session, KEY_OCR_PROVIDER)
vlm_cfg = _load_provider_config(session, KEY_VLM_PROVIDER)
mode = get_setting(session, KEY_RECOGNITION_MODE, DEFAULT_RECOGNITION_MODE)
if mode not in RECOGNITION_MODES:
mode = DEFAULT_RECOGNITION_MODE
categories = _ensure_default_categories(session)
category_names = [c.name for c in categories]
allow_upload = not _is_sensitive(session, image_path)
return _PreparedContext(
path=image_path,
ocr_cfg=ocr_cfg,
vlm_cfg=vlm_cfg,
recognition_mode=mode,
category_names=category_names,
allow_upload=allow_upload,
exists=exists,
)
def _use_traditional_ocr(ctx: _PreparedContext) -> bool:
"""混合/传统模式下是否启用 OCR 区配置(排除 vision 类型,vision 单独处理)。"""
if ctx.recognition_mode not in ("ocr", "hybrid"):
return False
if ctx.ocr_cfg is None or ctx.ocr_cfg.type in ("", "none", "disabled", "vision"):
return False
return True
async def _extract_text(
screenshot_id: int,
ctx: _PreparedContext,
ocr_provider,
) -> tuple[str, str]:
"""按识别模式提取文字:传统 OCR / 视觉 AI / 混合。"""
ocr_text = ""
ocr_status = ProcessStatus.SKIPPED.value
mode = ctx.recognition_mode
# 1) 传统 OCRTesseract / Paddle / HTTP
if ocr_provider is not None:
_mark_status(screenshot_id, ocr=ProcessStatus.RUNNING.value)
try:
ocr_text = await ocr_provider.recognize(ctx.path)
ocr_status = ProcessStatus.DONE.value
except Exception as exc: # noqa: BLE001
logger.warning("OCR 失败 #%d: %s", screenshot_id, exc)
ocr_status = ProcessStatus.FAILED.value
# 2) 视觉 AI 识文
need_vision = mode == "vision" or (
mode == "hybrid" and not ocr_text.strip()
)
if mode == "ocr" and ctx.ocr_cfg and ctx.ocr_cfg.type == "vision":
# 用户在 OCR 区选了「视觉模型识文」
need_vision = True
if need_vision:
vision_cfg = _pick_vision_config(ctx)
vision = _safe_build(
lambda c: build_vision_ocr(c, allow_upload=ctx.allow_upload),
vision_cfg,
"VisionOCR",
)
if vision is not None:
_mark_status(screenshot_id, ocr=ProcessStatus.RUNNING.value)
try:
ocr_text = await vision.recognize(ctx.path)
ocr_status = ProcessStatus.DONE.value
except Exception as exc: # noqa: BLE001
logger.warning("视觉识文失败 #%d: %s", screenshot_id, exc)
if ocr_status != ProcessStatus.DONE.value:
ocr_status = ProcessStatus.FAILED.value
return ocr_text, ocr_status
def _pick_vision_config(ctx: _PreparedContext) -> Optional[ProviderConfig]:
"""决定视觉识文用哪套配置:优先 OCR 区的 vision,否则 VLM 区。"""
if ctx.ocr_cfg and ctx.ocr_cfg.type == "vision":
return ctx.ocr_cfg
if ctx.recognition_mode == "vision" or ctx.recognition_mode == "hybrid":
return ctx.vlm_cfg
return ctx.vlm_cfg
def _mark_status(
screenshot_id: int,
ocr: Optional[str] = None,
ai: Optional[str] = None,
) -> None:
"""短事务:把截图标记为 running,方便前端看到进度。"""
if ocr is None and ai is None:
return
with session_scope() as session:
shot = session.get(Screenshot, screenshot_id)
if shot is None:
return
if ocr is not None:
shot.ocr_status = ocr
if ai is not None:
shot.ai_status = ai
def _persist_missing(screenshot_id: int) -> None:
"""短事务:标记文件已丢失。"""
with session_scope() as session:
shot = session.get(Screenshot, screenshot_id)
if shot is None:
return
shot.ocr_status = ProcessStatus.FAILED.value
shot.ai_status = ProcessStatus.FAILED.value
meta = _get_or_create_meta(session, screenshot_id)
meta.ai_summary = "(文件丢失)"
def _persist_result(
screenshot_id: int,
ocr_text: str,
ocr_status: str,
ai_status: str,
vlm_result: Optional[VLMResult],
) -> None:
"""短事务:把 OCR/VLM 结果写回 DB,包括 meta/tags/category/todos。"""
with session_scope() as session:
shot = session.get(Screenshot, screenshot_id)
if shot is None:
return
shot.ocr_status = ocr_status
shot.ai_status = ai_status
meta = _get_or_create_meta(session, screenshot_id)
meta.ocr_text = ocr_text or None
if vlm_result is not None:
meta.ai_title = vlm_result.title or None
meta.ai_summary = vlm_result.summary or None
meta.ai_suggestion = vlm_result.suggestion or None
meta.ai_raw_json = json.dumps(vlm_result.raw, ensure_ascii=False)
categories = list(session.scalars(select(Category)).all())
category = _resolve_category(session, vlm_result.category, categories)
if category is not None:
shot.category_id = category.id
_sync_tags(session, shot, vlm_result.tags)
_sync_todos(session, shot, vlm_result.todos)
def _persist_ocr_only(screenshot_id: int, ocr_text: str, ocr_status: str) -> None:
"""短事务:仅写回 OCR 文本与状态,保留已有 AI 字段。"""
with session_scope() as session:
shot = session.get(Screenshot, screenshot_id)
if shot is None:
return
shot.ocr_status = ocr_status
meta = _get_or_create_meta(session, screenshot_id)
meta.ocr_text = ocr_text or None
def enqueue_ocr_jobs(*, limit: int = 500) -> int:
"""为「AI 已成功、OCR 失败」的截图批量创建 OCR 补跑任务。
跳过已有 pending/running 的 ocr 任务,避免重复入队。
"""
from app.models.job import Job, JobKind, JobStatus
active_status = (JobStatus.PENDING.value, JobStatus.RUNNING.value)
created = 0
with session_scope() as session:
# 已有活跃 OCR 任务的 screenshot_id
busy_ids = set(
session.scalars(
select(Job.screenshot_id).where(
Job.kind == JobKind.OCR.value,
Job.status.in_(active_status),
)
).all()
)
shots = session.scalars(
select(Screenshot)
.where(
Screenshot.ocr_status == ProcessStatus.FAILED.value,
Screenshot.ai_status == ProcessStatus.DONE.value,
)
.order_by(Screenshot.id.asc())
.limit(limit)
).all()
for shot in shots:
if shot.id in busy_ids:
continue
session.add(
Job(
screenshot_id=shot.id,
kind=JobKind.OCR.value,
status=JobStatus.PENDING.value,
)
)
busy_ids.add(shot.id)
created += 1
return created
# ---------------- 内部辅助 ---------------- #
def _load_provider_config(session, key: str) -> Optional[ProviderConfig]:
raw = get_provider_config(session, key)
if not raw:
return None
try:
return ProviderConfig(**raw)
except Exception as exc: # noqa: BLE001
logger.warning("Provider 配置 %s 解析失败: %s", key, exc)
return None
def _safe_build(builder, cfg: Optional[ProviderConfig], label: str):
if cfg is None:
return None
try:
return builder(cfg)
except Exception as exc: # noqa: BLE001
logger.warning("%s Provider 构造失败: %s", label, exc)
return None
def _is_sensitive(session, image_path: Path) -> bool:
"""判断文件是否落在某个标记为敏感的监听目录内。"""
sensitive_dirs = session.scalars(
select(WatchFolder.path).where(WatchFolder.is_sensitive.is_(True))
).all()
child = str(image_path)
for d in sensitive_dirs:
if path_is_under(d, child):
return True
return False
def _get_or_create_meta(session, screenshot_id: int) -> ScreenshotMeta:
meta = session.get(ScreenshotMeta, screenshot_id)
if meta is None:
meta = ScreenshotMeta(screenshot_id=screenshot_id)
session.add(meta)
session.flush()
return meta
def ensure_default_categories() -> None:
"""对外暴露:启动时 seed 默认分类。"""
with session_scope() as session:
_ensure_default_categories(session)
def _ensure_default_categories(session) -> list[Category]:
"""首次运行时灌入默认分类,返回最新列表。"""
existing = session.scalars(select(Category)).all()
if existing:
return list(existing)
for item in DEFAULT_CATEGORIES:
session.add(Category(**item))
session.flush()
return list(session.scalars(select(Category)).all())
def _resolve_category(
session,
name: str | None,
categories: list[Category],
) -> Optional[Category]:
if not name:
return None
normalized = name.strip()
for c in categories:
if c.name == normalized or c.name in normalized or normalized in c.name:
return c
new_cat = Category(name=normalized[:64], color=None, prompt_hint=None)
session.add(new_cat)
session.flush()
categories.append(new_cat)
return new_cat
def _sync_tags(session, screenshot: Screenshot, tag_names: list[str]) -> None:
"""根据 AI 给的标签名同步多对多关系;保留 EXIF 地点标签不被覆盖。"""
exif_tags = [t for t in (screenshot.tags or []) if is_exif_location_tag(t.name)]
exif_names = {t.name for t in exif_tags}
seen: set[str] = set(exif_names)
tag_objs: list[Tag] = list(exif_tags)
for raw_name in tag_names:
name = (raw_name or "").strip()[:64]
if not name or name in seen:
continue
seen.add(name)
tag = session.scalar(select(Tag).where(Tag.name == name))
if tag is None:
tag = Tag(name=name)
session.add(tag)
session.flush()
tag_objs.append(tag)
screenshot.tags = tag_objs
def _sync_todos(
session,
screenshot: Screenshot,
todos: list[dict[str, str]],
) -> None:
"""以 AI 输出覆盖该截图未完成的 todos;保留用户已完成/搁置项。"""
existing = list(screenshot.todos)
for t in existing:
if t.status in (TodoStatus.DONE.value, TodoStatus.DROPPED.value):
continue
session.delete(t)
session.flush()
for item in todos:
title = (item.get("title") or "").strip()
if not title:
continue
session.add(
Todo(
screenshot_id=screenshot.id,
title=title[:512],
note=(item.get("note") or "")[:2000] or None,
kind=(item.get("kind") or "待办")[:32],
status=TodoStatus.PENDING.value,
)
)
session.flush()