"""单张截图的分析逻辑:OCR -> VLM -> 写回数据库。 设计要点: - 不在长时间网络调用期间持有 SQLite 写事务,避免 `database is locked`。 - 把流程拆为「短事务(取配置/标记状态)」 -> 「无事务(OCR/VLM 网络调用)」 -> 「短事务(写回结果)」。 """ from __future__ import annotations import json from dataclasses import dataclass from pathlib import Path from typing import Optional from sqlalchemy import select from app.core.db import session_scope from app.core.logger import get_logger from app.core.path_utils import is_accessible_file, path_from_storage, path_is_under from app.models.category import Category, DEFAULT_CATEGORIES from app.models.meta import ScreenshotMeta from app.models.screenshot import ProcessStatus, Screenshot from app.models.setting import ( DEFAULT_RECOGNITION_MODE, KEY_OCR_PROVIDER, KEY_RECOGNITION_MODE, KEY_VLM_PROVIDER, ) from app.models.tag import Tag from app.models.todo import Todo, TodoStatus from app.models.watch_folder import WatchFolder from app.providers import ( RECOGNITION_MODES, build_ocr_provider, build_vision_ocr, build_vlm_provider, ) from app.providers.base import VLMResult from app.schemas.common import ProviderConfig from app.services.exif_utils import is_exif_location_tag from app.services.settings_store import get_provider_config, get_setting logger = get_logger(__name__) @dataclass class _PreparedContext: """从短事务中导出的、不依赖 ORM 会话的纯数据。""" path: Path ocr_cfg: Optional[ProviderConfig] vlm_cfg: Optional[ProviderConfig] recognition_mode: str category_names: list[str] allow_upload: bool exists: bool async def analyze_screenshot_by_id(screenshot_id: int) -> None: """对外入口:按 id 分析单张截图。 被 worker 调度。函数内部自己管理多个短事务。 """ ctx = _prepare(screenshot_id) if ctx is None: return # 截图已被删除 if not ctx.exists: _persist_missing(screenshot_id) return ocr_provider = _safe_build( lambda c: build_ocr_provider(c, allow_upload=ctx.allow_upload), ctx.ocr_cfg if _use_traditional_ocr(ctx) else None, "OCR", ) vlm_provider = _safe_build(build_vlm_provider, ctx.vlm_cfg, "VLM") # ---- 文字识别阶段(在事务外执行)---- ocr_text, ocr_status = await _extract_text(screenshot_id, ctx, ocr_provider) # ---- VLM 阶段(事务外)---- vlm_result: Optional[VLMResult] = None ai_status = ProcessStatus.SKIPPED.value vlm_error: Optional[Exception] = None if vlm_provider is not None: _mark_status(screenshot_id, ai=ProcessStatus.RUNNING.value) try: vlm_result = await vlm_provider.analyze( image_path=ctx.path, ocr_text=ocr_text, categories=ctx.category_names, allow_upload=ctx.allow_upload, ) ai_status = ProcessStatus.DONE.value except Exception as exc: # noqa: BLE001 logger.warning("VLM 失败 #%d: %s", screenshot_id, exc) ai_status = ProcessStatus.FAILED.value vlm_error = exc # ---- 写回阶段(短事务)---- _persist_result( screenshot_id=screenshot_id, ocr_text=ocr_text, ocr_status=ocr_status, ai_status=ai_status, vlm_result=vlm_result, ) if vlm_error is not None: raise vlm_error # 让 worker 决定重试 async def analyze_ocr_only_by_id(screenshot_id: int) -> None: """仅补跑 OCR/视觉识文,不改动 AI 分析结果。 用于 ai_status=done 但 ocr_status=failed 的截图。 OCR 仍失败时抛异常,由 worker 按 max_retries 重试。 """ ctx = _prepare(screenshot_id) if ctx is None: return if not ctx.exists: _persist_missing(screenshot_id) raise RuntimeError("截图文件丢失") ocr_provider = _safe_build( lambda c: build_ocr_provider(c, allow_upload=ctx.allow_upload), ctx.ocr_cfg if _use_traditional_ocr(ctx) else None, "OCR", ) ocr_text, ocr_status = await _extract_text(screenshot_id, ctx, ocr_provider) _persist_ocr_only(screenshot_id, ocr_text, ocr_status) if ocr_status == ProcessStatus.FAILED.value: raise RuntimeError("OCR 识别失败") # ---------------- 短事务工具 ---------------- # def _prepare(screenshot_id: int) -> Optional[_PreparedContext]: """短事务:读取 Provider 配置、分类列表、敏感目录判定。""" with session_scope() as session: shot = session.get(Screenshot, screenshot_id) if shot is None: return None image_path = path_from_storage(shot.path) exists = is_accessible_file(image_path) ocr_cfg = _load_provider_config(session, KEY_OCR_PROVIDER) vlm_cfg = _load_provider_config(session, KEY_VLM_PROVIDER) mode = get_setting(session, KEY_RECOGNITION_MODE, DEFAULT_RECOGNITION_MODE) if mode not in RECOGNITION_MODES: mode = DEFAULT_RECOGNITION_MODE categories = _ensure_default_categories(session) category_names = [c.name for c in categories] allow_upload = not _is_sensitive(session, image_path) return _PreparedContext( path=image_path, ocr_cfg=ocr_cfg, vlm_cfg=vlm_cfg, recognition_mode=mode, category_names=category_names, allow_upload=allow_upload, exists=exists, ) def _use_traditional_ocr(ctx: _PreparedContext) -> bool: """混合/传统模式下是否启用 OCR 区配置(排除 vision 类型,vision 单独处理)。""" if ctx.recognition_mode not in ("ocr", "hybrid"): return False if ctx.ocr_cfg is None or ctx.ocr_cfg.type in ("", "none", "disabled", "vision"): return False return True async def _extract_text( screenshot_id: int, ctx: _PreparedContext, ocr_provider, ) -> tuple[str, str]: """按识别模式提取文字:传统 OCR / 视觉 AI / 混合。""" ocr_text = "" ocr_status = ProcessStatus.SKIPPED.value mode = ctx.recognition_mode # 1) 传统 OCR(Tesseract / Paddle / HTTP) if ocr_provider is not None: _mark_status(screenshot_id, ocr=ProcessStatus.RUNNING.value) try: ocr_text = await ocr_provider.recognize(ctx.path) ocr_status = ProcessStatus.DONE.value except Exception as exc: # noqa: BLE001 logger.warning("OCR 失败 #%d: %s", screenshot_id, exc) ocr_status = ProcessStatus.FAILED.value # 2) 视觉 AI 识文 need_vision = mode == "vision" or ( mode == "hybrid" and not ocr_text.strip() ) if mode == "ocr" and ctx.ocr_cfg and ctx.ocr_cfg.type == "vision": # 用户在 OCR 区选了「视觉模型识文」 need_vision = True if need_vision: vision_cfg = _pick_vision_config(ctx) vision = _safe_build( lambda c: build_vision_ocr(c, allow_upload=ctx.allow_upload), vision_cfg, "VisionOCR", ) if vision is not None: _mark_status(screenshot_id, ocr=ProcessStatus.RUNNING.value) try: ocr_text = await vision.recognize(ctx.path) ocr_status = ProcessStatus.DONE.value except Exception as exc: # noqa: BLE001 logger.warning("视觉识文失败 #%d: %s", screenshot_id, exc) if ocr_status != ProcessStatus.DONE.value: ocr_status = ProcessStatus.FAILED.value return ocr_text, ocr_status def _pick_vision_config(ctx: _PreparedContext) -> Optional[ProviderConfig]: """决定视觉识文用哪套配置:优先 OCR 区的 vision,否则 VLM 区。""" if ctx.ocr_cfg and ctx.ocr_cfg.type == "vision": return ctx.ocr_cfg if ctx.recognition_mode == "vision" or ctx.recognition_mode == "hybrid": return ctx.vlm_cfg return ctx.vlm_cfg def _mark_status( screenshot_id: int, ocr: Optional[str] = None, ai: Optional[str] = None, ) -> None: """短事务:把截图标记为 running,方便前端看到进度。""" if ocr is None and ai is None: return with session_scope() as session: shot = session.get(Screenshot, screenshot_id) if shot is None: return if ocr is not None: shot.ocr_status = ocr if ai is not None: shot.ai_status = ai def _persist_missing(screenshot_id: int) -> None: """短事务:标记文件已丢失。""" with session_scope() as session: shot = session.get(Screenshot, screenshot_id) if shot is None: return shot.ocr_status = ProcessStatus.FAILED.value shot.ai_status = ProcessStatus.FAILED.value meta = _get_or_create_meta(session, screenshot_id) meta.ai_summary = "(文件丢失)" def _persist_result( screenshot_id: int, ocr_text: str, ocr_status: str, ai_status: str, vlm_result: Optional[VLMResult], ) -> None: """短事务:把 OCR/VLM 结果写回 DB,包括 meta/tags/category/todos。""" with session_scope() as session: shot = session.get(Screenshot, screenshot_id) if shot is None: return shot.ocr_status = ocr_status shot.ai_status = ai_status meta = _get_or_create_meta(session, screenshot_id) meta.ocr_text = ocr_text or None if vlm_result is not None: meta.ai_title = vlm_result.title or None meta.ai_summary = vlm_result.summary or None meta.ai_suggestion = vlm_result.suggestion or None meta.ai_raw_json = json.dumps(vlm_result.raw, ensure_ascii=False) categories = list(session.scalars(select(Category)).all()) category = _resolve_category(session, vlm_result.category, categories) if category is not None: shot.category_id = category.id _sync_tags(session, shot, vlm_result.tags) _sync_todos(session, shot, vlm_result.todos) def _persist_ocr_only(screenshot_id: int, ocr_text: str, ocr_status: str) -> None: """短事务:仅写回 OCR 文本与状态,保留已有 AI 字段。""" with session_scope() as session: shot = session.get(Screenshot, screenshot_id) if shot is None: return shot.ocr_status = ocr_status meta = _get_or_create_meta(session, screenshot_id) meta.ocr_text = ocr_text or None def enqueue_ocr_jobs(*, limit: int = 500) -> int: """为「AI 已成功、OCR 失败」的截图批量创建 OCR 补跑任务。 跳过已有 pending/running 的 ocr 任务,避免重复入队。 """ from app.models.job import Job, JobKind, JobStatus active_status = (JobStatus.PENDING.value, JobStatus.RUNNING.value) created = 0 with session_scope() as session: # 已有活跃 OCR 任务的 screenshot_id busy_ids = set( session.scalars( select(Job.screenshot_id).where( Job.kind == JobKind.OCR.value, Job.status.in_(active_status), ) ).all() ) shots = session.scalars( select(Screenshot) .where( Screenshot.ocr_status == ProcessStatus.FAILED.value, Screenshot.ai_status == ProcessStatus.DONE.value, ) .order_by(Screenshot.id.asc()) .limit(limit) ).all() for shot in shots: if shot.id in busy_ids: continue session.add( Job( screenshot_id=shot.id, kind=JobKind.OCR.value, status=JobStatus.PENDING.value, ) ) busy_ids.add(shot.id) created += 1 return created # ---------------- 内部辅助 ---------------- # def _load_provider_config(session, key: str) -> Optional[ProviderConfig]: raw = get_provider_config(session, key) if not raw: return None try: return ProviderConfig(**raw) except Exception as exc: # noqa: BLE001 logger.warning("Provider 配置 %s 解析失败: %s", key, exc) return None def _safe_build(builder, cfg: Optional[ProviderConfig], label: str): if cfg is None: return None try: return builder(cfg) except Exception as exc: # noqa: BLE001 logger.warning("%s Provider 构造失败: %s", label, exc) return None def _is_sensitive(session, image_path: Path) -> bool: """判断文件是否落在某个标记为敏感的监听目录内。""" sensitive_dirs = session.scalars( select(WatchFolder.path).where(WatchFolder.is_sensitive.is_(True)) ).all() child = str(image_path) for d in sensitive_dirs: if path_is_under(d, child): return True return False def _get_or_create_meta(session, screenshot_id: int) -> ScreenshotMeta: meta = session.get(ScreenshotMeta, screenshot_id) if meta is None: meta = ScreenshotMeta(screenshot_id=screenshot_id) session.add(meta) session.flush() return meta def ensure_default_categories() -> None: """对外暴露:启动时 seed 默认分类。""" with session_scope() as session: _ensure_default_categories(session) def _ensure_default_categories(session) -> list[Category]: """首次运行时灌入默认分类,返回最新列表。""" existing = session.scalars(select(Category)).all() if existing: return list(existing) for item in DEFAULT_CATEGORIES: session.add(Category(**item)) session.flush() return list(session.scalars(select(Category)).all()) def _resolve_category( session, name: str | None, categories: list[Category], ) -> Optional[Category]: if not name: return None normalized = name.strip() for c in categories: if c.name == normalized or c.name in normalized or normalized in c.name: return c new_cat = Category(name=normalized[:64], color=None, prompt_hint=None) session.add(new_cat) session.flush() categories.append(new_cat) return new_cat def _sync_tags(session, screenshot: Screenshot, tag_names: list[str]) -> None: """根据 AI 给的标签名同步多对多关系;保留 EXIF 地点标签不被覆盖。""" exif_tags = [t for t in (screenshot.tags or []) if is_exif_location_tag(t.name)] exif_names = {t.name for t in exif_tags} seen: set[str] = set(exif_names) tag_objs: list[Tag] = list(exif_tags) for raw_name in tag_names: name = (raw_name or "").strip()[:64] if not name or name in seen: continue seen.add(name) tag = session.scalar(select(Tag).where(Tag.name == name)) if tag is None: tag = Tag(name=name) session.add(tag) session.flush() tag_objs.append(tag) screenshot.tags = tag_objs def _sync_todos( session, screenshot: Screenshot, todos: list[dict[str, str]], ) -> None: """以 AI 输出覆盖该截图未完成的 todos;保留用户已完成/搁置项。""" existing = list(screenshot.todos) for t in existing: if t.status in (TodoStatus.DONE.value, TodoStatus.DROPPED.value): continue session.delete(t) session.flush() for item in todos: title = (item.get("title") or "").strip() if not title: continue session.add( Todo( screenshot_id=screenshot.id, title=title[:512], note=(item.get("note") or "")[:2000] or None, kind=(item.get("kind") or "待办")[:32], status=TodoStatus.PENDING.value, ) ) session.flush()