Initial commit: snapAna 截图智能整理工具
包含 FastAPI 后端、React 前端、队列/OCR/标签/待办等完整功能。 Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -0,0 +1,484 @@
|
||||
"""单张截图的分析逻辑:OCR -> VLM -> 写回数据库。
|
||||
|
||||
设计要点:
|
||||
- 不在长时间网络调用期间持有 SQLite 写事务,避免 `database is locked`。
|
||||
- 把流程拆为「短事务(取配置/标记状态)」 -> 「无事务(OCR/VLM 网络调用)」
|
||||
-> 「短事务(写回结果)」。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.db import session_scope
|
||||
from app.core.logger import get_logger
|
||||
from app.core.path_utils import is_accessible_file, path_from_storage, path_is_under
|
||||
from app.models.category import Category, DEFAULT_CATEGORIES
|
||||
from app.models.meta import ScreenshotMeta
|
||||
from app.models.screenshot import ProcessStatus, Screenshot
|
||||
from app.models.setting import (
|
||||
DEFAULT_RECOGNITION_MODE,
|
||||
KEY_OCR_PROVIDER,
|
||||
KEY_RECOGNITION_MODE,
|
||||
KEY_VLM_PROVIDER,
|
||||
)
|
||||
from app.models.tag import Tag
|
||||
from app.models.todo import Todo, TodoStatus
|
||||
from app.models.watch_folder import WatchFolder
|
||||
from app.providers import (
|
||||
RECOGNITION_MODES,
|
||||
build_ocr_provider,
|
||||
build_vision_ocr,
|
||||
build_vlm_provider,
|
||||
)
|
||||
from app.providers.base import VLMResult
|
||||
from app.schemas.common import ProviderConfig
|
||||
from app.services.exif_utils import is_exif_location_tag
|
||||
from app.services.settings_store import get_provider_config, get_setting
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PreparedContext:
|
||||
"""从短事务中导出的、不依赖 ORM 会话的纯数据。"""
|
||||
|
||||
path: Path
|
||||
ocr_cfg: Optional[ProviderConfig]
|
||||
vlm_cfg: Optional[ProviderConfig]
|
||||
recognition_mode: str
|
||||
category_names: list[str]
|
||||
allow_upload: bool
|
||||
exists: bool
|
||||
|
||||
|
||||
async def analyze_screenshot_by_id(screenshot_id: int) -> None:
|
||||
"""对外入口:按 id 分析单张截图。
|
||||
|
||||
被 worker 调度。函数内部自己管理多个短事务。
|
||||
"""
|
||||
ctx = _prepare(screenshot_id)
|
||||
if ctx is None:
|
||||
return # 截图已被删除
|
||||
if not ctx.exists:
|
||||
_persist_missing(screenshot_id)
|
||||
return
|
||||
|
||||
ocr_provider = _safe_build(
|
||||
lambda c: build_ocr_provider(c, allow_upload=ctx.allow_upload),
|
||||
ctx.ocr_cfg if _use_traditional_ocr(ctx) else None,
|
||||
"OCR",
|
||||
)
|
||||
vlm_provider = _safe_build(build_vlm_provider, ctx.vlm_cfg, "VLM")
|
||||
|
||||
# ---- 文字识别阶段(在事务外执行)----
|
||||
ocr_text, ocr_status = await _extract_text(screenshot_id, ctx, ocr_provider)
|
||||
|
||||
# ---- VLM 阶段(事务外)----
|
||||
vlm_result: Optional[VLMResult] = None
|
||||
ai_status = ProcessStatus.SKIPPED.value
|
||||
vlm_error: Optional[Exception] = None
|
||||
if vlm_provider is not None:
|
||||
_mark_status(screenshot_id, ai=ProcessStatus.RUNNING.value)
|
||||
try:
|
||||
vlm_result = await vlm_provider.analyze(
|
||||
image_path=ctx.path,
|
||||
ocr_text=ocr_text,
|
||||
categories=ctx.category_names,
|
||||
allow_upload=ctx.allow_upload,
|
||||
)
|
||||
ai_status = ProcessStatus.DONE.value
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("VLM 失败 #%d: %s", screenshot_id, exc)
|
||||
ai_status = ProcessStatus.FAILED.value
|
||||
vlm_error = exc
|
||||
|
||||
# ---- 写回阶段(短事务)----
|
||||
_persist_result(
|
||||
screenshot_id=screenshot_id,
|
||||
ocr_text=ocr_text,
|
||||
ocr_status=ocr_status,
|
||||
ai_status=ai_status,
|
||||
vlm_result=vlm_result,
|
||||
)
|
||||
|
||||
if vlm_error is not None:
|
||||
raise vlm_error # 让 worker 决定重试
|
||||
|
||||
|
||||
async def analyze_ocr_only_by_id(screenshot_id: int) -> None:
|
||||
"""仅补跑 OCR/视觉识文,不改动 AI 分析结果。
|
||||
|
||||
用于 ai_status=done 但 ocr_status=failed 的截图。
|
||||
OCR 仍失败时抛异常,由 worker 按 max_retries 重试。
|
||||
"""
|
||||
ctx = _prepare(screenshot_id)
|
||||
if ctx is None:
|
||||
return
|
||||
if not ctx.exists:
|
||||
_persist_missing(screenshot_id)
|
||||
raise RuntimeError("截图文件丢失")
|
||||
|
||||
ocr_provider = _safe_build(
|
||||
lambda c: build_ocr_provider(c, allow_upload=ctx.allow_upload),
|
||||
ctx.ocr_cfg if _use_traditional_ocr(ctx) else None,
|
||||
"OCR",
|
||||
)
|
||||
|
||||
ocr_text, ocr_status = await _extract_text(screenshot_id, ctx, ocr_provider)
|
||||
_persist_ocr_only(screenshot_id, ocr_text, ocr_status)
|
||||
|
||||
if ocr_status == ProcessStatus.FAILED.value:
|
||||
raise RuntimeError("OCR 识别失败")
|
||||
|
||||
|
||||
# ---------------- 短事务工具 ---------------- #
|
||||
|
||||
|
||||
def _prepare(screenshot_id: int) -> Optional[_PreparedContext]:
|
||||
"""短事务:读取 Provider 配置、分类列表、敏感目录判定。"""
|
||||
with session_scope() as session:
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
return None
|
||||
image_path = path_from_storage(shot.path)
|
||||
exists = is_accessible_file(image_path)
|
||||
|
||||
ocr_cfg = _load_provider_config(session, KEY_OCR_PROVIDER)
|
||||
vlm_cfg = _load_provider_config(session, KEY_VLM_PROVIDER)
|
||||
mode = get_setting(session, KEY_RECOGNITION_MODE, DEFAULT_RECOGNITION_MODE)
|
||||
if mode not in RECOGNITION_MODES:
|
||||
mode = DEFAULT_RECOGNITION_MODE
|
||||
|
||||
categories = _ensure_default_categories(session)
|
||||
category_names = [c.name for c in categories]
|
||||
|
||||
allow_upload = not _is_sensitive(session, image_path)
|
||||
|
||||
return _PreparedContext(
|
||||
path=image_path,
|
||||
ocr_cfg=ocr_cfg,
|
||||
vlm_cfg=vlm_cfg,
|
||||
recognition_mode=mode,
|
||||
category_names=category_names,
|
||||
allow_upload=allow_upload,
|
||||
exists=exists,
|
||||
)
|
||||
|
||||
|
||||
def _use_traditional_ocr(ctx: _PreparedContext) -> bool:
|
||||
"""混合/传统模式下是否启用 OCR 区配置(排除 vision 类型,vision 单独处理)。"""
|
||||
if ctx.recognition_mode not in ("ocr", "hybrid"):
|
||||
return False
|
||||
if ctx.ocr_cfg is None or ctx.ocr_cfg.type in ("", "none", "disabled", "vision"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def _extract_text(
|
||||
screenshot_id: int,
|
||||
ctx: _PreparedContext,
|
||||
ocr_provider,
|
||||
) -> tuple[str, str]:
|
||||
"""按识别模式提取文字:传统 OCR / 视觉 AI / 混合。"""
|
||||
ocr_text = ""
|
||||
ocr_status = ProcessStatus.SKIPPED.value
|
||||
mode = ctx.recognition_mode
|
||||
|
||||
# 1) 传统 OCR(Tesseract / Paddle / HTTP)
|
||||
if ocr_provider is not None:
|
||||
_mark_status(screenshot_id, ocr=ProcessStatus.RUNNING.value)
|
||||
try:
|
||||
ocr_text = await ocr_provider.recognize(ctx.path)
|
||||
ocr_status = ProcessStatus.DONE.value
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("OCR 失败 #%d: %s", screenshot_id, exc)
|
||||
ocr_status = ProcessStatus.FAILED.value
|
||||
|
||||
# 2) 视觉 AI 识文
|
||||
need_vision = mode == "vision" or (
|
||||
mode == "hybrid" and not ocr_text.strip()
|
||||
)
|
||||
if mode == "ocr" and ctx.ocr_cfg and ctx.ocr_cfg.type == "vision":
|
||||
# 用户在 OCR 区选了「视觉模型识文」
|
||||
need_vision = True
|
||||
|
||||
if need_vision:
|
||||
vision_cfg = _pick_vision_config(ctx)
|
||||
vision = _safe_build(
|
||||
lambda c: build_vision_ocr(c, allow_upload=ctx.allow_upload),
|
||||
vision_cfg,
|
||||
"VisionOCR",
|
||||
)
|
||||
if vision is not None:
|
||||
_mark_status(screenshot_id, ocr=ProcessStatus.RUNNING.value)
|
||||
try:
|
||||
ocr_text = await vision.recognize(ctx.path)
|
||||
ocr_status = ProcessStatus.DONE.value
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("视觉识文失败 #%d: %s", screenshot_id, exc)
|
||||
if ocr_status != ProcessStatus.DONE.value:
|
||||
ocr_status = ProcessStatus.FAILED.value
|
||||
|
||||
return ocr_text, ocr_status
|
||||
|
||||
|
||||
def _pick_vision_config(ctx: _PreparedContext) -> Optional[ProviderConfig]:
|
||||
"""决定视觉识文用哪套配置:优先 OCR 区的 vision,否则 VLM 区。"""
|
||||
if ctx.ocr_cfg and ctx.ocr_cfg.type == "vision":
|
||||
return ctx.ocr_cfg
|
||||
if ctx.recognition_mode == "vision" or ctx.recognition_mode == "hybrid":
|
||||
return ctx.vlm_cfg
|
||||
return ctx.vlm_cfg
|
||||
|
||||
|
||||
def _mark_status(
|
||||
screenshot_id: int,
|
||||
ocr: Optional[str] = None,
|
||||
ai: Optional[str] = None,
|
||||
) -> None:
|
||||
"""短事务:把截图标记为 running,方便前端看到进度。"""
|
||||
if ocr is None and ai is None:
|
||||
return
|
||||
with session_scope() as session:
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
return
|
||||
if ocr is not None:
|
||||
shot.ocr_status = ocr
|
||||
if ai is not None:
|
||||
shot.ai_status = ai
|
||||
|
||||
|
||||
def _persist_missing(screenshot_id: int) -> None:
|
||||
"""短事务:标记文件已丢失。"""
|
||||
with session_scope() as session:
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
return
|
||||
shot.ocr_status = ProcessStatus.FAILED.value
|
||||
shot.ai_status = ProcessStatus.FAILED.value
|
||||
meta = _get_or_create_meta(session, screenshot_id)
|
||||
meta.ai_summary = "(文件丢失)"
|
||||
|
||||
|
||||
def _persist_result(
|
||||
screenshot_id: int,
|
||||
ocr_text: str,
|
||||
ocr_status: str,
|
||||
ai_status: str,
|
||||
vlm_result: Optional[VLMResult],
|
||||
) -> None:
|
||||
"""短事务:把 OCR/VLM 结果写回 DB,包括 meta/tags/category/todos。"""
|
||||
with session_scope() as session:
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
return
|
||||
shot.ocr_status = ocr_status
|
||||
shot.ai_status = ai_status
|
||||
|
||||
meta = _get_or_create_meta(session, screenshot_id)
|
||||
meta.ocr_text = ocr_text or None
|
||||
|
||||
if vlm_result is not None:
|
||||
meta.ai_title = vlm_result.title or None
|
||||
meta.ai_summary = vlm_result.summary or None
|
||||
meta.ai_suggestion = vlm_result.suggestion or None
|
||||
meta.ai_raw_json = json.dumps(vlm_result.raw, ensure_ascii=False)
|
||||
|
||||
categories = list(session.scalars(select(Category)).all())
|
||||
category = _resolve_category(session, vlm_result.category, categories)
|
||||
if category is not None:
|
||||
shot.category_id = category.id
|
||||
|
||||
_sync_tags(session, shot, vlm_result.tags)
|
||||
_sync_todos(session, shot, vlm_result.todos)
|
||||
|
||||
|
||||
def _persist_ocr_only(screenshot_id: int, ocr_text: str, ocr_status: str) -> None:
|
||||
"""短事务:仅写回 OCR 文本与状态,保留已有 AI 字段。"""
|
||||
with session_scope() as session:
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
return
|
||||
shot.ocr_status = ocr_status
|
||||
meta = _get_or_create_meta(session, screenshot_id)
|
||||
meta.ocr_text = ocr_text or None
|
||||
|
||||
|
||||
def enqueue_ocr_jobs(*, limit: int = 500) -> int:
|
||||
"""为「AI 已成功、OCR 失败」的截图批量创建 OCR 补跑任务。
|
||||
|
||||
跳过已有 pending/running 的 ocr 任务,避免重复入队。
|
||||
"""
|
||||
from app.models.job import Job, JobKind, JobStatus
|
||||
|
||||
active_status = (JobStatus.PENDING.value, JobStatus.RUNNING.value)
|
||||
created = 0
|
||||
with session_scope() as session:
|
||||
# 已有活跃 OCR 任务的 screenshot_id
|
||||
busy_ids = set(
|
||||
session.scalars(
|
||||
select(Job.screenshot_id).where(
|
||||
Job.kind == JobKind.OCR.value,
|
||||
Job.status.in_(active_status),
|
||||
)
|
||||
).all()
|
||||
)
|
||||
shots = session.scalars(
|
||||
select(Screenshot)
|
||||
.where(
|
||||
Screenshot.ocr_status == ProcessStatus.FAILED.value,
|
||||
Screenshot.ai_status == ProcessStatus.DONE.value,
|
||||
)
|
||||
.order_by(Screenshot.id.asc())
|
||||
.limit(limit)
|
||||
).all()
|
||||
for shot in shots:
|
||||
if shot.id in busy_ids:
|
||||
continue
|
||||
session.add(
|
||||
Job(
|
||||
screenshot_id=shot.id,
|
||||
kind=JobKind.OCR.value,
|
||||
status=JobStatus.PENDING.value,
|
||||
)
|
||||
)
|
||||
busy_ids.add(shot.id)
|
||||
created += 1
|
||||
return created
|
||||
|
||||
|
||||
# ---------------- 内部辅助 ---------------- #
|
||||
|
||||
|
||||
def _load_provider_config(session, key: str) -> Optional[ProviderConfig]:
|
||||
raw = get_provider_config(session, key)
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
return ProviderConfig(**raw)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Provider 配置 %s 解析失败: %s", key, exc)
|
||||
return None
|
||||
|
||||
|
||||
def _safe_build(builder, cfg: Optional[ProviderConfig], label: str):
|
||||
if cfg is None:
|
||||
return None
|
||||
try:
|
||||
return builder(cfg)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("%s Provider 构造失败: %s", label, exc)
|
||||
return None
|
||||
|
||||
|
||||
def _is_sensitive(session, image_path: Path) -> bool:
|
||||
"""判断文件是否落在某个标记为敏感的监听目录内。"""
|
||||
sensitive_dirs = session.scalars(
|
||||
select(WatchFolder.path).where(WatchFolder.is_sensitive.is_(True))
|
||||
).all()
|
||||
child = str(image_path)
|
||||
for d in sensitive_dirs:
|
||||
if path_is_under(d, child):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_or_create_meta(session, screenshot_id: int) -> ScreenshotMeta:
|
||||
meta = session.get(ScreenshotMeta, screenshot_id)
|
||||
if meta is None:
|
||||
meta = ScreenshotMeta(screenshot_id=screenshot_id)
|
||||
session.add(meta)
|
||||
session.flush()
|
||||
return meta
|
||||
|
||||
|
||||
def ensure_default_categories() -> None:
|
||||
"""对外暴露:启动时 seed 默认分类。"""
|
||||
with session_scope() as session:
|
||||
_ensure_default_categories(session)
|
||||
|
||||
|
||||
def _ensure_default_categories(session) -> list[Category]:
|
||||
"""首次运行时灌入默认分类,返回最新列表。"""
|
||||
existing = session.scalars(select(Category)).all()
|
||||
if existing:
|
||||
return list(existing)
|
||||
for item in DEFAULT_CATEGORIES:
|
||||
session.add(Category(**item))
|
||||
session.flush()
|
||||
return list(session.scalars(select(Category)).all())
|
||||
|
||||
|
||||
def _resolve_category(
|
||||
session,
|
||||
name: str | None,
|
||||
categories: list[Category],
|
||||
) -> Optional[Category]:
|
||||
if not name:
|
||||
return None
|
||||
normalized = name.strip()
|
||||
for c in categories:
|
||||
if c.name == normalized or c.name in normalized or normalized in c.name:
|
||||
return c
|
||||
new_cat = Category(name=normalized[:64], color=None, prompt_hint=None)
|
||||
session.add(new_cat)
|
||||
session.flush()
|
||||
categories.append(new_cat)
|
||||
return new_cat
|
||||
|
||||
|
||||
def _sync_tags(session, screenshot: Screenshot, tag_names: list[str]) -> None:
|
||||
"""根据 AI 给的标签名同步多对多关系;保留 EXIF 地点标签不被覆盖。"""
|
||||
exif_tags = [t for t in (screenshot.tags or []) if is_exif_location_tag(t.name)]
|
||||
exif_names = {t.name for t in exif_tags}
|
||||
|
||||
seen: set[str] = set(exif_names)
|
||||
tag_objs: list[Tag] = list(exif_tags)
|
||||
for raw_name in tag_names:
|
||||
name = (raw_name or "").strip()[:64]
|
||||
if not name or name in seen:
|
||||
continue
|
||||
seen.add(name)
|
||||
tag = session.scalar(select(Tag).where(Tag.name == name))
|
||||
if tag is None:
|
||||
tag = Tag(name=name)
|
||||
session.add(tag)
|
||||
session.flush()
|
||||
tag_objs.append(tag)
|
||||
screenshot.tags = tag_objs
|
||||
|
||||
|
||||
def _sync_todos(
|
||||
session,
|
||||
screenshot: Screenshot,
|
||||
todos: list[dict[str, str]],
|
||||
) -> None:
|
||||
"""以 AI 输出覆盖该截图未完成的 todos;保留用户已完成/搁置项。"""
|
||||
existing = list(screenshot.todos)
|
||||
for t in existing:
|
||||
if t.status in (TodoStatus.DONE.value, TodoStatus.DROPPED.value):
|
||||
continue
|
||||
session.delete(t)
|
||||
session.flush()
|
||||
|
||||
for item in todos:
|
||||
title = (item.get("title") or "").strip()
|
||||
if not title:
|
||||
continue
|
||||
session.add(
|
||||
Todo(
|
||||
screenshot_id=screenshot.id,
|
||||
title=title[:512],
|
||||
note=(item.get("note") or "")[:2000] or None,
|
||||
kind=(item.get("kind") or "待办")[:32],
|
||||
status=TodoStatus.PENDING.value,
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
@@ -0,0 +1,107 @@
|
||||
"""从图片 EXIF 提取拍摄时间与 GPS 地点标签。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from fractions import Fraction
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from PIL import ExifTags, Image
|
||||
|
||||
from app.core.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# EXIF 地点类标签前缀,重分析时保留不被 AI 覆盖
|
||||
EXIF_TAG_PREFIX = "地点:"
|
||||
|
||||
|
||||
def _ratio_to_float(value) -> float:
|
||||
"""EXIF 有理数 → float。"""
|
||||
if isinstance(value, tuple) and len(value) == 2:
|
||||
num, den = value
|
||||
return float(num) / float(den) if den else 0.0
|
||||
if isinstance(value, Fraction):
|
||||
return float(value)
|
||||
return float(value)
|
||||
|
||||
|
||||
def _dms_to_decimal(dms: tuple, ref: str) -> Optional[float]:
|
||||
"""度分秒 → 十进制度。"""
|
||||
try:
|
||||
deg, minutes, seconds = dms
|
||||
decimal = _ratio_to_float(deg) + _ratio_to_float(minutes) / 60 + _ratio_to_float(seconds) / 3600
|
||||
if ref in ("S", "W"):
|
||||
decimal = -decimal
|
||||
return round(decimal, 6)
|
||||
except (TypeError, ValueError, ZeroDivisionError):
|
||||
return None
|
||||
|
||||
|
||||
def extract_image_metadata(path: Path) -> tuple[Optional[datetime], list[str]]:
|
||||
"""读取 EXIF,返回 (拍摄时间, 地点标签列表)。"""
|
||||
captured: Optional[datetime] = None
|
||||
location_tags: list[str] = []
|
||||
|
||||
try:
|
||||
with Image.open(path) as img:
|
||||
exif = img.getexif()
|
||||
if not exif:
|
||||
return None, []
|
||||
|
||||
# 拍摄时间:优先 DateTimeOriginal
|
||||
for key in (36867, 36868, 306): # DateTimeOriginal / DateTimeDigitized / DateTime
|
||||
raw = exif.get(key)
|
||||
if raw:
|
||||
captured = _parse_exif_datetime(str(raw))
|
||||
if captured:
|
||||
break
|
||||
|
||||
# GPS → 地点标签
|
||||
gps_ifd = exif.get_ifd(ExifTags.IFD.GPSInfo) if hasattr(exif, "get_ifd") else None
|
||||
if gps_ifd:
|
||||
lat = _dms_to_decimal(
|
||||
gps_ifd.get(2),
|
||||
gps_ifd.get(1, "N"),
|
||||
)
|
||||
lon = _dms_to_decimal(
|
||||
gps_ifd.get(4),
|
||||
gps_ifd.get(3, "E"),
|
||||
)
|
||||
if lat is not None and lon is not None:
|
||||
location_tags.append(f"{EXIF_TAG_PREFIX}{lat},{lon}")
|
||||
|
||||
# 部分设备写入可读地名(XP Keywords / ImageDescription 等)
|
||||
for key, val in exif.items():
|
||||
tag_name = ExifTags.TAGS.get(key, "")
|
||||
if tag_name in ("ImageDescription", "XPComment") and val:
|
||||
text = str(val).strip()[:64]
|
||||
if text and _looks_like_place(text):
|
||||
location_tags.append(f"{EXIF_TAG_PREFIX}{text}")
|
||||
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("读取 EXIF 失败 %s: %s", path.name, exc)
|
||||
|
||||
return captured, location_tags
|
||||
|
||||
|
||||
def _parse_exif_datetime(raw: str) -> Optional[datetime]:
|
||||
"""解析 EXIF 时间字符串。"""
|
||||
for fmt in ("%Y:%m:%d %H:%M:%S", "%Y-%m-%d %H:%M:%S"):
|
||||
try:
|
||||
return datetime.strptime(raw.strip(), fmt)
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _looks_like_place(text: str) -> bool:
|
||||
"""粗判字符串是否像地名(含中文或常见地址关键词)。"""
|
||||
keywords = ("市", "省", "区", "县", "镇", "村", "路", "街", "国", "GPS")
|
||||
return any(k in text for k in keywords) or any("\u4e00" <= c <= "\u9fff" for c in text)
|
||||
|
||||
|
||||
def is_exif_location_tag(name: str) -> bool:
|
||||
"""是否为 EXIF 自动写入的地点标签。"""
|
||||
return name.startswith(EXIF_TAG_PREFIX)
|
||||
@@ -0,0 +1,147 @@
|
||||
"""将磁盘上的截图文件入库 + 排队分析。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from PIL import Image
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.path_utils import (
|
||||
is_accessible_dir,
|
||||
is_accessible_file,
|
||||
path_from_storage,
|
||||
path_to_storage,
|
||||
)
|
||||
from app.core.logger import get_logger
|
||||
from app.models.job import Job, JobKind, JobStatus
|
||||
from app.models.screenshot import ProcessStatus, Screenshot
|
||||
from app.models.tag import Tag
|
||||
from app.services.exif_utils import extract_image_metadata
|
||||
from app.services.thumbnail import file_hash, generate_thumbnail, is_supported
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def ingest_path(session: Session, path: Path) -> Optional[Screenshot]:
|
||||
"""单文件入库。返回 Screenshot 或 None(不支持/重复时)。"""
|
||||
if not is_accessible_file(path) or not path.is_file():
|
||||
return None
|
||||
if not is_supported(path):
|
||||
return None
|
||||
|
||||
stored_path = path_to_storage(path)
|
||||
|
||||
try:
|
||||
digest = file_hash(path)
|
||||
except OSError as exc:
|
||||
logger.warning("无法读取文件 %s: %s", path, exc)
|
||||
return None
|
||||
|
||||
existing = session.scalar(select(Screenshot).where(Screenshot.file_hash == digest))
|
||||
if existing:
|
||||
# 同一内容重命名/移动:更新路径
|
||||
if existing.path != stored_path:
|
||||
existing.path = stored_path
|
||||
session.flush()
|
||||
return existing
|
||||
|
||||
try:
|
||||
with Image.open(path) as img:
|
||||
width, height = img.size
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("无法读取图片尺寸 %s: %s", path, exc)
|
||||
width, height = 0, 0
|
||||
|
||||
stat = path.stat()
|
||||
captured_at = datetime.fromtimestamp(stat.st_mtime)
|
||||
exif_time, location_tags = extract_image_metadata(path)
|
||||
if exif_time is not None:
|
||||
captured_at = exif_time
|
||||
|
||||
try:
|
||||
thumb = generate_thumbnail(path)
|
||||
thumb_path = thumb.as_posix()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("生成缩略图失败 %s: %s", path, exc)
|
||||
thumb_path = None
|
||||
|
||||
shot = Screenshot(
|
||||
path=stored_path,
|
||||
file_hash=digest,
|
||||
width=width,
|
||||
height=height,
|
||||
size=stat.st_size,
|
||||
captured_at=captured_at,
|
||||
thumb_path=thumb_path,
|
||||
ocr_status=ProcessStatus.PENDING.value,
|
||||
ai_status=ProcessStatus.PENDING.value,
|
||||
)
|
||||
session.add(shot)
|
||||
session.flush()
|
||||
|
||||
if location_tags:
|
||||
_attach_location_tags(session, shot, location_tags)
|
||||
|
||||
job = Job(screenshot_id=shot.id, kind=JobKind.FULL.value, status=JobStatus.PENDING.value)
|
||||
session.add(job)
|
||||
logger.info("入库 #%d %s", shot.id, path.name)
|
||||
return shot
|
||||
|
||||
|
||||
def _attach_location_tags(session: Session, shot: Screenshot, tag_names: list[str]) -> None:
|
||||
"""入库时写入 EXIF 地点标签。"""
|
||||
tag_objs: list[Tag] = []
|
||||
for raw in tag_names:
|
||||
name = (raw or "").strip()[:64]
|
||||
if not name:
|
||||
continue
|
||||
tag = session.scalar(select(Tag).where(Tag.name == name))
|
||||
if tag is None:
|
||||
tag = Tag(name=name)
|
||||
session.add(tag)
|
||||
session.flush()
|
||||
tag_objs.append(tag)
|
||||
shot.tags = tag_objs
|
||||
|
||||
|
||||
def ingest_directory(
|
||||
session: Session,
|
||||
root: Path | str,
|
||||
recursive: bool = True,
|
||||
) -> tuple[int, int]:
|
||||
"""遍历目录入库。返回 (新增数, 跳过数)。支持 UNC 网络路径。"""
|
||||
root_p = path_from_storage(str(root)) if isinstance(root, str) else root
|
||||
if not is_accessible_dir(root_p):
|
||||
return 0, 0
|
||||
|
||||
iterator: Iterable[Path]
|
||||
if recursive:
|
||||
iterator = (p for p in root_p.rglob("*") if p.is_file())
|
||||
else:
|
||||
iterator = (p for p in root_p.iterdir() if p.is_file())
|
||||
|
||||
added, skipped = 0, 0
|
||||
for path in iterator:
|
||||
if not is_supported(path):
|
||||
continue
|
||||
stored = path_to_storage(path)
|
||||
before = session.scalar(
|
||||
select(Screenshot.id).where(Screenshot.path == stored)
|
||||
)
|
||||
result = ingest_path(session, path)
|
||||
if result is None:
|
||||
skipped += 1
|
||||
continue
|
||||
if before is None:
|
||||
added += 1
|
||||
else:
|
||||
skipped += 1
|
||||
# 批量提交,避免巨型事务
|
||||
if (added + skipped) % 50 == 0:
|
||||
session.commit()
|
||||
session.commit()
|
||||
return added, skipped
|
||||
@@ -0,0 +1,207 @@
|
||||
"""Provider 连通性测试:OCR / 视觉 AI。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from app.providers import build_ocr_provider
|
||||
from app.schemas.common import ProviderConfig
|
||||
|
||||
|
||||
class ProviderTestError(Exception):
|
||||
"""测试失败,携带用户可读信息。"""
|
||||
|
||||
|
||||
async def test_provider_config(key: str, cfg: ProviderConfig) -> dict[str, Any]:
|
||||
"""测试 OCR 或 VLM Provider 连通性,返回 {ok, message, detail, latency_ms}。"""
|
||||
started = time.perf_counter()
|
||||
try:
|
||||
if cfg.type in ("", "none", "disabled"):
|
||||
raise ProviderTestError("当前类型为「不使用」,无需测试")
|
||||
|
||||
if key == KEY_OCR:
|
||||
message, detail = await _test_ocr(cfg)
|
||||
elif key == KEY_VLM:
|
||||
message, detail = await _test_vlm(cfg)
|
||||
else:
|
||||
raise ProviderTestError(f"未知配置键: {key}")
|
||||
|
||||
latency = int((time.perf_counter() - started) * 1000)
|
||||
return {"ok": True, "message": message, "detail": detail, "latency_ms": latency}
|
||||
except ProviderTestError as exc:
|
||||
latency = int((time.perf_counter() - started) * 1000)
|
||||
return {"ok": False, "message": str(exc), "detail": None, "latency_ms": latency}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
latency = int((time.perf_counter() - started) * 1000)
|
||||
return {
|
||||
"ok": False,
|
||||
"message": f"测试失败: {exc}",
|
||||
"detail": repr(exc),
|
||||
"latency_ms": latency,
|
||||
}
|
||||
|
||||
|
||||
KEY_OCR = "ocr_provider"
|
||||
KEY_VLM = "vlm_provider"
|
||||
|
||||
# 1x1 白图,用于 HTTP OCR / 视觉测试
|
||||
_TINY_PNG_B64 = (
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
|
||||
)
|
||||
|
||||
|
||||
async def _test_ocr(cfg: ProviderConfig) -> tuple[str, str | None]:
|
||||
if cfg.type == "tesseract":
|
||||
return await _test_tesseract(cfg)
|
||||
if cfg.type == "paddleocr":
|
||||
return await _test_paddle(cfg)
|
||||
if cfg.type == "http":
|
||||
return await _test_http_ocr(cfg)
|
||||
if cfg.type == "vision":
|
||||
return await _test_openai_compat(cfg, label="视觉 OCR")
|
||||
raise ProviderTestError(f"不支持的 OCR 类型: {cfg.type}")
|
||||
|
||||
|
||||
async def _test_vlm(cfg: ProviderConfig) -> tuple[str, str | None]:
|
||||
if cfg.type in ("openai_compat", "openai", "ollama", "glm", "minimax", "moonshot", "vision"):
|
||||
return await _test_openai_compat(cfg, label="视觉 AI")
|
||||
raise ProviderTestError(f"不支持的 VLM 类型: {cfg.type}")
|
||||
|
||||
|
||||
async def _test_tesseract(cfg: ProviderConfig) -> tuple[str, str | None]:
|
||||
provider = build_ocr_provider(cfg, allow_upload=True)
|
||||
if provider is None:
|
||||
raise ProviderTestError("无法构造 Tesseract Provider")
|
||||
|
||||
def _check() -> str:
|
||||
import pytesseract
|
||||
|
||||
if cfg.extra.get("cmd"):
|
||||
pytesseract.pytesseract.tesseract_cmd = cfg.extra["cmd"]
|
||||
version = pytesseract.get_tesseract_version()
|
||||
return str(version)
|
||||
|
||||
version = await asyncio.to_thread(_check)
|
||||
return f"Tesseract 可用,版本 {version}", f"lang={cfg.extra.get('lang', 'chi_sim+eng')}"
|
||||
|
||||
|
||||
async def _test_paddle(cfg: ProviderConfig) -> tuple[str, str | None]:
|
||||
def _check() -> str:
|
||||
try:
|
||||
import paddleocr # noqa: F401
|
||||
except ImportError as exc:
|
||||
raise ProviderTestError(
|
||||
"未安装 PaddleOCR,请执行: pip install paddleocr paddlepaddle"
|
||||
) from exc
|
||||
return "PaddleOCR 模块已安装"
|
||||
|
||||
detail = await asyncio.to_thread(_check)
|
||||
provider = build_ocr_provider(cfg, allow_upload=True)
|
||||
if provider is None:
|
||||
raise ProviderTestError("无法构造 PaddleOCR Provider")
|
||||
return "PaddleOCR 可用", detail
|
||||
|
||||
|
||||
async def _test_http_ocr(cfg: ProviderConfig) -> tuple[str, str | None]:
|
||||
if not cfg.base_url:
|
||||
raise ProviderTestError("请填写 OCR API URL")
|
||||
provider = build_ocr_provider(cfg, allow_upload=True)
|
||||
if provider is None:
|
||||
raise ProviderTestError("无法构造 HTTP OCR Provider")
|
||||
|
||||
# 写入临时 tiny png 再调用
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
||||
tmp.write(base64.b64decode(_TINY_PNG_B64))
|
||||
tmp_path = Path(tmp.name)
|
||||
try:
|
||||
text = await provider.recognize(tmp_path)
|
||||
finally:
|
||||
try:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
preview = (text or "").strip()[:80] or "(空响应,但接口可达)"
|
||||
return "HTTP OCR 接口可达", f"响应预览: {preview}"
|
||||
|
||||
|
||||
async def _test_openai_compat(cfg: ProviderConfig, *, label: str) -> tuple[str, str | None]:
|
||||
base_url = (cfg.base_url or "http://localhost:11434/v1").rstrip("/")
|
||||
api_key = cfg.api_key or ""
|
||||
model = cfg.model or "gpt-4o-mini"
|
||||
timeout = float(cfg.extra.get("timeout", 30))
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# 1) 尝试 /models(Ollama、OpenAI 兼容)
|
||||
models_url = f"{base_url}/models"
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
try:
|
||||
resp = await client.get(models_url, headers=headers)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
ids = _extract_model_ids(data)
|
||||
if model and ids and model not in ids:
|
||||
return (
|
||||
f"{label} 服务可达",
|
||||
f"已连接 /models,但未找到模型「{model}」。可用: {', '.join(ids[:8])}",
|
||||
)
|
||||
return f"{label} 服务可达", f"已连接 /models,目标模型: {model}"
|
||||
except httpx.HTTPError:
|
||||
pass
|
||||
|
||||
# 2) 最小 chat 探活
|
||||
chat_url = f"{base_url}/chat/completions"
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": "请只回复 OK"}],
|
||||
"max_tokens": 16,
|
||||
"temperature": 0,
|
||||
}
|
||||
try:
|
||||
resp = await client.post(chat_url, json=payload, headers=headers)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
return f"{label} 对话成功", f"模型 {model} 回复: {str(content).strip()[:60]}"
|
||||
except httpx.HTTPStatusError as exc:
|
||||
body = exc.response.text[:200]
|
||||
raise ProviderTestError(
|
||||
f"API 返回 {exc.response.status_code}: {body}"
|
||||
) from exc
|
||||
except httpx.HTTPError as exc:
|
||||
raise ProviderTestError(f"无法连接 {base_url}: {exc}") from exc
|
||||
|
||||
|
||||
def _extract_model_ids(data: Any) -> list[str]:
|
||||
"""从 /models 响应中提取 model id 列表。"""
|
||||
if not isinstance(data, dict):
|
||||
return []
|
||||
items = data.get("data") or data.get("models") or []
|
||||
ids: list[str] = []
|
||||
if isinstance(items, list):
|
||||
for item in items:
|
||||
if isinstance(item, dict):
|
||||
mid = item.get("id") or item.get("name") or item.get("model")
|
||||
if mid:
|
||||
ids.append(str(mid))
|
||||
elif isinstance(item, str):
|
||||
ids.append(item)
|
||||
return ids
|
||||
|
||||
|
||||
def merge_provider_api_key(cfg: ProviderConfig, existing: dict | None) -> ProviderConfig:
|
||||
"""测试时若 api_key 为空,合并已保存的 key。"""
|
||||
payload = cfg.model_dump()
|
||||
if (not payload.get("api_key")) and isinstance(existing, dict):
|
||||
payload["api_key"] = existing.get("api_key", "")
|
||||
return ProviderConfig(**payload)
|
||||
@@ -0,0 +1,67 @@
|
||||
"""截图列表搜索:FTS + 子串模糊(兼容中文标签/标题)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import or_, select, text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.meta import ScreenshotMeta
|
||||
from app.models.screenshot import Screenshot
|
||||
from app.models.tag import Tag
|
||||
|
||||
|
||||
def fts_query_string(raw: str) -> str:
|
||||
"""把用户输入处理成 FTS5 查询串(中英文均支持前缀匹配)。"""
|
||||
parts = [p for p in raw.replace("\n", " ").split() if p]
|
||||
if not parts:
|
||||
return raw
|
||||
cleaned: list[str] = []
|
||||
for p in parts:
|
||||
p = p.replace('"', "").strip()
|
||||
if not p:
|
||||
continue
|
||||
cleaned.append(f'"{p}"*')
|
||||
return " ".join(cleaned)
|
||||
|
||||
|
||||
def collect_search_ids(session: Session, q: str, *, limit: int = 5000) -> set[int]:
|
||||
"""联合 FTS5 与 LIKE 子串搜索,返回匹配的 screenshot id 集合。"""
|
||||
q = q.strip()
|
||||
if not q:
|
||||
return set()
|
||||
|
||||
ids: set[int] = set()
|
||||
like = f"%{q}%"
|
||||
|
||||
# 1) FTS5 全文索引
|
||||
try:
|
||||
fts_sql = text(
|
||||
"SELECT rowid FROM screenshots_fts WHERE screenshots_fts MATCH :q LIMIT :lim"
|
||||
)
|
||||
rows = session.execute(fts_sql, {"q": fts_query_string(q), "lim": limit}).fetchall()
|
||||
ids.update(int(row[0]) for row in rows)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2) 子串模糊:OCR/AI 文本(解决「三花」匹配「三花猫」)
|
||||
meta_ids = session.scalars(
|
||||
select(ScreenshotMeta.screenshot_id).where(
|
||||
or_(
|
||||
ScreenshotMeta.ocr_text.ilike(like),
|
||||
ScreenshotMeta.ai_title.ilike(like),
|
||||
ScreenshotMeta.ai_summary.ilike(like),
|
||||
ScreenshotMeta.ai_suggestion.ilike(like),
|
||||
)
|
||||
).limit(limit)
|
||||
).all()
|
||||
ids.update(int(i) for i in meta_ids)
|
||||
|
||||
# 3) 标签名子串匹配
|
||||
tag_ids = session.scalars(
|
||||
select(Screenshot.id)
|
||||
.join(Screenshot.tags)
|
||||
.where(Tag.name.ilike(like))
|
||||
.limit(limit)
|
||||
).all()
|
||||
ids.update(int(i) for i in tag_ids)
|
||||
|
||||
return ids
|
||||
@@ -0,0 +1,50 @@
|
||||
"""读取/写入键值设置。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.setting import Setting
|
||||
|
||||
|
||||
def get_setting(session: Session, key: str, default: Any = None) -> Any:
|
||||
"""读取并 JSON 解析。"""
|
||||
row = session.get(Setting, key)
|
||||
if row is None:
|
||||
return default
|
||||
try:
|
||||
return json.loads(row.value_json)
|
||||
except json.JSONDecodeError:
|
||||
return default
|
||||
|
||||
|
||||
def set_setting(session: Session, key: str, value: Any) -> None:
|
||||
"""JSON 序列化后落库(upsert)。"""
|
||||
row = session.get(Setting, key)
|
||||
payload = json.dumps(value, ensure_ascii=False)
|
||||
if row is None:
|
||||
session.add(Setting(key=key, value_json=payload))
|
||||
else:
|
||||
row.value_json = payload
|
||||
session.flush()
|
||||
|
||||
|
||||
def all_settings(session: Session) -> dict[str, Any]:
|
||||
"""返回所有设置,给前端调试 / 导出。"""
|
||||
items: dict[str, Any] = {}
|
||||
for row in session.query(Setting).all():
|
||||
try:
|
||||
items[row.key] = json.loads(row.value_json)
|
||||
except json.JSONDecodeError:
|
||||
items[row.key] = row.value_json
|
||||
return items
|
||||
|
||||
|
||||
def get_provider_config(session: Session, key: str) -> Optional[dict[str, Any]]:
|
||||
"""便捷读取 OCR/VLM provider 配置 dict。"""
|
||||
value = get_setting(session, key, None)
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
return None
|
||||
@@ -0,0 +1,48 @@
|
||||
"""生成并缓存缩略图。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.path_utils import path_to_storage
|
||||
|
||||
|
||||
SUPPORTED_EXTS = {".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp", ".tif", ".tiff"}
|
||||
|
||||
|
||||
def is_supported(path: Path) -> bool:
|
||||
"""是否为我们支持的图片格式。"""
|
||||
return path.suffix.lower() in SUPPORTED_EXTS
|
||||
|
||||
|
||||
def generate_thumbnail(image_path: Path, max_side: int | None = None) -> Path:
|
||||
"""生成 webp 缩略图,落到缓存目录。返回缓存路径。"""
|
||||
max_side = max_side or settings.thumb_size
|
||||
# 用文件路径 + mtime 哈希作为缓存键,源文件变化会自动生成新缩略图
|
||||
stat = image_path.stat()
|
||||
key = hashlib.md5(
|
||||
f"{path_to_storage(image_path)}|{stat.st_mtime_ns}|{max_side}".encode("utf-8")
|
||||
).hexdigest()
|
||||
out = settings.thumb_dir / f"{key}.webp"
|
||||
if out.exists():
|
||||
return out
|
||||
with Image.open(image_path) as img:
|
||||
img = img.convert("RGB")
|
||||
img.thumbnail((max_side, max_side), Image.LANCZOS)
|
||||
img.save(out, format="WEBP", quality=80)
|
||||
return out
|
||||
|
||||
|
||||
def file_hash(image_path: Path, chunk: int = 1024 * 1024) -> str:
|
||||
"""计算文件 sha256,用作去重键。"""
|
||||
h = hashlib.sha256()
|
||||
with open(image_path, "rb") as f:
|
||||
while True:
|
||||
data = f.read(chunk)
|
||||
if not data:
|
||||
break
|
||||
h.update(data)
|
||||
return h.hexdigest()
|
||||
@@ -0,0 +1,122 @@
|
||||
"""watchdog 监听被关注的目录。
|
||||
|
||||
中文路径与 OneDrive 同步盘下 NTFS 事件偶发不稳,因此默认使用 PollingObserver。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import select
|
||||
from watchdog.events import FileSystemEvent, FileSystemEventHandler
|
||||
from watchdog.observers.polling import PollingObserver
|
||||
|
||||
from app.core.db import session_scope
|
||||
from app.core.logger import get_logger
|
||||
from app.core.path_utils import is_accessible_dir, path_from_storage
|
||||
from app.models.watch_folder import WatchFolder
|
||||
from app.services.ingest import ingest_path
|
||||
from app.services.thumbnail import is_supported
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class _ScreenshotEventHandler(FileSystemEventHandler):
|
||||
"""新文件 -> 入库 -> 触发分析。"""
|
||||
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop, notify) -> None: # noqa: ANN001
|
||||
self._loop = loop
|
||||
self._notify = notify
|
||||
|
||||
def on_created(self, event: FileSystemEvent) -> None:
|
||||
if event.is_directory:
|
||||
return
|
||||
self._handle(Path(event.src_path))
|
||||
|
||||
def on_moved(self, event: FileSystemEvent) -> None:
|
||||
if event.is_directory:
|
||||
return
|
||||
self._handle(Path(getattr(event, "dest_path", event.src_path)))
|
||||
|
||||
def _handle(self, path: Path) -> None:
|
||||
if not is_supported(path):
|
||||
return
|
||||
# 等待写入完成(截图工具常会先创建空文件再写入)
|
||||
try:
|
||||
self._wait_file_ready(path)
|
||||
except FileNotFoundError:
|
||||
return
|
||||
with session_scope() as session:
|
||||
shot = ingest_path(session, path)
|
||||
if shot is not None:
|
||||
asyncio.run_coroutine_threadsafe(self._notify(), self._loop)
|
||||
|
||||
@staticmethod
|
||||
def _wait_file_ready(path: Path, retries: int = 10, interval: float = 0.3) -> None:
|
||||
"""轮询直至文件大小稳定。"""
|
||||
import time
|
||||
|
||||
last = -1
|
||||
for _ in range(retries):
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(path)
|
||||
size = path.stat().st_size
|
||||
if size > 0 and size == last:
|
||||
return
|
||||
last = size
|
||||
time.sleep(interval)
|
||||
|
||||
|
||||
class WatcherService:
|
||||
"""管理多个监听目录的生命周期。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._observer: PollingObserver | None = None
|
||||
self._lock = threading.Lock()
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._notify_cb = None
|
||||
|
||||
def start(self, loop: asyncio.AbstractEventLoop, notify) -> None: # noqa: ANN001
|
||||
"""根据数据库中的目录列表启动监听。"""
|
||||
with self._lock:
|
||||
self._loop = loop
|
||||
self._notify_cb = notify
|
||||
self._stop_locked()
|
||||
self._observer = PollingObserver(timeout=2.0)
|
||||
handler = _ScreenshotEventHandler(loop, notify)
|
||||
with session_scope() as session:
|
||||
folders = session.scalars(
|
||||
select(WatchFolder).where(WatchFolder.enabled.is_(True))
|
||||
).all()
|
||||
paths = [(f.path, f.recursive) for f in folders]
|
||||
for path, recursive in paths:
|
||||
p = path_from_storage(path)
|
||||
if not is_accessible_dir(p):
|
||||
logger.warning("监听目录不存在或不可访问,跳过: %s", path)
|
||||
continue
|
||||
logger.info("开始监听 %s (recursive=%s)", path, recursive)
|
||||
self._observer.schedule(handler, str(p), recursive=recursive)
|
||||
self._observer.start()
|
||||
|
||||
def reload(self) -> None:
|
||||
"""监听目录变更后重启。"""
|
||||
if self._loop is None or self._notify_cb is None:
|
||||
return
|
||||
self.start(self._loop, self._notify_cb)
|
||||
|
||||
def stop(self) -> None:
|
||||
with self._lock:
|
||||
self._stop_locked()
|
||||
|
||||
def _stop_locked(self) -> None:
|
||||
if self._observer is not None:
|
||||
try:
|
||||
self._observer.stop()
|
||||
self._observer.join(timeout=3)
|
||||
finally:
|
||||
self._observer = None
|
||||
|
||||
|
||||
watcher_service = WatcherService()
|
||||
@@ -0,0 +1,237 @@
|
||||
"""异步任务调度器:从 jobs 表取任务并并发执行。
|
||||
|
||||
事务规则:
|
||||
- 调度循环只用短事务 claim 任务、汇总状态。
|
||||
- 真正的 OCR/VLM 调用由 `analyze_screenshot_by_id` 自己管理短事务,
|
||||
绝不在 worker 这一层包裹长事务。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import case, func, or_, select
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.db import session_scope
|
||||
from app.core.logger import get_logger
|
||||
from app.models.job import Job, JobKind, JobStatus
|
||||
from app.models.screenshot import ProcessStatus, Screenshot
|
||||
from app.services.analyze import analyze_ocr_only_by_id, analyze_screenshot_by_id
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AnalyzeWorker:
|
||||
"""单实例后台 worker,负责把 jobs 表中的待处理项跑完。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._event = asyncio.Event()
|
||||
self._stop = False
|
||||
self._semaphore = asyncio.Semaphore(settings.analyze_concurrency)
|
||||
self._inflight: int = 0
|
||||
self._lock = asyncio.Lock()
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动主循环。"""
|
||||
# 启动时把上次中断的 RUNNING 任务复位
|
||||
with session_scope() as session:
|
||||
running = session.scalars(
|
||||
select(Job).where(Job.status == JobStatus.RUNNING.value)
|
||||
).all()
|
||||
for job in running:
|
||||
job.status = JobStatus.PENDING.value
|
||||
self._stop = False
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._task = asyncio.create_task(self._run(), name="analyze-worker")
|
||||
self.notify()
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._stop = True
|
||||
self._event.set()
|
||||
if self._task is not None:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def notify(self) -> None:
|
||||
"""在事件循环线程内通知 worker 有新任务可取。"""
|
||||
self._event.set()
|
||||
|
||||
def notify_threadsafe(self) -> None:
|
||||
"""跨线程唤醒 worker(FastAPI BackgroundTasks / watcher 线程)。
|
||||
|
||||
asyncio.Event.set() 本身不是线程安全的,必须通过
|
||||
loop.call_soon_threadsafe 调度回事件循环线程。
|
||||
"""
|
||||
loop = self._loop
|
||||
if loop is None or loop.is_closed():
|
||||
return
|
||||
loop.call_soon_threadsafe(self._event.set)
|
||||
|
||||
async def status(self) -> dict[str, int]:
|
||||
"""供 API 查询当前队列状况(按 status 索引计数,适合大批量)。"""
|
||||
with session_scope() as session:
|
||||
pending = session.scalar(
|
||||
select(Job.id)
|
||||
.where(Job.status == JobStatus.PENDING.value)
|
||||
.limit(1)
|
||||
)
|
||||
rows = session.execute(
|
||||
select(Job.status, func.count())
|
||||
.group_by(Job.status)
|
||||
).all()
|
||||
counts = {st.value: 0 for st in JobStatus}
|
||||
for status, cnt in rows:
|
||||
counts[status] = int(cnt)
|
||||
counts["inflight"] = self._inflight
|
||||
counts["has_more"] = 1 if pending is not None else 0
|
||||
return counts
|
||||
|
||||
def reset_stale_running(self, *, minutes: int = 5, reset_all: bool = False) -> int:
|
||||
"""把长时间 RUNNING 且无进展的任务复位为 PENDING。"""
|
||||
with session_scope() as session:
|
||||
q = select(Job).where(Job.status == JobStatus.RUNNING.value)
|
||||
if not reset_all:
|
||||
cutoff = datetime.utcnow() - timedelta(minutes=max(minutes, 1))
|
||||
q = q.where(Job.started_at.is_not(None), Job.started_at < cutoff)
|
||||
stale = session.scalars(q).all()
|
||||
for job in stale:
|
||||
job.status = JobStatus.PENDING.value
|
||||
job.started_at = None
|
||||
count = len(stale)
|
||||
if count:
|
||||
logger.info("复位 %d 条 RUNNING 任务为 PENDING", count)
|
||||
self.notify()
|
||||
return count
|
||||
|
||||
def retry_failed(self, job_ids: Optional[list[int]] = None) -> int:
|
||||
"""将 failed 任务重新排队。"""
|
||||
with session_scope() as session:
|
||||
q = select(Job).where(Job.status == JobStatus.FAILED.value)
|
||||
if job_ids:
|
||||
q = q.where(Job.id.in_(job_ids))
|
||||
failed = session.scalars(q).all()
|
||||
for job in failed:
|
||||
job.status = JobStatus.PENDING.value
|
||||
job.retries = 0
|
||||
job.last_error = None
|
||||
job.started_at = None
|
||||
job.finished_at = None
|
||||
count = len(failed)
|
||||
if count:
|
||||
logger.info("重试 %d 条 failed 任务", count)
|
||||
self.notify()
|
||||
return count
|
||||
|
||||
async def _run(self) -> None:
|
||||
"""主循环。"""
|
||||
idle_rounds = 0
|
||||
while not self._stop:
|
||||
job = self._claim_one()
|
||||
if job is None:
|
||||
idle_rounds += 1
|
||||
# 空闲时定期清理僵尸 RUNNING,避免 inflight=0 但 DB 仍显示 running
|
||||
if idle_rounds >= 3 and self._inflight == 0:
|
||||
idle_rounds = 0
|
||||
if self.reset_stale_running(minutes=5):
|
||||
continue
|
||||
self._event.clear()
|
||||
try:
|
||||
await asyncio.wait_for(self._event.wait(), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
continue
|
||||
idle_rounds = 0
|
||||
await self._semaphore.acquire()
|
||||
async with self._lock:
|
||||
self._inflight += 1
|
||||
asyncio.create_task(
|
||||
self._process(job["id"], job["screenshot_id"], job["kind"])
|
||||
)
|
||||
|
||||
def _claim_one(self) -> Optional[dict]:
|
||||
"""短事务:取一条 PENDING 任务;FULL 优先于 OCR 补跑。"""
|
||||
with session_scope() as session:
|
||||
job = session.scalar(
|
||||
select(Job)
|
||||
.where(
|
||||
Job.status == JobStatus.PENDING.value,
|
||||
or_(Job.retries < settings.max_retries, Job.retries.is_(None)),
|
||||
)
|
||||
.order_by(
|
||||
case(
|
||||
(Job.kind == JobKind.FULL.value, 0),
|
||||
(Job.kind == JobKind.VLM.value, 1),
|
||||
else_=2,
|
||||
),
|
||||
Job.id.asc(),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if job is None:
|
||||
return None
|
||||
job.status = JobStatus.RUNNING.value
|
||||
job.started_at = datetime.utcnow()
|
||||
session.flush()
|
||||
return {
|
||||
"id": job.id,
|
||||
"screenshot_id": job.screenshot_id,
|
||||
"kind": job.kind,
|
||||
}
|
||||
|
||||
async def _process(self, job_id: int, screenshot_id: int, kind: str) -> None:
|
||||
"""执行单个任务,所有 DB 写入均在短事务中。"""
|
||||
try:
|
||||
try:
|
||||
if kind == JobKind.OCR.value:
|
||||
await analyze_ocr_only_by_id(screenshot_id)
|
||||
else:
|
||||
await analyze_screenshot_by_id(screenshot_id)
|
||||
self._finish(job_id, success=True, kind=kind)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("分析失败 #%d (%s): %s", screenshot_id, kind, exc)
|
||||
self._finish(job_id, success=False, error=str(exc), kind=kind)
|
||||
finally:
|
||||
self._semaphore.release()
|
||||
async with self._lock:
|
||||
self._inflight -= 1
|
||||
self.notify()
|
||||
|
||||
def _finish(
|
||||
self,
|
||||
job_id: int,
|
||||
success: bool,
|
||||
error: Optional[str] = None,
|
||||
kind: str = JobKind.FULL.value,
|
||||
) -> None:
|
||||
"""短事务:更新 jobs 表完成状态。"""
|
||||
with session_scope() as session:
|
||||
job = session.get(Job, job_id)
|
||||
if job is None:
|
||||
return
|
||||
if success:
|
||||
job.status = JobStatus.DONE.value
|
||||
job.last_error = None
|
||||
else:
|
||||
job.retries = (job.retries or 0) + 1
|
||||
if job.retries >= settings.max_retries:
|
||||
job.status = JobStatus.FAILED.value
|
||||
# OCR 补跑失败不影响 ai_status
|
||||
if kind != JobKind.OCR.value:
|
||||
shot = session.get(Screenshot, job.screenshot_id)
|
||||
if shot is not None:
|
||||
shot.ai_status = ProcessStatus.FAILED.value
|
||||
else:
|
||||
job.status = JobStatus.PENDING.value
|
||||
job.last_error = (error or "")[:1000]
|
||||
job.finished_at = datetime.utcnow()
|
||||
|
||||
|
||||
worker = AnalyzeWorker()
|
||||
Reference in New Issue
Block a user