Initial commit: snapAna 截图智能整理工具

包含 FastAPI 后端、React 前端、队列/OCR/标签/待办等完整功能。

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
wjl
2026-05-27 15:45:50 +08:00
commit 5c028d7952
76 changed files with 10467 additions and 0 deletions
View File
+484
View File
@@ -0,0 +1,484 @@
"""单张截图的分析逻辑:OCR -> VLM -> 写回数据库。
设计要点:
- 不在长时间网络调用期间持有 SQLite 写事务,避免 `database is locked`。
- 把流程拆为「短事务(取配置/标记状态)」 -> 「无事务(OCR/VLM 网络调用)」
-> 「短事务(写回结果)」。
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from sqlalchemy import select
from app.core.db import session_scope
from app.core.logger import get_logger
from app.core.path_utils import is_accessible_file, path_from_storage, path_is_under
from app.models.category import Category, DEFAULT_CATEGORIES
from app.models.meta import ScreenshotMeta
from app.models.screenshot import ProcessStatus, Screenshot
from app.models.setting import (
DEFAULT_RECOGNITION_MODE,
KEY_OCR_PROVIDER,
KEY_RECOGNITION_MODE,
KEY_VLM_PROVIDER,
)
from app.models.tag import Tag
from app.models.todo import Todo, TodoStatus
from app.models.watch_folder import WatchFolder
from app.providers import (
RECOGNITION_MODES,
build_ocr_provider,
build_vision_ocr,
build_vlm_provider,
)
from app.providers.base import VLMResult
from app.schemas.common import ProviderConfig
from app.services.exif_utils import is_exif_location_tag
from app.services.settings_store import get_provider_config, get_setting
logger = get_logger(__name__)
@dataclass
class _PreparedContext:
"""从短事务中导出的、不依赖 ORM 会话的纯数据。"""
path: Path
ocr_cfg: Optional[ProviderConfig]
vlm_cfg: Optional[ProviderConfig]
recognition_mode: str
category_names: list[str]
allow_upload: bool
exists: bool
async def analyze_screenshot_by_id(screenshot_id: int) -> None:
"""对外入口:按 id 分析单张截图。
被 worker 调度。函数内部自己管理多个短事务。
"""
ctx = _prepare(screenshot_id)
if ctx is None:
return # 截图已被删除
if not ctx.exists:
_persist_missing(screenshot_id)
return
ocr_provider = _safe_build(
lambda c: build_ocr_provider(c, allow_upload=ctx.allow_upload),
ctx.ocr_cfg if _use_traditional_ocr(ctx) else None,
"OCR",
)
vlm_provider = _safe_build(build_vlm_provider, ctx.vlm_cfg, "VLM")
# ---- 文字识别阶段(在事务外执行)----
ocr_text, ocr_status = await _extract_text(screenshot_id, ctx, ocr_provider)
# ---- VLM 阶段(事务外)----
vlm_result: Optional[VLMResult] = None
ai_status = ProcessStatus.SKIPPED.value
vlm_error: Optional[Exception] = None
if vlm_provider is not None:
_mark_status(screenshot_id, ai=ProcessStatus.RUNNING.value)
try:
vlm_result = await vlm_provider.analyze(
image_path=ctx.path,
ocr_text=ocr_text,
categories=ctx.category_names,
allow_upload=ctx.allow_upload,
)
ai_status = ProcessStatus.DONE.value
except Exception as exc: # noqa: BLE001
logger.warning("VLM 失败 #%d: %s", screenshot_id, exc)
ai_status = ProcessStatus.FAILED.value
vlm_error = exc
# ---- 写回阶段(短事务)----
_persist_result(
screenshot_id=screenshot_id,
ocr_text=ocr_text,
ocr_status=ocr_status,
ai_status=ai_status,
vlm_result=vlm_result,
)
if vlm_error is not None:
raise vlm_error # 让 worker 决定重试
async def analyze_ocr_only_by_id(screenshot_id: int) -> None:
"""仅补跑 OCR/视觉识文,不改动 AI 分析结果。
用于 ai_status=done 但 ocr_status=failed 的截图。
OCR 仍失败时抛异常,由 worker 按 max_retries 重试。
"""
ctx = _prepare(screenshot_id)
if ctx is None:
return
if not ctx.exists:
_persist_missing(screenshot_id)
raise RuntimeError("截图文件丢失")
ocr_provider = _safe_build(
lambda c: build_ocr_provider(c, allow_upload=ctx.allow_upload),
ctx.ocr_cfg if _use_traditional_ocr(ctx) else None,
"OCR",
)
ocr_text, ocr_status = await _extract_text(screenshot_id, ctx, ocr_provider)
_persist_ocr_only(screenshot_id, ocr_text, ocr_status)
if ocr_status == ProcessStatus.FAILED.value:
raise RuntimeError("OCR 识别失败")
# ---------------- 短事务工具 ---------------- #
def _prepare(screenshot_id: int) -> Optional[_PreparedContext]:
"""短事务:读取 Provider 配置、分类列表、敏感目录判定。"""
with session_scope() as session:
shot = session.get(Screenshot, screenshot_id)
if shot is None:
return None
image_path = path_from_storage(shot.path)
exists = is_accessible_file(image_path)
ocr_cfg = _load_provider_config(session, KEY_OCR_PROVIDER)
vlm_cfg = _load_provider_config(session, KEY_VLM_PROVIDER)
mode = get_setting(session, KEY_RECOGNITION_MODE, DEFAULT_RECOGNITION_MODE)
if mode not in RECOGNITION_MODES:
mode = DEFAULT_RECOGNITION_MODE
categories = _ensure_default_categories(session)
category_names = [c.name for c in categories]
allow_upload = not _is_sensitive(session, image_path)
return _PreparedContext(
path=image_path,
ocr_cfg=ocr_cfg,
vlm_cfg=vlm_cfg,
recognition_mode=mode,
category_names=category_names,
allow_upload=allow_upload,
exists=exists,
)
def _use_traditional_ocr(ctx: _PreparedContext) -> bool:
"""混合/传统模式下是否启用 OCR 区配置(排除 vision 类型,vision 单独处理)。"""
if ctx.recognition_mode not in ("ocr", "hybrid"):
return False
if ctx.ocr_cfg is None or ctx.ocr_cfg.type in ("", "none", "disabled", "vision"):
return False
return True
async def _extract_text(
screenshot_id: int,
ctx: _PreparedContext,
ocr_provider,
) -> tuple[str, str]:
"""按识别模式提取文字:传统 OCR / 视觉 AI / 混合。"""
ocr_text = ""
ocr_status = ProcessStatus.SKIPPED.value
mode = ctx.recognition_mode
# 1) 传统 OCRTesseract / Paddle / HTTP
if ocr_provider is not None:
_mark_status(screenshot_id, ocr=ProcessStatus.RUNNING.value)
try:
ocr_text = await ocr_provider.recognize(ctx.path)
ocr_status = ProcessStatus.DONE.value
except Exception as exc: # noqa: BLE001
logger.warning("OCR 失败 #%d: %s", screenshot_id, exc)
ocr_status = ProcessStatus.FAILED.value
# 2) 视觉 AI 识文
need_vision = mode == "vision" or (
mode == "hybrid" and not ocr_text.strip()
)
if mode == "ocr" and ctx.ocr_cfg and ctx.ocr_cfg.type == "vision":
# 用户在 OCR 区选了「视觉模型识文」
need_vision = True
if need_vision:
vision_cfg = _pick_vision_config(ctx)
vision = _safe_build(
lambda c: build_vision_ocr(c, allow_upload=ctx.allow_upload),
vision_cfg,
"VisionOCR",
)
if vision is not None:
_mark_status(screenshot_id, ocr=ProcessStatus.RUNNING.value)
try:
ocr_text = await vision.recognize(ctx.path)
ocr_status = ProcessStatus.DONE.value
except Exception as exc: # noqa: BLE001
logger.warning("视觉识文失败 #%d: %s", screenshot_id, exc)
if ocr_status != ProcessStatus.DONE.value:
ocr_status = ProcessStatus.FAILED.value
return ocr_text, ocr_status
def _pick_vision_config(ctx: _PreparedContext) -> Optional[ProviderConfig]:
"""决定视觉识文用哪套配置:优先 OCR 区的 vision,否则 VLM 区。"""
if ctx.ocr_cfg and ctx.ocr_cfg.type == "vision":
return ctx.ocr_cfg
if ctx.recognition_mode == "vision" or ctx.recognition_mode == "hybrid":
return ctx.vlm_cfg
return ctx.vlm_cfg
def _mark_status(
screenshot_id: int,
ocr: Optional[str] = None,
ai: Optional[str] = None,
) -> None:
"""短事务:把截图标记为 running,方便前端看到进度。"""
if ocr is None and ai is None:
return
with session_scope() as session:
shot = session.get(Screenshot, screenshot_id)
if shot is None:
return
if ocr is not None:
shot.ocr_status = ocr
if ai is not None:
shot.ai_status = ai
def _persist_missing(screenshot_id: int) -> None:
"""短事务:标记文件已丢失。"""
with session_scope() as session:
shot = session.get(Screenshot, screenshot_id)
if shot is None:
return
shot.ocr_status = ProcessStatus.FAILED.value
shot.ai_status = ProcessStatus.FAILED.value
meta = _get_or_create_meta(session, screenshot_id)
meta.ai_summary = "(文件丢失)"
def _persist_result(
screenshot_id: int,
ocr_text: str,
ocr_status: str,
ai_status: str,
vlm_result: Optional[VLMResult],
) -> None:
"""短事务:把 OCR/VLM 结果写回 DB,包括 meta/tags/category/todos。"""
with session_scope() as session:
shot = session.get(Screenshot, screenshot_id)
if shot is None:
return
shot.ocr_status = ocr_status
shot.ai_status = ai_status
meta = _get_or_create_meta(session, screenshot_id)
meta.ocr_text = ocr_text or None
if vlm_result is not None:
meta.ai_title = vlm_result.title or None
meta.ai_summary = vlm_result.summary or None
meta.ai_suggestion = vlm_result.suggestion or None
meta.ai_raw_json = json.dumps(vlm_result.raw, ensure_ascii=False)
categories = list(session.scalars(select(Category)).all())
category = _resolve_category(session, vlm_result.category, categories)
if category is not None:
shot.category_id = category.id
_sync_tags(session, shot, vlm_result.tags)
_sync_todos(session, shot, vlm_result.todos)
def _persist_ocr_only(screenshot_id: int, ocr_text: str, ocr_status: str) -> None:
"""短事务:仅写回 OCR 文本与状态,保留已有 AI 字段。"""
with session_scope() as session:
shot = session.get(Screenshot, screenshot_id)
if shot is None:
return
shot.ocr_status = ocr_status
meta = _get_or_create_meta(session, screenshot_id)
meta.ocr_text = ocr_text or None
def enqueue_ocr_jobs(*, limit: int = 500) -> int:
"""为「AI 已成功、OCR 失败」的截图批量创建 OCR 补跑任务。
跳过已有 pending/running 的 ocr 任务,避免重复入队。
"""
from app.models.job import Job, JobKind, JobStatus
active_status = (JobStatus.PENDING.value, JobStatus.RUNNING.value)
created = 0
with session_scope() as session:
# 已有活跃 OCR 任务的 screenshot_id
busy_ids = set(
session.scalars(
select(Job.screenshot_id).where(
Job.kind == JobKind.OCR.value,
Job.status.in_(active_status),
)
).all()
)
shots = session.scalars(
select(Screenshot)
.where(
Screenshot.ocr_status == ProcessStatus.FAILED.value,
Screenshot.ai_status == ProcessStatus.DONE.value,
)
.order_by(Screenshot.id.asc())
.limit(limit)
).all()
for shot in shots:
if shot.id in busy_ids:
continue
session.add(
Job(
screenshot_id=shot.id,
kind=JobKind.OCR.value,
status=JobStatus.PENDING.value,
)
)
busy_ids.add(shot.id)
created += 1
return created
# ---------------- 内部辅助 ---------------- #
def _load_provider_config(session, key: str) -> Optional[ProviderConfig]:
raw = get_provider_config(session, key)
if not raw:
return None
try:
return ProviderConfig(**raw)
except Exception as exc: # noqa: BLE001
logger.warning("Provider 配置 %s 解析失败: %s", key, exc)
return None
def _safe_build(builder, cfg: Optional[ProviderConfig], label: str):
if cfg is None:
return None
try:
return builder(cfg)
except Exception as exc: # noqa: BLE001
logger.warning("%s Provider 构造失败: %s", label, exc)
return None
def _is_sensitive(session, image_path: Path) -> bool:
"""判断文件是否落在某个标记为敏感的监听目录内。"""
sensitive_dirs = session.scalars(
select(WatchFolder.path).where(WatchFolder.is_sensitive.is_(True))
).all()
child = str(image_path)
for d in sensitive_dirs:
if path_is_under(d, child):
return True
return False
def _get_or_create_meta(session, screenshot_id: int) -> ScreenshotMeta:
meta = session.get(ScreenshotMeta, screenshot_id)
if meta is None:
meta = ScreenshotMeta(screenshot_id=screenshot_id)
session.add(meta)
session.flush()
return meta
def ensure_default_categories() -> None:
"""对外暴露:启动时 seed 默认分类。"""
with session_scope() as session:
_ensure_default_categories(session)
def _ensure_default_categories(session) -> list[Category]:
"""首次运行时灌入默认分类,返回最新列表。"""
existing = session.scalars(select(Category)).all()
if existing:
return list(existing)
for item in DEFAULT_CATEGORIES:
session.add(Category(**item))
session.flush()
return list(session.scalars(select(Category)).all())
def _resolve_category(
session,
name: str | None,
categories: list[Category],
) -> Optional[Category]:
if not name:
return None
normalized = name.strip()
for c in categories:
if c.name == normalized or c.name in normalized or normalized in c.name:
return c
new_cat = Category(name=normalized[:64], color=None, prompt_hint=None)
session.add(new_cat)
session.flush()
categories.append(new_cat)
return new_cat
def _sync_tags(session, screenshot: Screenshot, tag_names: list[str]) -> None:
"""根据 AI 给的标签名同步多对多关系;保留 EXIF 地点标签不被覆盖。"""
exif_tags = [t for t in (screenshot.tags or []) if is_exif_location_tag(t.name)]
exif_names = {t.name for t in exif_tags}
seen: set[str] = set(exif_names)
tag_objs: list[Tag] = list(exif_tags)
for raw_name in tag_names:
name = (raw_name or "").strip()[:64]
if not name or name in seen:
continue
seen.add(name)
tag = session.scalar(select(Tag).where(Tag.name == name))
if tag is None:
tag = Tag(name=name)
session.add(tag)
session.flush()
tag_objs.append(tag)
screenshot.tags = tag_objs
def _sync_todos(
session,
screenshot: Screenshot,
todos: list[dict[str, str]],
) -> None:
"""以 AI 输出覆盖该截图未完成的 todos;保留用户已完成/搁置项。"""
existing = list(screenshot.todos)
for t in existing:
if t.status in (TodoStatus.DONE.value, TodoStatus.DROPPED.value):
continue
session.delete(t)
session.flush()
for item in todos:
title = (item.get("title") or "").strip()
if not title:
continue
session.add(
Todo(
screenshot_id=screenshot.id,
title=title[:512],
note=(item.get("note") or "")[:2000] or None,
kind=(item.get("kind") or "待办")[:32],
status=TodoStatus.PENDING.value,
)
)
session.flush()
+107
View File
@@ -0,0 +1,107 @@
"""从图片 EXIF 提取拍摄时间与 GPS 地点标签。"""
from __future__ import annotations
from datetime import datetime
from fractions import Fraction
from pathlib import Path
from typing import Optional
from PIL import ExifTags, Image
from app.core.logger import get_logger
logger = get_logger(__name__)
# EXIF 地点类标签前缀,重分析时保留不被 AI 覆盖
EXIF_TAG_PREFIX = "地点:"
def _ratio_to_float(value) -> float:
"""EXIF 有理数 → float。"""
if isinstance(value, tuple) and len(value) == 2:
num, den = value
return float(num) / float(den) if den else 0.0
if isinstance(value, Fraction):
return float(value)
return float(value)
def _dms_to_decimal(dms: tuple, ref: str) -> Optional[float]:
"""度分秒 → 十进制度。"""
try:
deg, minutes, seconds = dms
decimal = _ratio_to_float(deg) + _ratio_to_float(minutes) / 60 + _ratio_to_float(seconds) / 3600
if ref in ("S", "W"):
decimal = -decimal
return round(decimal, 6)
except (TypeError, ValueError, ZeroDivisionError):
return None
def extract_image_metadata(path: Path) -> tuple[Optional[datetime], list[str]]:
"""读取 EXIF,返回 (拍摄时间, 地点标签列表)。"""
captured: Optional[datetime] = None
location_tags: list[str] = []
try:
with Image.open(path) as img:
exif = img.getexif()
if not exif:
return None, []
# 拍摄时间:优先 DateTimeOriginal
for key in (36867, 36868, 306): # DateTimeOriginal / DateTimeDigitized / DateTime
raw = exif.get(key)
if raw:
captured = _parse_exif_datetime(str(raw))
if captured:
break
# GPS → 地点标签
gps_ifd = exif.get_ifd(ExifTags.IFD.GPSInfo) if hasattr(exif, "get_ifd") else None
if gps_ifd:
lat = _dms_to_decimal(
gps_ifd.get(2),
gps_ifd.get(1, "N"),
)
lon = _dms_to_decimal(
gps_ifd.get(4),
gps_ifd.get(3, "E"),
)
if lat is not None and lon is not None:
location_tags.append(f"{EXIF_TAG_PREFIX}{lat},{lon}")
# 部分设备写入可读地名(XP Keywords / ImageDescription 等)
for key, val in exif.items():
tag_name = ExifTags.TAGS.get(key, "")
if tag_name in ("ImageDescription", "XPComment") and val:
text = str(val).strip()[:64]
if text and _looks_like_place(text):
location_tags.append(f"{EXIF_TAG_PREFIX}{text}")
except Exception as exc: # noqa: BLE001
logger.debug("读取 EXIF 失败 %s: %s", path.name, exc)
return captured, location_tags
def _parse_exif_datetime(raw: str) -> Optional[datetime]:
"""解析 EXIF 时间字符串。"""
for fmt in ("%Y:%m:%d %H:%M:%S", "%Y-%m-%d %H:%M:%S"):
try:
return datetime.strptime(raw.strip(), fmt)
except ValueError:
continue
return None
def _looks_like_place(text: str) -> bool:
"""粗判字符串是否像地名(含中文或常见地址关键词)。"""
keywords = ("", "", "", "", "", "", "", "", "", "GPS")
return any(k in text for k in keywords) or any("\u4e00" <= c <= "\u9fff" for c in text)
def is_exif_location_tag(name: str) -> bool:
"""是否为 EXIF 自动写入的地点标签。"""
return name.startswith(EXIF_TAG_PREFIX)
+147
View File
@@ -0,0 +1,147 @@
"""将磁盘上的截图文件入库 + 排队分析。"""
from __future__ import annotations
from datetime import datetime
from pathlib import Path
from typing import Iterable, Optional
from PIL import Image
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.path_utils import (
is_accessible_dir,
is_accessible_file,
path_from_storage,
path_to_storage,
)
from app.core.logger import get_logger
from app.models.job import Job, JobKind, JobStatus
from app.models.screenshot import ProcessStatus, Screenshot
from app.models.tag import Tag
from app.services.exif_utils import extract_image_metadata
from app.services.thumbnail import file_hash, generate_thumbnail, is_supported
logger = get_logger(__name__)
def ingest_path(session: Session, path: Path) -> Optional[Screenshot]:
"""单文件入库。返回 Screenshot 或 None(不支持/重复时)。"""
if not is_accessible_file(path) or not path.is_file():
return None
if not is_supported(path):
return None
stored_path = path_to_storage(path)
try:
digest = file_hash(path)
except OSError as exc:
logger.warning("无法读取文件 %s: %s", path, exc)
return None
existing = session.scalar(select(Screenshot).where(Screenshot.file_hash == digest))
if existing:
# 同一内容重命名/移动:更新路径
if existing.path != stored_path:
existing.path = stored_path
session.flush()
return existing
try:
with Image.open(path) as img:
width, height = img.size
except Exception as exc: # noqa: BLE001
logger.warning("无法读取图片尺寸 %s: %s", path, exc)
width, height = 0, 0
stat = path.stat()
captured_at = datetime.fromtimestamp(stat.st_mtime)
exif_time, location_tags = extract_image_metadata(path)
if exif_time is not None:
captured_at = exif_time
try:
thumb = generate_thumbnail(path)
thumb_path = thumb.as_posix()
except Exception as exc: # noqa: BLE001
logger.warning("生成缩略图失败 %s: %s", path, exc)
thumb_path = None
shot = Screenshot(
path=stored_path,
file_hash=digest,
width=width,
height=height,
size=stat.st_size,
captured_at=captured_at,
thumb_path=thumb_path,
ocr_status=ProcessStatus.PENDING.value,
ai_status=ProcessStatus.PENDING.value,
)
session.add(shot)
session.flush()
if location_tags:
_attach_location_tags(session, shot, location_tags)
job = Job(screenshot_id=shot.id, kind=JobKind.FULL.value, status=JobStatus.PENDING.value)
session.add(job)
logger.info("入库 #%d %s", shot.id, path.name)
return shot
def _attach_location_tags(session: Session, shot: Screenshot, tag_names: list[str]) -> None:
"""入库时写入 EXIF 地点标签。"""
tag_objs: list[Tag] = []
for raw in tag_names:
name = (raw or "").strip()[:64]
if not name:
continue
tag = session.scalar(select(Tag).where(Tag.name == name))
if tag is None:
tag = Tag(name=name)
session.add(tag)
session.flush()
tag_objs.append(tag)
shot.tags = tag_objs
def ingest_directory(
session: Session,
root: Path | str,
recursive: bool = True,
) -> tuple[int, int]:
"""遍历目录入库。返回 (新增数, 跳过数)。支持 UNC 网络路径。"""
root_p = path_from_storage(str(root)) if isinstance(root, str) else root
if not is_accessible_dir(root_p):
return 0, 0
iterator: Iterable[Path]
if recursive:
iterator = (p for p in root_p.rglob("*") if p.is_file())
else:
iterator = (p for p in root_p.iterdir() if p.is_file())
added, skipped = 0, 0
for path in iterator:
if not is_supported(path):
continue
stored = path_to_storage(path)
before = session.scalar(
select(Screenshot.id).where(Screenshot.path == stored)
)
result = ingest_path(session, path)
if result is None:
skipped += 1
continue
if before is None:
added += 1
else:
skipped += 1
# 批量提交,避免巨型事务
if (added + skipped) % 50 == 0:
session.commit()
session.commit()
return added, skipped
+207
View File
@@ -0,0 +1,207 @@
"""Provider 连通性测试:OCR / 视觉 AI。"""
from __future__ import annotations
import asyncio
import base64
import time
from typing import Any
import httpx
from app.providers import build_ocr_provider
from app.schemas.common import ProviderConfig
class ProviderTestError(Exception):
"""测试失败,携带用户可读信息。"""
async def test_provider_config(key: str, cfg: ProviderConfig) -> dict[str, Any]:
"""测试 OCR 或 VLM Provider 连通性,返回 {ok, message, detail, latency_ms}。"""
started = time.perf_counter()
try:
if cfg.type in ("", "none", "disabled"):
raise ProviderTestError("当前类型为「不使用」,无需测试")
if key == KEY_OCR:
message, detail = await _test_ocr(cfg)
elif key == KEY_VLM:
message, detail = await _test_vlm(cfg)
else:
raise ProviderTestError(f"未知配置键: {key}")
latency = int((time.perf_counter() - started) * 1000)
return {"ok": True, "message": message, "detail": detail, "latency_ms": latency}
except ProviderTestError as exc:
latency = int((time.perf_counter() - started) * 1000)
return {"ok": False, "message": str(exc), "detail": None, "latency_ms": latency}
except Exception as exc: # noqa: BLE001
latency = int((time.perf_counter() - started) * 1000)
return {
"ok": False,
"message": f"测试失败: {exc}",
"detail": repr(exc),
"latency_ms": latency,
}
KEY_OCR = "ocr_provider"
KEY_VLM = "vlm_provider"
# 1x1 白图,用于 HTTP OCR / 视觉测试
_TINY_PNG_B64 = (
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
)
async def _test_ocr(cfg: ProviderConfig) -> tuple[str, str | None]:
if cfg.type == "tesseract":
return await _test_tesseract(cfg)
if cfg.type == "paddleocr":
return await _test_paddle(cfg)
if cfg.type == "http":
return await _test_http_ocr(cfg)
if cfg.type == "vision":
return await _test_openai_compat(cfg, label="视觉 OCR")
raise ProviderTestError(f"不支持的 OCR 类型: {cfg.type}")
async def _test_vlm(cfg: ProviderConfig) -> tuple[str, str | None]:
if cfg.type in ("openai_compat", "openai", "ollama", "glm", "minimax", "moonshot", "vision"):
return await _test_openai_compat(cfg, label="视觉 AI")
raise ProviderTestError(f"不支持的 VLM 类型: {cfg.type}")
async def _test_tesseract(cfg: ProviderConfig) -> tuple[str, str | None]:
provider = build_ocr_provider(cfg, allow_upload=True)
if provider is None:
raise ProviderTestError("无法构造 Tesseract Provider")
def _check() -> str:
import pytesseract
if cfg.extra.get("cmd"):
pytesseract.pytesseract.tesseract_cmd = cfg.extra["cmd"]
version = pytesseract.get_tesseract_version()
return str(version)
version = await asyncio.to_thread(_check)
return f"Tesseract 可用,版本 {version}", f"lang={cfg.extra.get('lang', 'chi_sim+eng')}"
async def _test_paddle(cfg: ProviderConfig) -> tuple[str, str | None]:
def _check() -> str:
try:
import paddleocr # noqa: F401
except ImportError as exc:
raise ProviderTestError(
"未安装 PaddleOCR,请执行: pip install paddleocr paddlepaddle"
) from exc
return "PaddleOCR 模块已安装"
detail = await asyncio.to_thread(_check)
provider = build_ocr_provider(cfg, allow_upload=True)
if provider is None:
raise ProviderTestError("无法构造 PaddleOCR Provider")
return "PaddleOCR 可用", detail
async def _test_http_ocr(cfg: ProviderConfig) -> tuple[str, str | None]:
if not cfg.base_url:
raise ProviderTestError("请填写 OCR API URL")
provider = build_ocr_provider(cfg, allow_upload=True)
if provider is None:
raise ProviderTestError("无法构造 HTTP OCR Provider")
# 写入临时 tiny png 再调用
from pathlib import Path
import tempfile
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
tmp.write(base64.b64decode(_TINY_PNG_B64))
tmp_path = Path(tmp.name)
try:
text = await provider.recognize(tmp_path)
finally:
try:
tmp_path.unlink(missing_ok=True)
except OSError:
pass
preview = (text or "").strip()[:80] or "(空响应,但接口可达)"
return "HTTP OCR 接口可达", f"响应预览: {preview}"
async def _test_openai_compat(cfg: ProviderConfig, *, label: str) -> tuple[str, str | None]:
base_url = (cfg.base_url or "http://localhost:11434/v1").rstrip("/")
api_key = cfg.api_key or ""
model = cfg.model or "gpt-4o-mini"
timeout = float(cfg.extra.get("timeout", 30))
headers: dict[str, str] = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
# 1) 尝试 /modelsOllama、OpenAI 兼容)
models_url = f"{base_url}/models"
async with httpx.AsyncClient(timeout=timeout) as client:
try:
resp = await client.get(models_url, headers=headers)
if resp.status_code == 200:
data = resp.json()
ids = _extract_model_ids(data)
if model and ids and model not in ids:
return (
f"{label} 服务可达",
f"已连接 /models,但未找到模型「{model}」。可用: {', '.join(ids[:8])}",
)
return f"{label} 服务可达", f"已连接 /models,目标模型: {model}"
except httpx.HTTPError:
pass
# 2) 最小 chat 探活
chat_url = f"{base_url}/chat/completions"
payload = {
"model": model,
"messages": [{"role": "user", "content": "请只回复 OK"}],
"max_tokens": 16,
"temperature": 0,
}
try:
resp = await client.post(chat_url, json=payload, headers=headers)
resp.raise_for_status()
data = resp.json()
content = data["choices"][0]["message"]["content"]
return f"{label} 对话成功", f"模型 {model} 回复: {str(content).strip()[:60]}"
except httpx.HTTPStatusError as exc:
body = exc.response.text[:200]
raise ProviderTestError(
f"API 返回 {exc.response.status_code}: {body}"
) from exc
except httpx.HTTPError as exc:
raise ProviderTestError(f"无法连接 {base_url}: {exc}") from exc
def _extract_model_ids(data: Any) -> list[str]:
"""从 /models 响应中提取 model id 列表。"""
if not isinstance(data, dict):
return []
items = data.get("data") or data.get("models") or []
ids: list[str] = []
if isinstance(items, list):
for item in items:
if isinstance(item, dict):
mid = item.get("id") or item.get("name") or item.get("model")
if mid:
ids.append(str(mid))
elif isinstance(item, str):
ids.append(item)
return ids
def merge_provider_api_key(cfg: ProviderConfig, existing: dict | None) -> ProviderConfig:
"""测试时若 api_key 为空,合并已保存的 key。"""
payload = cfg.model_dump()
if (not payload.get("api_key")) and isinstance(existing, dict):
payload["api_key"] = existing.get("api_key", "")
return ProviderConfig(**payload)
+67
View File
@@ -0,0 +1,67 @@
"""截图列表搜索:FTS + 子串模糊(兼容中文标签/标题)。"""
from __future__ import annotations
from sqlalchemy import or_, select, text
from sqlalchemy.orm import Session
from app.models.meta import ScreenshotMeta
from app.models.screenshot import Screenshot
from app.models.tag import Tag
def fts_query_string(raw: str) -> str:
"""把用户输入处理成 FTS5 查询串(中英文均支持前缀匹配)。"""
parts = [p for p in raw.replace("\n", " ").split() if p]
if not parts:
return raw
cleaned: list[str] = []
for p in parts:
p = p.replace('"', "").strip()
if not p:
continue
cleaned.append(f'"{p}"*')
return " ".join(cleaned)
def collect_search_ids(session: Session, q: str, *, limit: int = 5000) -> set[int]:
"""联合 FTS5 与 LIKE 子串搜索,返回匹配的 screenshot id 集合。"""
q = q.strip()
if not q:
return set()
ids: set[int] = set()
like = f"%{q}%"
# 1) FTS5 全文索引
try:
fts_sql = text(
"SELECT rowid FROM screenshots_fts WHERE screenshots_fts MATCH :q LIMIT :lim"
)
rows = session.execute(fts_sql, {"q": fts_query_string(q), "lim": limit}).fetchall()
ids.update(int(row[0]) for row in rows)
except Exception:
pass
# 2) 子串模糊:OCR/AI 文本(解决「三花」匹配「三花猫」)
meta_ids = session.scalars(
select(ScreenshotMeta.screenshot_id).where(
or_(
ScreenshotMeta.ocr_text.ilike(like),
ScreenshotMeta.ai_title.ilike(like),
ScreenshotMeta.ai_summary.ilike(like),
ScreenshotMeta.ai_suggestion.ilike(like),
)
).limit(limit)
).all()
ids.update(int(i) for i in meta_ids)
# 3) 标签名子串匹配
tag_ids = session.scalars(
select(Screenshot.id)
.join(Screenshot.tags)
.where(Tag.name.ilike(like))
.limit(limit)
).all()
ids.update(int(i) for i in tag_ids)
return ids
+50
View File
@@ -0,0 +1,50 @@
"""读取/写入键值设置。"""
from __future__ import annotations
import json
from typing import Any, Optional
from sqlalchemy.orm import Session
from app.models.setting import Setting
def get_setting(session: Session, key: str, default: Any = None) -> Any:
"""读取并 JSON 解析。"""
row = session.get(Setting, key)
if row is None:
return default
try:
return json.loads(row.value_json)
except json.JSONDecodeError:
return default
def set_setting(session: Session, key: str, value: Any) -> None:
"""JSON 序列化后落库(upsert)。"""
row = session.get(Setting, key)
payload = json.dumps(value, ensure_ascii=False)
if row is None:
session.add(Setting(key=key, value_json=payload))
else:
row.value_json = payload
session.flush()
def all_settings(session: Session) -> dict[str, Any]:
"""返回所有设置,给前端调试 / 导出。"""
items: dict[str, Any] = {}
for row in session.query(Setting).all():
try:
items[row.key] = json.loads(row.value_json)
except json.JSONDecodeError:
items[row.key] = row.value_json
return items
def get_provider_config(session: Session, key: str) -> Optional[dict[str, Any]]:
"""便捷读取 OCR/VLM provider 配置 dict。"""
value = get_setting(session, key, None)
if isinstance(value, dict):
return value
return None
+48
View File
@@ -0,0 +1,48 @@
"""生成并缓存缩略图。"""
from __future__ import annotations
import hashlib
from pathlib import Path
from PIL import Image
from app.core.config import settings
from app.core.path_utils import path_to_storage
SUPPORTED_EXTS = {".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp", ".tif", ".tiff"}
def is_supported(path: Path) -> bool:
"""是否为我们支持的图片格式。"""
return path.suffix.lower() in SUPPORTED_EXTS
def generate_thumbnail(image_path: Path, max_side: int | None = None) -> Path:
"""生成 webp 缩略图,落到缓存目录。返回缓存路径。"""
max_side = max_side or settings.thumb_size
# 用文件路径 + mtime 哈希作为缓存键,源文件变化会自动生成新缩略图
stat = image_path.stat()
key = hashlib.md5(
f"{path_to_storage(image_path)}|{stat.st_mtime_ns}|{max_side}".encode("utf-8")
).hexdigest()
out = settings.thumb_dir / f"{key}.webp"
if out.exists():
return out
with Image.open(image_path) as img:
img = img.convert("RGB")
img.thumbnail((max_side, max_side), Image.LANCZOS)
img.save(out, format="WEBP", quality=80)
return out
def file_hash(image_path: Path, chunk: int = 1024 * 1024) -> str:
"""计算文件 sha256,用作去重键。"""
h = hashlib.sha256()
with open(image_path, "rb") as f:
while True:
data = f.read(chunk)
if not data:
break
h.update(data)
return h.hexdigest()
+122
View File
@@ -0,0 +1,122 @@
"""watchdog 监听被关注的目录。
中文路径与 OneDrive 同步盘下 NTFS 事件偶发不稳,因此默认使用 PollingObserver。
"""
from __future__ import annotations
import asyncio
import threading
from pathlib import Path
from sqlalchemy import select
from watchdog.events import FileSystemEvent, FileSystemEventHandler
from watchdog.observers.polling import PollingObserver
from app.core.db import session_scope
from app.core.logger import get_logger
from app.core.path_utils import is_accessible_dir, path_from_storage
from app.models.watch_folder import WatchFolder
from app.services.ingest import ingest_path
from app.services.thumbnail import is_supported
logger = get_logger(__name__)
class _ScreenshotEventHandler(FileSystemEventHandler):
"""新文件 -> 入库 -> 触发分析。"""
def __init__(self, loop: asyncio.AbstractEventLoop, notify) -> None: # noqa: ANN001
self._loop = loop
self._notify = notify
def on_created(self, event: FileSystemEvent) -> None:
if event.is_directory:
return
self._handle(Path(event.src_path))
def on_moved(self, event: FileSystemEvent) -> None:
if event.is_directory:
return
self._handle(Path(getattr(event, "dest_path", event.src_path)))
def _handle(self, path: Path) -> None:
if not is_supported(path):
return
# 等待写入完成(截图工具常会先创建空文件再写入)
try:
self._wait_file_ready(path)
except FileNotFoundError:
return
with session_scope() as session:
shot = ingest_path(session, path)
if shot is not None:
asyncio.run_coroutine_threadsafe(self._notify(), self._loop)
@staticmethod
def _wait_file_ready(path: Path, retries: int = 10, interval: float = 0.3) -> None:
"""轮询直至文件大小稳定。"""
import time
last = -1
for _ in range(retries):
if not path.exists():
raise FileNotFoundError(path)
size = path.stat().st_size
if size > 0 and size == last:
return
last = size
time.sleep(interval)
class WatcherService:
"""管理多个监听目录的生命周期。"""
def __init__(self) -> None:
self._observer: PollingObserver | None = None
self._lock = threading.Lock()
self._loop: asyncio.AbstractEventLoop | None = None
self._notify_cb = None
def start(self, loop: asyncio.AbstractEventLoop, notify) -> None: # noqa: ANN001
"""根据数据库中的目录列表启动监听。"""
with self._lock:
self._loop = loop
self._notify_cb = notify
self._stop_locked()
self._observer = PollingObserver(timeout=2.0)
handler = _ScreenshotEventHandler(loop, notify)
with session_scope() as session:
folders = session.scalars(
select(WatchFolder).where(WatchFolder.enabled.is_(True))
).all()
paths = [(f.path, f.recursive) for f in folders]
for path, recursive in paths:
p = path_from_storage(path)
if not is_accessible_dir(p):
logger.warning("监听目录不存在或不可访问,跳过: %s", path)
continue
logger.info("开始监听 %s (recursive=%s)", path, recursive)
self._observer.schedule(handler, str(p), recursive=recursive)
self._observer.start()
def reload(self) -> None:
"""监听目录变更后重启。"""
if self._loop is None or self._notify_cb is None:
return
self.start(self._loop, self._notify_cb)
def stop(self) -> None:
with self._lock:
self._stop_locked()
def _stop_locked(self) -> None:
if self._observer is not None:
try:
self._observer.stop()
self._observer.join(timeout=3)
finally:
self._observer = None
watcher_service = WatcherService()
+237
View File
@@ -0,0 +1,237 @@
"""异步任务调度器:从 jobs 表取任务并并发执行。
事务规则:
- 调度循环只用短事务 claim 任务、汇总状态。
- 真正的 OCR/VLM 调用由 `analyze_screenshot_by_id` 自己管理短事务,
绝不在 worker 这一层包裹长事务。
"""
from __future__ import annotations
import asyncio
from datetime import datetime, timedelta
from typing import Optional
from sqlalchemy import case, func, or_, select
from app.core.config import settings
from app.core.db import session_scope
from app.core.logger import get_logger
from app.models.job import Job, JobKind, JobStatus
from app.models.screenshot import ProcessStatus, Screenshot
from app.services.analyze import analyze_ocr_only_by_id, analyze_screenshot_by_id
logger = get_logger(__name__)
class AnalyzeWorker:
"""单实例后台 worker,负责把 jobs 表中的待处理项跑完。"""
def __init__(self) -> None:
self._task: Optional[asyncio.Task] = None
self._event = asyncio.Event()
self._stop = False
self._semaphore = asyncio.Semaphore(settings.analyze_concurrency)
self._inflight: int = 0
self._lock = asyncio.Lock()
self._loop: Optional[asyncio.AbstractEventLoop] = None
async def start(self) -> None:
"""启动主循环。"""
# 启动时把上次中断的 RUNNING 任务复位
with session_scope() as session:
running = session.scalars(
select(Job).where(Job.status == JobStatus.RUNNING.value)
).all()
for job in running:
job.status = JobStatus.PENDING.value
self._stop = False
self._loop = asyncio.get_running_loop()
self._task = asyncio.create_task(self._run(), name="analyze-worker")
self.notify()
async def stop(self) -> None:
self._stop = True
self._event.set()
if self._task is not None:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
def notify(self) -> None:
"""在事件循环线程内通知 worker 有新任务可取。"""
self._event.set()
def notify_threadsafe(self) -> None:
"""跨线程唤醒 workerFastAPI BackgroundTasks / watcher 线程)。
asyncio.Event.set() 本身不是线程安全的,必须通过
loop.call_soon_threadsafe 调度回事件循环线程。
"""
loop = self._loop
if loop is None or loop.is_closed():
return
loop.call_soon_threadsafe(self._event.set)
async def status(self) -> dict[str, int]:
"""供 API 查询当前队列状况(按 status 索引计数,适合大批量)。"""
with session_scope() as session:
pending = session.scalar(
select(Job.id)
.where(Job.status == JobStatus.PENDING.value)
.limit(1)
)
rows = session.execute(
select(Job.status, func.count())
.group_by(Job.status)
).all()
counts = {st.value: 0 for st in JobStatus}
for status, cnt in rows:
counts[status] = int(cnt)
counts["inflight"] = self._inflight
counts["has_more"] = 1 if pending is not None else 0
return counts
def reset_stale_running(self, *, minutes: int = 5, reset_all: bool = False) -> int:
"""把长时间 RUNNING 且无进展的任务复位为 PENDING。"""
with session_scope() as session:
q = select(Job).where(Job.status == JobStatus.RUNNING.value)
if not reset_all:
cutoff = datetime.utcnow() - timedelta(minutes=max(minutes, 1))
q = q.where(Job.started_at.is_not(None), Job.started_at < cutoff)
stale = session.scalars(q).all()
for job in stale:
job.status = JobStatus.PENDING.value
job.started_at = None
count = len(stale)
if count:
logger.info("复位 %d 条 RUNNING 任务为 PENDING", count)
self.notify()
return count
def retry_failed(self, job_ids: Optional[list[int]] = None) -> int:
"""将 failed 任务重新排队。"""
with session_scope() as session:
q = select(Job).where(Job.status == JobStatus.FAILED.value)
if job_ids:
q = q.where(Job.id.in_(job_ids))
failed = session.scalars(q).all()
for job in failed:
job.status = JobStatus.PENDING.value
job.retries = 0
job.last_error = None
job.started_at = None
job.finished_at = None
count = len(failed)
if count:
logger.info("重试 %d 条 failed 任务", count)
self.notify()
return count
async def _run(self) -> None:
"""主循环。"""
idle_rounds = 0
while not self._stop:
job = self._claim_one()
if job is None:
idle_rounds += 1
# 空闲时定期清理僵尸 RUNNING,避免 inflight=0 但 DB 仍显示 running
if idle_rounds >= 3 and self._inflight == 0:
idle_rounds = 0
if self.reset_stale_running(minutes=5):
continue
self._event.clear()
try:
await asyncio.wait_for(self._event.wait(), timeout=10)
except asyncio.TimeoutError:
pass
continue
idle_rounds = 0
await self._semaphore.acquire()
async with self._lock:
self._inflight += 1
asyncio.create_task(
self._process(job["id"], job["screenshot_id"], job["kind"])
)
def _claim_one(self) -> Optional[dict]:
"""短事务:取一条 PENDING 任务;FULL 优先于 OCR 补跑。"""
with session_scope() as session:
job = session.scalar(
select(Job)
.where(
Job.status == JobStatus.PENDING.value,
or_(Job.retries < settings.max_retries, Job.retries.is_(None)),
)
.order_by(
case(
(Job.kind == JobKind.FULL.value, 0),
(Job.kind == JobKind.VLM.value, 1),
else_=2,
),
Job.id.asc(),
)
.limit(1)
)
if job is None:
return None
job.status = JobStatus.RUNNING.value
job.started_at = datetime.utcnow()
session.flush()
return {
"id": job.id,
"screenshot_id": job.screenshot_id,
"kind": job.kind,
}
async def _process(self, job_id: int, screenshot_id: int, kind: str) -> None:
"""执行单个任务,所有 DB 写入均在短事务中。"""
try:
try:
if kind == JobKind.OCR.value:
await analyze_ocr_only_by_id(screenshot_id)
else:
await analyze_screenshot_by_id(screenshot_id)
self._finish(job_id, success=True, kind=kind)
except Exception as exc: # noqa: BLE001
logger.exception("分析失败 #%d (%s): %s", screenshot_id, kind, exc)
self._finish(job_id, success=False, error=str(exc), kind=kind)
finally:
self._semaphore.release()
async with self._lock:
self._inflight -= 1
self.notify()
def _finish(
self,
job_id: int,
success: bool,
error: Optional[str] = None,
kind: str = JobKind.FULL.value,
) -> None:
"""短事务:更新 jobs 表完成状态。"""
with session_scope() as session:
job = session.get(Job, job_id)
if job is None:
return
if success:
job.status = JobStatus.DONE.value
job.last_error = None
else:
job.retries = (job.retries or 0) + 1
if job.retries >= settings.max_retries:
job.status = JobStatus.FAILED.value
# OCR 补跑失败不影响 ai_status
if kind != JobKind.OCR.value:
shot = session.get(Screenshot, job.screenshot_id)
if shot is not None:
shot.ai_status = ProcessStatus.FAILED.value
else:
job.status = JobStatus.PENDING.value
job.last_error = (error or "")[:1000]
job.finished_at = datetime.utcnow()
worker = AnalyzeWorker()