Initial commit: snapAna 截图智能整理工具
包含 FastAPI 后端、React 前端、队列/OCR/标签/待办等完整功能。 Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -0,0 +1,256 @@
|
||||
"""监听目录的增删改、手动导入、分析队列。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from app.api.deps import db_session
|
||||
from app.core.db import session_scope
|
||||
from app.core.path_utils import (
|
||||
count_files_sample,
|
||||
is_accessible_dir,
|
||||
normalize_user_path,
|
||||
)
|
||||
from app.models.job import Job, JobStatus
|
||||
from app.models.screenshot import ProcessStatus, Screenshot
|
||||
from app.models.watch_folder import WatchFolder
|
||||
from app.schemas.common import WatchFolderIn, WatchFolderOut
|
||||
from app.schemas.job import JobListResp, JobOut, JobRetryIn
|
||||
from app.services.analyze import enqueue_ocr_jobs
|
||||
from app.services.ingest import ingest_directory
|
||||
from app.services.watcher import watcher_service
|
||||
from app.services.worker import worker
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/watch", tags=["watch"])
|
||||
|
||||
|
||||
def _validate_folder_path(raw: str) -> str:
|
||||
"""校验并规范化监听目录路径(含 UNC 网络路径)。"""
|
||||
normalized = normalize_user_path(raw)
|
||||
if not normalized:
|
||||
raise HTTPException(400, "路径不能为空")
|
||||
if not is_accessible_dir(normalized):
|
||||
raise HTTPException(
|
||||
400,
|
||||
f"目录不存在或无法访问: {normalized}。"
|
||||
"请确认 NAS 已挂载、有读权限,UNC 路径形如 \\\\服务器\\共享\\文件夹",
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
@router.get("/folders", response_model=list[WatchFolderOut])
|
||||
def list_folders(session: Session = Depends(db_session)) -> list[WatchFolderOut]:
|
||||
rows = session.scalars(select(WatchFolder).order_by(WatchFolder.id)).all()
|
||||
return [WatchFolderOut.model_validate(r) for r in rows]
|
||||
|
||||
|
||||
@router.post("/folders", response_model=WatchFolderOut)
|
||||
def add_folder(
|
||||
payload: WatchFolderIn,
|
||||
background: BackgroundTasks,
|
||||
session: Session = Depends(db_session),
|
||||
) -> WatchFolderOut:
|
||||
"""新增监听目录,自动触发一次扫描入库。"""
|
||||
normalized = _validate_folder_path(payload.path)
|
||||
exists = session.scalar(select(WatchFolder).where(WatchFolder.path == normalized))
|
||||
if exists is not None:
|
||||
raise HTTPException(400, "目录已存在")
|
||||
folder = WatchFolder(
|
||||
path=normalized,
|
||||
enabled=payload.enabled,
|
||||
recursive=payload.recursive,
|
||||
is_sensitive=payload.is_sensitive,
|
||||
)
|
||||
session.add(folder)
|
||||
session.commit()
|
||||
session.refresh(folder)
|
||||
watcher_service.reload()
|
||||
background.add_task(_scan_folder, normalized, payload.recursive)
|
||||
return WatchFolderOut.model_validate(folder)
|
||||
|
||||
|
||||
@router.patch("/folders/{folder_id}", response_model=WatchFolderOut)
|
||||
def update_folder(
|
||||
folder_id: int,
|
||||
payload: WatchFolderIn,
|
||||
session: Session = Depends(db_session),
|
||||
) -> WatchFolderOut:
|
||||
folder = session.get(WatchFolder, folder_id)
|
||||
if folder is None:
|
||||
raise HTTPException(404, "folder not found")
|
||||
normalized = _validate_folder_path(payload.path)
|
||||
folder.path = normalized
|
||||
folder.enabled = payload.enabled
|
||||
folder.recursive = payload.recursive
|
||||
folder.is_sensitive = payload.is_sensitive
|
||||
session.commit()
|
||||
session.refresh(folder)
|
||||
watcher_service.reload()
|
||||
return WatchFolderOut.model_validate(folder)
|
||||
|
||||
|
||||
@router.delete("/folders/{folder_id}")
|
||||
def delete_folder(
|
||||
folder_id: int,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
folder = session.get(WatchFolder, folder_id)
|
||||
if folder is None:
|
||||
raise HTTPException(404, "folder not found")
|
||||
session.delete(folder)
|
||||
session.commit()
|
||||
watcher_service.reload()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.post("/import")
|
||||
def import_now(
|
||||
payload: WatchFolderIn,
|
||||
background: BackgroundTasks,
|
||||
) -> dict:
|
||||
"""手动触发一次目录扫描(不一定要登记为监听)。"""
|
||||
normalized = _validate_folder_path(payload.path)
|
||||
background.add_task(_scan_folder, normalized, payload.recursive)
|
||||
return {"ok": True, "message": "已在后台扫描"}
|
||||
|
||||
|
||||
@router.post("/validate-path")
|
||||
def validate_path(payload: WatchFolderIn) -> dict:
|
||||
"""测试目录是否可访问(含 UNC 网络路径),返回抽样文件数。"""
|
||||
normalized = _validate_folder_path(payload.path)
|
||||
total, samples = count_files_sample(normalized, limit=3)
|
||||
return {
|
||||
"ok": True,
|
||||
"path": normalized,
|
||||
"sample_image_count": total,
|
||||
"samples": samples,
|
||||
"message": f"目录可访问,抽样发现约 {total}+ 张图片",
|
||||
}
|
||||
|
||||
|
||||
def _scan_folder(path: str, recursive: bool) -> None:
|
||||
"""后台任务:扫描目录入库,再通知 worker。
|
||||
|
||||
BackgroundTasks 的同步函数运行在线程池中,必须用 threadsafe 入口
|
||||
唤醒事件循环,否则 asyncio.Event.set() 会有竞态。
|
||||
"""
|
||||
with session_scope() as session:
|
||||
ingest_directory(session, path, recursive=recursive)
|
||||
worker.notify_threadsafe()
|
||||
|
||||
|
||||
@router.get("/queue")
|
||||
async def queue_status() -> dict:
|
||||
"""读取 worker 队列状态。"""
|
||||
counts = await worker.status()
|
||||
with session_scope() as session:
|
||||
counts["ocr_retryable"] = (
|
||||
session.scalar(
|
||||
select(func.count(Screenshot.id)).where(
|
||||
Screenshot.ocr_status == ProcessStatus.FAILED.value,
|
||||
Screenshot.ai_status == ProcessStatus.DONE.value,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
counts["ocr_pending"] = (
|
||||
session.scalar(
|
||||
select(func.count(Job.id)).where(
|
||||
Job.kind == "ocr",
|
||||
Job.status == JobStatus.PENDING.value,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
return counts
|
||||
|
||||
|
||||
def _job_to_out(job: Job, shot: Screenshot | None) -> JobOut:
|
||||
"""ORM -> JobOut,附带截图摘要字段。"""
|
||||
return JobOut(
|
||||
id=job.id,
|
||||
screenshot_id=job.screenshot_id,
|
||||
kind=job.kind,
|
||||
status=job.status,
|
||||
retries=job.retries or 0,
|
||||
last_error=job.last_error,
|
||||
created_at=job.created_at,
|
||||
started_at=job.started_at,
|
||||
finished_at=job.finished_at,
|
||||
thumb_url=(
|
||||
f"/api/screenshots/{shot.id}/thumb" if shot and shot.thumb_path else None
|
||||
),
|
||||
path=shot.path if shot else None,
|
||||
ai_title=(shot.meta.ai_title if shot and shot.meta else None),
|
||||
ai_status=shot.ai_status if shot else None,
|
||||
ocr_status=shot.ocr_status if shot else None,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/jobs", response_model=JobListResp)
|
||||
def list_jobs(
|
||||
status: Optional[str] = Query(None, description="pending|running|done|failed"),
|
||||
kind: Optional[str] = Query(None, description="full|ocr|vlm"),
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(50, ge=1, le=200),
|
||||
session: Session = Depends(db_session),
|
||||
) -> JobListResp:
|
||||
"""分页列出分析任务,默认按 id 倒序(最新的在前)。"""
|
||||
if status and status not in {s.value for s in JobStatus}:
|
||||
raise HTTPException(400, f"无效 status: {status}")
|
||||
|
||||
base = select(Job)
|
||||
if status:
|
||||
base = base.where(Job.status == status)
|
||||
if kind:
|
||||
base = base.where(Job.kind == kind)
|
||||
|
||||
total = session.scalar(select(func.count()).select_from(base.subquery())) or 0
|
||||
jobs = session.scalars(
|
||||
base.order_by(Job.id.desc()).offset((page - 1) * size).limit(size)
|
||||
).all()
|
||||
|
||||
shot_ids = [j.screenshot_id for j in jobs]
|
||||
shots: dict[int, Screenshot] = {}
|
||||
if shot_ids:
|
||||
rows = session.scalars(
|
||||
select(Screenshot)
|
||||
.where(Screenshot.id.in_(shot_ids))
|
||||
.options(selectinload(Screenshot.meta))
|
||||
).all()
|
||||
shots = {s.id: s for s in rows}
|
||||
|
||||
items = [_job_to_out(j, shots.get(j.screenshot_id)) for j in jobs]
|
||||
return JobListResp(items=items, total=total, page=page, size=size)
|
||||
|
||||
|
||||
@router.post("/jobs/retry-failed")
|
||||
def retry_failed_jobs(payload: JobRetryIn | None = None) -> dict:
|
||||
"""将全部或指定 failed 任务重新排队。"""
|
||||
job_ids = payload.job_ids if payload else None
|
||||
count = worker.retry_failed(job_ids)
|
||||
return {"ok": True, "count": count}
|
||||
|
||||
|
||||
@router.post("/jobs/reset-stale")
|
||||
def reset_stale_jobs(
|
||||
minutes: int = Query(5, ge=1, le=1440),
|
||||
reset_all: bool = Query(False, description="为 true 时复位全部 RUNNING"),
|
||||
) -> dict:
|
||||
"""复位僵尸 RUNNING 任务(worker 崩溃或未正常 finish 时)。"""
|
||||
count = worker.reset_stale_running(minutes=minutes, reset_all=reset_all)
|
||||
return {"ok": True, "count": count}
|
||||
|
||||
|
||||
@router.post("/jobs/enqueue-ocr-failed")
|
||||
def enqueue_ocr_failed(limit: int = Query(500, ge=1, le=5000)) -> dict:
|
||||
"""为 AI 已成功但 OCR 失败的截图批量创建 OCR 补跑任务。"""
|
||||
count = enqueue_ocr_jobs(limit=limit)
|
||||
if count:
|
||||
worker.notify()
|
||||
return {"ok": True, "count": count}
|
||||
Reference in New Issue
Block a user