5c028d7952
包含 FastAPI 后端、React 前端、队列/OCR/标签/待办等完整功能。 Co-authored-by: Cursor <cursoragent@cursor.com>
485 lines
16 KiB
Python
485 lines
16 KiB
Python
"""单张截图的分析逻辑:OCR -> VLM -> 写回数据库。
|
||
|
||
设计要点:
|
||
- 不在长时间网络调用期间持有 SQLite 写事务,避免 `database is locked`。
|
||
- 把流程拆为「短事务(取配置/标记状态)」 -> 「无事务(OCR/VLM 网络调用)」
|
||
-> 「短事务(写回结果)」。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
|
||
from sqlalchemy import select
|
||
|
||
from app.core.db import session_scope
|
||
from app.core.logger import get_logger
|
||
from app.core.path_utils import is_accessible_file, path_from_storage, path_is_under
|
||
from app.models.category import Category, DEFAULT_CATEGORIES
|
||
from app.models.meta import ScreenshotMeta
|
||
from app.models.screenshot import ProcessStatus, Screenshot
|
||
from app.models.setting import (
|
||
DEFAULT_RECOGNITION_MODE,
|
||
KEY_OCR_PROVIDER,
|
||
KEY_RECOGNITION_MODE,
|
||
KEY_VLM_PROVIDER,
|
||
)
|
||
from app.models.tag import Tag
|
||
from app.models.todo import Todo, TodoStatus
|
||
from app.models.watch_folder import WatchFolder
|
||
from app.providers import (
|
||
RECOGNITION_MODES,
|
||
build_ocr_provider,
|
||
build_vision_ocr,
|
||
build_vlm_provider,
|
||
)
|
||
from app.providers.base import VLMResult
|
||
from app.schemas.common import ProviderConfig
|
||
from app.services.exif_utils import is_exif_location_tag
|
||
from app.services.settings_store import get_provider_config, get_setting
|
||
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class _PreparedContext:
|
||
"""从短事务中导出的、不依赖 ORM 会话的纯数据。"""
|
||
|
||
path: Path
|
||
ocr_cfg: Optional[ProviderConfig]
|
||
vlm_cfg: Optional[ProviderConfig]
|
||
recognition_mode: str
|
||
category_names: list[str]
|
||
allow_upload: bool
|
||
exists: bool
|
||
|
||
|
||
async def analyze_screenshot_by_id(screenshot_id: int) -> None:
|
||
"""对外入口:按 id 分析单张截图。
|
||
|
||
被 worker 调度。函数内部自己管理多个短事务。
|
||
"""
|
||
ctx = _prepare(screenshot_id)
|
||
if ctx is None:
|
||
return # 截图已被删除
|
||
if not ctx.exists:
|
||
_persist_missing(screenshot_id)
|
||
return
|
||
|
||
ocr_provider = _safe_build(
|
||
lambda c: build_ocr_provider(c, allow_upload=ctx.allow_upload),
|
||
ctx.ocr_cfg if _use_traditional_ocr(ctx) else None,
|
||
"OCR",
|
||
)
|
||
vlm_provider = _safe_build(build_vlm_provider, ctx.vlm_cfg, "VLM")
|
||
|
||
# ---- 文字识别阶段(在事务外执行)----
|
||
ocr_text, ocr_status = await _extract_text(screenshot_id, ctx, ocr_provider)
|
||
|
||
# ---- VLM 阶段(事务外)----
|
||
vlm_result: Optional[VLMResult] = None
|
||
ai_status = ProcessStatus.SKIPPED.value
|
||
vlm_error: Optional[Exception] = None
|
||
if vlm_provider is not None:
|
||
_mark_status(screenshot_id, ai=ProcessStatus.RUNNING.value)
|
||
try:
|
||
vlm_result = await vlm_provider.analyze(
|
||
image_path=ctx.path,
|
||
ocr_text=ocr_text,
|
||
categories=ctx.category_names,
|
||
allow_upload=ctx.allow_upload,
|
||
)
|
||
ai_status = ProcessStatus.DONE.value
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.warning("VLM 失败 #%d: %s", screenshot_id, exc)
|
||
ai_status = ProcessStatus.FAILED.value
|
||
vlm_error = exc
|
||
|
||
# ---- 写回阶段(短事务)----
|
||
_persist_result(
|
||
screenshot_id=screenshot_id,
|
||
ocr_text=ocr_text,
|
||
ocr_status=ocr_status,
|
||
ai_status=ai_status,
|
||
vlm_result=vlm_result,
|
||
)
|
||
|
||
if vlm_error is not None:
|
||
raise vlm_error # 让 worker 决定重试
|
||
|
||
|
||
async def analyze_ocr_only_by_id(screenshot_id: int) -> None:
|
||
"""仅补跑 OCR/视觉识文,不改动 AI 分析结果。
|
||
|
||
用于 ai_status=done 但 ocr_status=failed 的截图。
|
||
OCR 仍失败时抛异常,由 worker 按 max_retries 重试。
|
||
"""
|
||
ctx = _prepare(screenshot_id)
|
||
if ctx is None:
|
||
return
|
||
if not ctx.exists:
|
||
_persist_missing(screenshot_id)
|
||
raise RuntimeError("截图文件丢失")
|
||
|
||
ocr_provider = _safe_build(
|
||
lambda c: build_ocr_provider(c, allow_upload=ctx.allow_upload),
|
||
ctx.ocr_cfg if _use_traditional_ocr(ctx) else None,
|
||
"OCR",
|
||
)
|
||
|
||
ocr_text, ocr_status = await _extract_text(screenshot_id, ctx, ocr_provider)
|
||
_persist_ocr_only(screenshot_id, ocr_text, ocr_status)
|
||
|
||
if ocr_status == ProcessStatus.FAILED.value:
|
||
raise RuntimeError("OCR 识别失败")
|
||
|
||
|
||
# ---------------- 短事务工具 ---------------- #
|
||
|
||
|
||
def _prepare(screenshot_id: int) -> Optional[_PreparedContext]:
|
||
"""短事务:读取 Provider 配置、分类列表、敏感目录判定。"""
|
||
with session_scope() as session:
|
||
shot = session.get(Screenshot, screenshot_id)
|
||
if shot is None:
|
||
return None
|
||
image_path = path_from_storage(shot.path)
|
||
exists = is_accessible_file(image_path)
|
||
|
||
ocr_cfg = _load_provider_config(session, KEY_OCR_PROVIDER)
|
||
vlm_cfg = _load_provider_config(session, KEY_VLM_PROVIDER)
|
||
mode = get_setting(session, KEY_RECOGNITION_MODE, DEFAULT_RECOGNITION_MODE)
|
||
if mode not in RECOGNITION_MODES:
|
||
mode = DEFAULT_RECOGNITION_MODE
|
||
|
||
categories = _ensure_default_categories(session)
|
||
category_names = [c.name for c in categories]
|
||
|
||
allow_upload = not _is_sensitive(session, image_path)
|
||
|
||
return _PreparedContext(
|
||
path=image_path,
|
||
ocr_cfg=ocr_cfg,
|
||
vlm_cfg=vlm_cfg,
|
||
recognition_mode=mode,
|
||
category_names=category_names,
|
||
allow_upload=allow_upload,
|
||
exists=exists,
|
||
)
|
||
|
||
|
||
def _use_traditional_ocr(ctx: _PreparedContext) -> bool:
|
||
"""混合/传统模式下是否启用 OCR 区配置(排除 vision 类型,vision 单独处理)。"""
|
||
if ctx.recognition_mode not in ("ocr", "hybrid"):
|
||
return False
|
||
if ctx.ocr_cfg is None or ctx.ocr_cfg.type in ("", "none", "disabled", "vision"):
|
||
return False
|
||
return True
|
||
|
||
|
||
async def _extract_text(
|
||
screenshot_id: int,
|
||
ctx: _PreparedContext,
|
||
ocr_provider,
|
||
) -> tuple[str, str]:
|
||
"""按识别模式提取文字:传统 OCR / 视觉 AI / 混合。"""
|
||
ocr_text = ""
|
||
ocr_status = ProcessStatus.SKIPPED.value
|
||
mode = ctx.recognition_mode
|
||
|
||
# 1) 传统 OCR(Tesseract / Paddle / HTTP)
|
||
if ocr_provider is not None:
|
||
_mark_status(screenshot_id, ocr=ProcessStatus.RUNNING.value)
|
||
try:
|
||
ocr_text = await ocr_provider.recognize(ctx.path)
|
||
ocr_status = ProcessStatus.DONE.value
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.warning("OCR 失败 #%d: %s", screenshot_id, exc)
|
||
ocr_status = ProcessStatus.FAILED.value
|
||
|
||
# 2) 视觉 AI 识文
|
||
need_vision = mode == "vision" or (
|
||
mode == "hybrid" and not ocr_text.strip()
|
||
)
|
||
if mode == "ocr" and ctx.ocr_cfg and ctx.ocr_cfg.type == "vision":
|
||
# 用户在 OCR 区选了「视觉模型识文」
|
||
need_vision = True
|
||
|
||
if need_vision:
|
||
vision_cfg = _pick_vision_config(ctx)
|
||
vision = _safe_build(
|
||
lambda c: build_vision_ocr(c, allow_upload=ctx.allow_upload),
|
||
vision_cfg,
|
||
"VisionOCR",
|
||
)
|
||
if vision is not None:
|
||
_mark_status(screenshot_id, ocr=ProcessStatus.RUNNING.value)
|
||
try:
|
||
ocr_text = await vision.recognize(ctx.path)
|
||
ocr_status = ProcessStatus.DONE.value
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.warning("视觉识文失败 #%d: %s", screenshot_id, exc)
|
||
if ocr_status != ProcessStatus.DONE.value:
|
||
ocr_status = ProcessStatus.FAILED.value
|
||
|
||
return ocr_text, ocr_status
|
||
|
||
|
||
def _pick_vision_config(ctx: _PreparedContext) -> Optional[ProviderConfig]:
|
||
"""决定视觉识文用哪套配置:优先 OCR 区的 vision,否则 VLM 区。"""
|
||
if ctx.ocr_cfg and ctx.ocr_cfg.type == "vision":
|
||
return ctx.ocr_cfg
|
||
if ctx.recognition_mode == "vision" or ctx.recognition_mode == "hybrid":
|
||
return ctx.vlm_cfg
|
||
return ctx.vlm_cfg
|
||
|
||
|
||
def _mark_status(
|
||
screenshot_id: int,
|
||
ocr: Optional[str] = None,
|
||
ai: Optional[str] = None,
|
||
) -> None:
|
||
"""短事务:把截图标记为 running,方便前端看到进度。"""
|
||
if ocr is None and ai is None:
|
||
return
|
||
with session_scope() as session:
|
||
shot = session.get(Screenshot, screenshot_id)
|
||
if shot is None:
|
||
return
|
||
if ocr is not None:
|
||
shot.ocr_status = ocr
|
||
if ai is not None:
|
||
shot.ai_status = ai
|
||
|
||
|
||
def _persist_missing(screenshot_id: int) -> None:
|
||
"""短事务:标记文件已丢失。"""
|
||
with session_scope() as session:
|
||
shot = session.get(Screenshot, screenshot_id)
|
||
if shot is None:
|
||
return
|
||
shot.ocr_status = ProcessStatus.FAILED.value
|
||
shot.ai_status = ProcessStatus.FAILED.value
|
||
meta = _get_or_create_meta(session, screenshot_id)
|
||
meta.ai_summary = "(文件丢失)"
|
||
|
||
|
||
def _persist_result(
|
||
screenshot_id: int,
|
||
ocr_text: str,
|
||
ocr_status: str,
|
||
ai_status: str,
|
||
vlm_result: Optional[VLMResult],
|
||
) -> None:
|
||
"""短事务:把 OCR/VLM 结果写回 DB,包括 meta/tags/category/todos。"""
|
||
with session_scope() as session:
|
||
shot = session.get(Screenshot, screenshot_id)
|
||
if shot is None:
|
||
return
|
||
shot.ocr_status = ocr_status
|
||
shot.ai_status = ai_status
|
||
|
||
meta = _get_or_create_meta(session, screenshot_id)
|
||
meta.ocr_text = ocr_text or None
|
||
|
||
if vlm_result is not None:
|
||
meta.ai_title = vlm_result.title or None
|
||
meta.ai_summary = vlm_result.summary or None
|
||
meta.ai_suggestion = vlm_result.suggestion or None
|
||
meta.ai_raw_json = json.dumps(vlm_result.raw, ensure_ascii=False)
|
||
|
||
categories = list(session.scalars(select(Category)).all())
|
||
category = _resolve_category(session, vlm_result.category, categories)
|
||
if category is not None:
|
||
shot.category_id = category.id
|
||
|
||
_sync_tags(session, shot, vlm_result.tags)
|
||
_sync_todos(session, shot, vlm_result.todos)
|
||
|
||
|
||
def _persist_ocr_only(screenshot_id: int, ocr_text: str, ocr_status: str) -> None:
|
||
"""短事务:仅写回 OCR 文本与状态,保留已有 AI 字段。"""
|
||
with session_scope() as session:
|
||
shot = session.get(Screenshot, screenshot_id)
|
||
if shot is None:
|
||
return
|
||
shot.ocr_status = ocr_status
|
||
meta = _get_or_create_meta(session, screenshot_id)
|
||
meta.ocr_text = ocr_text or None
|
||
|
||
|
||
def enqueue_ocr_jobs(*, limit: int = 500) -> int:
|
||
"""为「AI 已成功、OCR 失败」的截图批量创建 OCR 补跑任务。
|
||
|
||
跳过已有 pending/running 的 ocr 任务,避免重复入队。
|
||
"""
|
||
from app.models.job import Job, JobKind, JobStatus
|
||
|
||
active_status = (JobStatus.PENDING.value, JobStatus.RUNNING.value)
|
||
created = 0
|
||
with session_scope() as session:
|
||
# 已有活跃 OCR 任务的 screenshot_id
|
||
busy_ids = set(
|
||
session.scalars(
|
||
select(Job.screenshot_id).where(
|
||
Job.kind == JobKind.OCR.value,
|
||
Job.status.in_(active_status),
|
||
)
|
||
).all()
|
||
)
|
||
shots = session.scalars(
|
||
select(Screenshot)
|
||
.where(
|
||
Screenshot.ocr_status == ProcessStatus.FAILED.value,
|
||
Screenshot.ai_status == ProcessStatus.DONE.value,
|
||
)
|
||
.order_by(Screenshot.id.asc())
|
||
.limit(limit)
|
||
).all()
|
||
for shot in shots:
|
||
if shot.id in busy_ids:
|
||
continue
|
||
session.add(
|
||
Job(
|
||
screenshot_id=shot.id,
|
||
kind=JobKind.OCR.value,
|
||
status=JobStatus.PENDING.value,
|
||
)
|
||
)
|
||
busy_ids.add(shot.id)
|
||
created += 1
|
||
return created
|
||
|
||
|
||
# ---------------- 内部辅助 ---------------- #
|
||
|
||
|
||
def _load_provider_config(session, key: str) -> Optional[ProviderConfig]:
|
||
raw = get_provider_config(session, key)
|
||
if not raw:
|
||
return None
|
||
try:
|
||
return ProviderConfig(**raw)
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.warning("Provider 配置 %s 解析失败: %s", key, exc)
|
||
return None
|
||
|
||
|
||
def _safe_build(builder, cfg: Optional[ProviderConfig], label: str):
|
||
if cfg is None:
|
||
return None
|
||
try:
|
||
return builder(cfg)
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.warning("%s Provider 构造失败: %s", label, exc)
|
||
return None
|
||
|
||
|
||
def _is_sensitive(session, image_path: Path) -> bool:
|
||
"""判断文件是否落在某个标记为敏感的监听目录内。"""
|
||
sensitive_dirs = session.scalars(
|
||
select(WatchFolder.path).where(WatchFolder.is_sensitive.is_(True))
|
||
).all()
|
||
child = str(image_path)
|
||
for d in sensitive_dirs:
|
||
if path_is_under(d, child):
|
||
return True
|
||
return False
|
||
|
||
|
||
def _get_or_create_meta(session, screenshot_id: int) -> ScreenshotMeta:
|
||
meta = session.get(ScreenshotMeta, screenshot_id)
|
||
if meta is None:
|
||
meta = ScreenshotMeta(screenshot_id=screenshot_id)
|
||
session.add(meta)
|
||
session.flush()
|
||
return meta
|
||
|
||
|
||
def ensure_default_categories() -> None:
|
||
"""对外暴露:启动时 seed 默认分类。"""
|
||
with session_scope() as session:
|
||
_ensure_default_categories(session)
|
||
|
||
|
||
def _ensure_default_categories(session) -> list[Category]:
|
||
"""首次运行时灌入默认分类,返回最新列表。"""
|
||
existing = session.scalars(select(Category)).all()
|
||
if existing:
|
||
return list(existing)
|
||
for item in DEFAULT_CATEGORIES:
|
||
session.add(Category(**item))
|
||
session.flush()
|
||
return list(session.scalars(select(Category)).all())
|
||
|
||
|
||
def _resolve_category(
|
||
session,
|
||
name: str | None,
|
||
categories: list[Category],
|
||
) -> Optional[Category]:
|
||
if not name:
|
||
return None
|
||
normalized = name.strip()
|
||
for c in categories:
|
||
if c.name == normalized or c.name in normalized or normalized in c.name:
|
||
return c
|
||
new_cat = Category(name=normalized[:64], color=None, prompt_hint=None)
|
||
session.add(new_cat)
|
||
session.flush()
|
||
categories.append(new_cat)
|
||
return new_cat
|
||
|
||
|
||
def _sync_tags(session, screenshot: Screenshot, tag_names: list[str]) -> None:
|
||
"""根据 AI 给的标签名同步多对多关系;保留 EXIF 地点标签不被覆盖。"""
|
||
exif_tags = [t for t in (screenshot.tags or []) if is_exif_location_tag(t.name)]
|
||
exif_names = {t.name for t in exif_tags}
|
||
|
||
seen: set[str] = set(exif_names)
|
||
tag_objs: list[Tag] = list(exif_tags)
|
||
for raw_name in tag_names:
|
||
name = (raw_name or "").strip()[:64]
|
||
if not name or name in seen:
|
||
continue
|
||
seen.add(name)
|
||
tag = session.scalar(select(Tag).where(Tag.name == name))
|
||
if tag is None:
|
||
tag = Tag(name=name)
|
||
session.add(tag)
|
||
session.flush()
|
||
tag_objs.append(tag)
|
||
screenshot.tags = tag_objs
|
||
|
||
|
||
def _sync_todos(
|
||
session,
|
||
screenshot: Screenshot,
|
||
todos: list[dict[str, str]],
|
||
) -> None:
|
||
"""以 AI 输出覆盖该截图未完成的 todos;保留用户已完成/搁置项。"""
|
||
existing = list(screenshot.todos)
|
||
for t in existing:
|
||
if t.status in (TodoStatus.DONE.value, TodoStatus.DROPPED.value):
|
||
continue
|
||
session.delete(t)
|
||
session.flush()
|
||
|
||
for item in todos:
|
||
title = (item.get("title") or "").strip()
|
||
if not title:
|
||
continue
|
||
session.add(
|
||
Todo(
|
||
screenshot_id=screenshot.id,
|
||
title=title[:512],
|
||
note=(item.get("note") or "")[:2000] or None,
|
||
kind=(item.get("kind") or "待办")[:32],
|
||
status=TodoStatus.PENDING.value,
|
||
)
|
||
)
|
||
session.flush()
|