Files
SnapAndAnaly/backend/app/services/analyze.py
T

485 lines
16 KiB
Python
Raw Normal View History

"""单张截图的分析逻辑:OCR -> VLM -> 写回数据库。
设计要点:
- 不在长时间网络调用期间持有 SQLite 写事务,避免 `database is locked`。
- 把流程拆为「短事务(取配置/标记状态)」 -> 「无事务(OCR/VLM 网络调用)」
-> 「短事务(写回结果)」。
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from sqlalchemy import select
from app.core.db import session_scope
from app.core.logger import get_logger
from app.core.path_utils import is_accessible_file, path_from_storage, path_is_under
from app.models.category import Category, DEFAULT_CATEGORIES
from app.models.meta import ScreenshotMeta
from app.models.screenshot import ProcessStatus, Screenshot
from app.models.setting import (
DEFAULT_RECOGNITION_MODE,
KEY_OCR_PROVIDER,
KEY_RECOGNITION_MODE,
KEY_VLM_PROVIDER,
)
from app.models.tag import Tag
from app.models.todo import Todo, TodoStatus
from app.models.watch_folder import WatchFolder
from app.providers import (
RECOGNITION_MODES,
build_ocr_provider,
build_vision_ocr,
build_vlm_provider,
)
from app.providers.base import VLMResult
from app.schemas.common import ProviderConfig
from app.services.exif_utils import is_exif_location_tag
from app.services.settings_store import get_provider_config, get_setting
logger = get_logger(__name__)
@dataclass
class _PreparedContext:
"""从短事务中导出的、不依赖 ORM 会话的纯数据。"""
path: Path
ocr_cfg: Optional[ProviderConfig]
vlm_cfg: Optional[ProviderConfig]
recognition_mode: str
category_names: list[str]
allow_upload: bool
exists: bool
async def analyze_screenshot_by_id(screenshot_id: int) -> None:
"""对外入口:按 id 分析单张截图。
被 worker 调度。函数内部自己管理多个短事务。
"""
ctx = _prepare(screenshot_id)
if ctx is None:
return # 截图已被删除
if not ctx.exists:
_persist_missing(screenshot_id)
return
ocr_provider = _safe_build(
lambda c: build_ocr_provider(c, allow_upload=ctx.allow_upload),
ctx.ocr_cfg if _use_traditional_ocr(ctx) else None,
"OCR",
)
vlm_provider = _safe_build(build_vlm_provider, ctx.vlm_cfg, "VLM")
# ---- 文字识别阶段(在事务外执行)----
ocr_text, ocr_status = await _extract_text(screenshot_id, ctx, ocr_provider)
# ---- VLM 阶段(事务外)----
vlm_result: Optional[VLMResult] = None
ai_status = ProcessStatus.SKIPPED.value
vlm_error: Optional[Exception] = None
if vlm_provider is not None:
_mark_status(screenshot_id, ai=ProcessStatus.RUNNING.value)
try:
vlm_result = await vlm_provider.analyze(
image_path=ctx.path,
ocr_text=ocr_text,
categories=ctx.category_names,
allow_upload=ctx.allow_upload,
)
ai_status = ProcessStatus.DONE.value
except Exception as exc: # noqa: BLE001
logger.warning("VLM 失败 #%d: %s", screenshot_id, exc)
ai_status = ProcessStatus.FAILED.value
vlm_error = exc
# ---- 写回阶段(短事务)----
_persist_result(
screenshot_id=screenshot_id,
ocr_text=ocr_text,
ocr_status=ocr_status,
ai_status=ai_status,
vlm_result=vlm_result,
)
if vlm_error is not None:
raise vlm_error # 让 worker 决定重试
async def analyze_ocr_only_by_id(screenshot_id: int) -> None:
"""仅补跑 OCR/视觉识文,不改动 AI 分析结果。
用于 ai_status=done 但 ocr_status=failed 的截图。
OCR 仍失败时抛异常,由 worker 按 max_retries 重试。
"""
ctx = _prepare(screenshot_id)
if ctx is None:
return
if not ctx.exists:
_persist_missing(screenshot_id)
raise RuntimeError("截图文件丢失")
ocr_provider = _safe_build(
lambda c: build_ocr_provider(c, allow_upload=ctx.allow_upload),
ctx.ocr_cfg if _use_traditional_ocr(ctx) else None,
"OCR",
)
ocr_text, ocr_status = await _extract_text(screenshot_id, ctx, ocr_provider)
_persist_ocr_only(screenshot_id, ocr_text, ocr_status)
if ocr_status == ProcessStatus.FAILED.value:
raise RuntimeError("OCR 识别失败")
# ---------------- 短事务工具 ---------------- #
def _prepare(screenshot_id: int) -> Optional[_PreparedContext]:
"""短事务:读取 Provider 配置、分类列表、敏感目录判定。"""
with session_scope() as session:
shot = session.get(Screenshot, screenshot_id)
if shot is None:
return None
image_path = path_from_storage(shot.path)
exists = is_accessible_file(image_path)
ocr_cfg = _load_provider_config(session, KEY_OCR_PROVIDER)
vlm_cfg = _load_provider_config(session, KEY_VLM_PROVIDER)
mode = get_setting(session, KEY_RECOGNITION_MODE, DEFAULT_RECOGNITION_MODE)
if mode not in RECOGNITION_MODES:
mode = DEFAULT_RECOGNITION_MODE
categories = _ensure_default_categories(session)
category_names = [c.name for c in categories]
allow_upload = not _is_sensitive(session, image_path)
return _PreparedContext(
path=image_path,
ocr_cfg=ocr_cfg,
vlm_cfg=vlm_cfg,
recognition_mode=mode,
category_names=category_names,
allow_upload=allow_upload,
exists=exists,
)
def _use_traditional_ocr(ctx: _PreparedContext) -> bool:
"""混合/传统模式下是否启用 OCR 区配置(排除 vision 类型,vision 单独处理)。"""
if ctx.recognition_mode not in ("ocr", "hybrid"):
return False
if ctx.ocr_cfg is None or ctx.ocr_cfg.type in ("", "none", "disabled", "vision"):
return False
return True
async def _extract_text(
screenshot_id: int,
ctx: _PreparedContext,
ocr_provider,
) -> tuple[str, str]:
"""按识别模式提取文字:传统 OCR / 视觉 AI / 混合。"""
ocr_text = ""
ocr_status = ProcessStatus.SKIPPED.value
mode = ctx.recognition_mode
# 1) 传统 OCRTesseract / Paddle / HTTP
if ocr_provider is not None:
_mark_status(screenshot_id, ocr=ProcessStatus.RUNNING.value)
try:
ocr_text = await ocr_provider.recognize(ctx.path)
ocr_status = ProcessStatus.DONE.value
except Exception as exc: # noqa: BLE001
logger.warning("OCR 失败 #%d: %s", screenshot_id, exc)
ocr_status = ProcessStatus.FAILED.value
# 2) 视觉 AI 识文
need_vision = mode == "vision" or (
mode == "hybrid" and not ocr_text.strip()
)
if mode == "ocr" and ctx.ocr_cfg and ctx.ocr_cfg.type == "vision":
# 用户在 OCR 区选了「视觉模型识文」
need_vision = True
if need_vision:
vision_cfg = _pick_vision_config(ctx)
vision = _safe_build(
lambda c: build_vision_ocr(c, allow_upload=ctx.allow_upload),
vision_cfg,
"VisionOCR",
)
if vision is not None:
_mark_status(screenshot_id, ocr=ProcessStatus.RUNNING.value)
try:
ocr_text = await vision.recognize(ctx.path)
ocr_status = ProcessStatus.DONE.value
except Exception as exc: # noqa: BLE001
logger.warning("视觉识文失败 #%d: %s", screenshot_id, exc)
if ocr_status != ProcessStatus.DONE.value:
ocr_status = ProcessStatus.FAILED.value
return ocr_text, ocr_status
def _pick_vision_config(ctx: _PreparedContext) -> Optional[ProviderConfig]:
"""决定视觉识文用哪套配置:优先 OCR 区的 vision,否则 VLM 区。"""
if ctx.ocr_cfg and ctx.ocr_cfg.type == "vision":
return ctx.ocr_cfg
if ctx.recognition_mode == "vision" or ctx.recognition_mode == "hybrid":
return ctx.vlm_cfg
return ctx.vlm_cfg
def _mark_status(
screenshot_id: int,
ocr: Optional[str] = None,
ai: Optional[str] = None,
) -> None:
"""短事务:把截图标记为 running,方便前端看到进度。"""
if ocr is None and ai is None:
return
with session_scope() as session:
shot = session.get(Screenshot, screenshot_id)
if shot is None:
return
if ocr is not None:
shot.ocr_status = ocr
if ai is not None:
shot.ai_status = ai
def _persist_missing(screenshot_id: int) -> None:
"""短事务:标记文件已丢失。"""
with session_scope() as session:
shot = session.get(Screenshot, screenshot_id)
if shot is None:
return
shot.ocr_status = ProcessStatus.FAILED.value
shot.ai_status = ProcessStatus.FAILED.value
meta = _get_or_create_meta(session, screenshot_id)
meta.ai_summary = "(文件丢失)"
def _persist_result(
screenshot_id: int,
ocr_text: str,
ocr_status: str,
ai_status: str,
vlm_result: Optional[VLMResult],
) -> None:
"""短事务:把 OCR/VLM 结果写回 DB,包括 meta/tags/category/todos。"""
with session_scope() as session:
shot = session.get(Screenshot, screenshot_id)
if shot is None:
return
shot.ocr_status = ocr_status
shot.ai_status = ai_status
meta = _get_or_create_meta(session, screenshot_id)
meta.ocr_text = ocr_text or None
if vlm_result is not None:
meta.ai_title = vlm_result.title or None
meta.ai_summary = vlm_result.summary or None
meta.ai_suggestion = vlm_result.suggestion or None
meta.ai_raw_json = json.dumps(vlm_result.raw, ensure_ascii=False)
categories = list(session.scalars(select(Category)).all())
category = _resolve_category(session, vlm_result.category, categories)
if category is not None:
shot.category_id = category.id
_sync_tags(session, shot, vlm_result.tags)
_sync_todos(session, shot, vlm_result.todos)
def _persist_ocr_only(screenshot_id: int, ocr_text: str, ocr_status: str) -> None:
"""短事务:仅写回 OCR 文本与状态,保留已有 AI 字段。"""
with session_scope() as session:
shot = session.get(Screenshot, screenshot_id)
if shot is None:
return
shot.ocr_status = ocr_status
meta = _get_or_create_meta(session, screenshot_id)
meta.ocr_text = ocr_text or None
def enqueue_ocr_jobs(*, limit: int = 500) -> int:
"""为「AI 已成功、OCR 失败」的截图批量创建 OCR 补跑任务。
跳过已有 pending/running 的 ocr 任务,避免重复入队。
"""
from app.models.job import Job, JobKind, JobStatus
active_status = (JobStatus.PENDING.value, JobStatus.RUNNING.value)
created = 0
with session_scope() as session:
# 已有活跃 OCR 任务的 screenshot_id
busy_ids = set(
session.scalars(
select(Job.screenshot_id).where(
Job.kind == JobKind.OCR.value,
Job.status.in_(active_status),
)
).all()
)
shots = session.scalars(
select(Screenshot)
.where(
Screenshot.ocr_status == ProcessStatus.FAILED.value,
Screenshot.ai_status == ProcessStatus.DONE.value,
)
.order_by(Screenshot.id.asc())
.limit(limit)
).all()
for shot in shots:
if shot.id in busy_ids:
continue
session.add(
Job(
screenshot_id=shot.id,
kind=JobKind.OCR.value,
status=JobStatus.PENDING.value,
)
)
busy_ids.add(shot.id)
created += 1
return created
# ---------------- 内部辅助 ---------------- #
def _load_provider_config(session, key: str) -> Optional[ProviderConfig]:
raw = get_provider_config(session, key)
if not raw:
return None
try:
return ProviderConfig(**raw)
except Exception as exc: # noqa: BLE001
logger.warning("Provider 配置 %s 解析失败: %s", key, exc)
return None
def _safe_build(builder, cfg: Optional[ProviderConfig], label: str):
if cfg is None:
return None
try:
return builder(cfg)
except Exception as exc: # noqa: BLE001
logger.warning("%s Provider 构造失败: %s", label, exc)
return None
def _is_sensitive(session, image_path: Path) -> bool:
"""判断文件是否落在某个标记为敏感的监听目录内。"""
sensitive_dirs = session.scalars(
select(WatchFolder.path).where(WatchFolder.is_sensitive.is_(True))
).all()
child = str(image_path)
for d in sensitive_dirs:
if path_is_under(d, child):
return True
return False
def _get_or_create_meta(session, screenshot_id: int) -> ScreenshotMeta:
meta = session.get(ScreenshotMeta, screenshot_id)
if meta is None:
meta = ScreenshotMeta(screenshot_id=screenshot_id)
session.add(meta)
session.flush()
return meta
def ensure_default_categories() -> None:
"""对外暴露:启动时 seed 默认分类。"""
with session_scope() as session:
_ensure_default_categories(session)
def _ensure_default_categories(session) -> list[Category]:
"""首次运行时灌入默认分类,返回最新列表。"""
existing = session.scalars(select(Category)).all()
if existing:
return list(existing)
for item in DEFAULT_CATEGORIES:
session.add(Category(**item))
session.flush()
return list(session.scalars(select(Category)).all())
def _resolve_category(
session,
name: str | None,
categories: list[Category],
) -> Optional[Category]:
if not name:
return None
normalized = name.strip()
for c in categories:
if c.name == normalized or c.name in normalized or normalized in c.name:
return c
new_cat = Category(name=normalized[:64], color=None, prompt_hint=None)
session.add(new_cat)
session.flush()
categories.append(new_cat)
return new_cat
def _sync_tags(session, screenshot: Screenshot, tag_names: list[str]) -> None:
"""根据 AI 给的标签名同步多对多关系;保留 EXIF 地点标签不被覆盖。"""
exif_tags = [t for t in (screenshot.tags or []) if is_exif_location_tag(t.name)]
exif_names = {t.name for t in exif_tags}
seen: set[str] = set(exif_names)
tag_objs: list[Tag] = list(exif_tags)
for raw_name in tag_names:
name = (raw_name or "").strip()[:64]
if not name or name in seen:
continue
seen.add(name)
tag = session.scalar(select(Tag).where(Tag.name == name))
if tag is None:
tag = Tag(name=name)
session.add(tag)
session.flush()
tag_objs.append(tag)
screenshot.tags = tag_objs
def _sync_todos(
session,
screenshot: Screenshot,
todos: list[dict[str, str]],
) -> None:
"""以 AI 输出覆盖该截图未完成的 todos;保留用户已完成/搁置项。"""
existing = list(screenshot.todos)
for t in existing:
if t.status in (TodoStatus.DONE.value, TodoStatus.DROPPED.value):
continue
session.delete(t)
session.flush()
for item in todos:
title = (item.get("title") or "").strip()
if not title:
continue
session.add(
Todo(
screenshot_id=screenshot.id,
title=title[:512],
note=(item.get("note") or "")[:2000] or None,
kind=(item.get("kind") or "待办")[:32],
status=TodoStatus.PENDING.value,
)
)
session.flush()