5c028d7952
包含 FastAPI 后端、React 前端、队列/OCR/标签/待办等完整功能。 Co-authored-by: Cursor <cursoragent@cursor.com>
257 lines
8.7 KiB
Python
257 lines
8.7 KiB
Python
"""监听目录的增删改、手动导入、分析队列。"""
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.orm import Session, selectinload
|
|
|
|
from app.api.deps import db_session
|
|
from app.core.db import session_scope
|
|
from app.core.path_utils import (
|
|
count_files_sample,
|
|
is_accessible_dir,
|
|
normalize_user_path,
|
|
)
|
|
from app.models.job import Job, JobStatus
|
|
from app.models.screenshot import ProcessStatus, Screenshot
|
|
from app.models.watch_folder import WatchFolder
|
|
from app.schemas.common import WatchFolderIn, WatchFolderOut
|
|
from app.schemas.job import JobListResp, JobOut, JobRetryIn
|
|
from app.services.analyze import enqueue_ocr_jobs
|
|
from app.services.ingest import ingest_directory
|
|
from app.services.watcher import watcher_service
|
|
from app.services.worker import worker
|
|
|
|
|
|
router = APIRouter(prefix="/api/watch", tags=["watch"])
|
|
|
|
|
|
def _validate_folder_path(raw: str) -> str:
|
|
"""校验并规范化监听目录路径(含 UNC 网络路径)。"""
|
|
normalized = normalize_user_path(raw)
|
|
if not normalized:
|
|
raise HTTPException(400, "路径不能为空")
|
|
if not is_accessible_dir(normalized):
|
|
raise HTTPException(
|
|
400,
|
|
f"目录不存在或无法访问: {normalized}。"
|
|
"请确认 NAS 已挂载、有读权限,UNC 路径形如 \\\\服务器\\共享\\文件夹",
|
|
)
|
|
return normalized
|
|
|
|
|
|
@router.get("/folders", response_model=list[WatchFolderOut])
|
|
def list_folders(session: Session = Depends(db_session)) -> list[WatchFolderOut]:
|
|
rows = session.scalars(select(WatchFolder).order_by(WatchFolder.id)).all()
|
|
return [WatchFolderOut.model_validate(r) for r in rows]
|
|
|
|
|
|
@router.post("/folders", response_model=WatchFolderOut)
|
|
def add_folder(
|
|
payload: WatchFolderIn,
|
|
background: BackgroundTasks,
|
|
session: Session = Depends(db_session),
|
|
) -> WatchFolderOut:
|
|
"""新增监听目录,自动触发一次扫描入库。"""
|
|
normalized = _validate_folder_path(payload.path)
|
|
exists = session.scalar(select(WatchFolder).where(WatchFolder.path == normalized))
|
|
if exists is not None:
|
|
raise HTTPException(400, "目录已存在")
|
|
folder = WatchFolder(
|
|
path=normalized,
|
|
enabled=payload.enabled,
|
|
recursive=payload.recursive,
|
|
is_sensitive=payload.is_sensitive,
|
|
)
|
|
session.add(folder)
|
|
session.commit()
|
|
session.refresh(folder)
|
|
watcher_service.reload()
|
|
background.add_task(_scan_folder, normalized, payload.recursive)
|
|
return WatchFolderOut.model_validate(folder)
|
|
|
|
|
|
@router.patch("/folders/{folder_id}", response_model=WatchFolderOut)
|
|
def update_folder(
|
|
folder_id: int,
|
|
payload: WatchFolderIn,
|
|
session: Session = Depends(db_session),
|
|
) -> WatchFolderOut:
|
|
folder = session.get(WatchFolder, folder_id)
|
|
if folder is None:
|
|
raise HTTPException(404, "folder not found")
|
|
normalized = _validate_folder_path(payload.path)
|
|
folder.path = normalized
|
|
folder.enabled = payload.enabled
|
|
folder.recursive = payload.recursive
|
|
folder.is_sensitive = payload.is_sensitive
|
|
session.commit()
|
|
session.refresh(folder)
|
|
watcher_service.reload()
|
|
return WatchFolderOut.model_validate(folder)
|
|
|
|
|
|
@router.delete("/folders/{folder_id}")
|
|
def delete_folder(
|
|
folder_id: int,
|
|
session: Session = Depends(db_session),
|
|
) -> dict:
|
|
folder = session.get(WatchFolder, folder_id)
|
|
if folder is None:
|
|
raise HTTPException(404, "folder not found")
|
|
session.delete(folder)
|
|
session.commit()
|
|
watcher_service.reload()
|
|
return {"ok": True}
|
|
|
|
|
|
@router.post("/import")
|
|
def import_now(
|
|
payload: WatchFolderIn,
|
|
background: BackgroundTasks,
|
|
) -> dict:
|
|
"""手动触发一次目录扫描(不一定要登记为监听)。"""
|
|
normalized = _validate_folder_path(payload.path)
|
|
background.add_task(_scan_folder, normalized, payload.recursive)
|
|
return {"ok": True, "message": "已在后台扫描"}
|
|
|
|
|
|
@router.post("/validate-path")
|
|
def validate_path(payload: WatchFolderIn) -> dict:
|
|
"""测试目录是否可访问(含 UNC 网络路径),返回抽样文件数。"""
|
|
normalized = _validate_folder_path(payload.path)
|
|
total, samples = count_files_sample(normalized, limit=3)
|
|
return {
|
|
"ok": True,
|
|
"path": normalized,
|
|
"sample_image_count": total,
|
|
"samples": samples,
|
|
"message": f"目录可访问,抽样发现约 {total}+ 张图片",
|
|
}
|
|
|
|
|
|
def _scan_folder(path: str, recursive: bool) -> None:
|
|
"""后台任务:扫描目录入库,再通知 worker。
|
|
|
|
BackgroundTasks 的同步函数运行在线程池中,必须用 threadsafe 入口
|
|
唤醒事件循环,否则 asyncio.Event.set() 会有竞态。
|
|
"""
|
|
with session_scope() as session:
|
|
ingest_directory(session, path, recursive=recursive)
|
|
worker.notify_threadsafe()
|
|
|
|
|
|
@router.get("/queue")
|
|
async def queue_status() -> dict:
|
|
"""读取 worker 队列状态。"""
|
|
counts = await worker.status()
|
|
with session_scope() as session:
|
|
counts["ocr_retryable"] = (
|
|
session.scalar(
|
|
select(func.count(Screenshot.id)).where(
|
|
Screenshot.ocr_status == ProcessStatus.FAILED.value,
|
|
Screenshot.ai_status == ProcessStatus.DONE.value,
|
|
)
|
|
)
|
|
or 0
|
|
)
|
|
counts["ocr_pending"] = (
|
|
session.scalar(
|
|
select(func.count(Job.id)).where(
|
|
Job.kind == "ocr",
|
|
Job.status == JobStatus.PENDING.value,
|
|
)
|
|
)
|
|
or 0
|
|
)
|
|
return counts
|
|
|
|
|
|
def _job_to_out(job: Job, shot: Screenshot | None) -> JobOut:
|
|
"""ORM -> JobOut,附带截图摘要字段。"""
|
|
return JobOut(
|
|
id=job.id,
|
|
screenshot_id=job.screenshot_id,
|
|
kind=job.kind,
|
|
status=job.status,
|
|
retries=job.retries or 0,
|
|
last_error=job.last_error,
|
|
created_at=job.created_at,
|
|
started_at=job.started_at,
|
|
finished_at=job.finished_at,
|
|
thumb_url=(
|
|
f"/api/screenshots/{shot.id}/thumb" if shot and shot.thumb_path else None
|
|
),
|
|
path=shot.path if shot else None,
|
|
ai_title=(shot.meta.ai_title if shot and shot.meta else None),
|
|
ai_status=shot.ai_status if shot else None,
|
|
ocr_status=shot.ocr_status if shot else None,
|
|
)
|
|
|
|
|
|
@router.get("/jobs", response_model=JobListResp)
|
|
def list_jobs(
|
|
status: Optional[str] = Query(None, description="pending|running|done|failed"),
|
|
kind: Optional[str] = Query(None, description="full|ocr|vlm"),
|
|
page: int = Query(1, ge=1),
|
|
size: int = Query(50, ge=1, le=200),
|
|
session: Session = Depends(db_session),
|
|
) -> JobListResp:
|
|
"""分页列出分析任务,默认按 id 倒序(最新的在前)。"""
|
|
if status and status not in {s.value for s in JobStatus}:
|
|
raise HTTPException(400, f"无效 status: {status}")
|
|
|
|
base = select(Job)
|
|
if status:
|
|
base = base.where(Job.status == status)
|
|
if kind:
|
|
base = base.where(Job.kind == kind)
|
|
|
|
total = session.scalar(select(func.count()).select_from(base.subquery())) or 0
|
|
jobs = session.scalars(
|
|
base.order_by(Job.id.desc()).offset((page - 1) * size).limit(size)
|
|
).all()
|
|
|
|
shot_ids = [j.screenshot_id for j in jobs]
|
|
shots: dict[int, Screenshot] = {}
|
|
if shot_ids:
|
|
rows = session.scalars(
|
|
select(Screenshot)
|
|
.where(Screenshot.id.in_(shot_ids))
|
|
.options(selectinload(Screenshot.meta))
|
|
).all()
|
|
shots = {s.id: s for s in rows}
|
|
|
|
items = [_job_to_out(j, shots.get(j.screenshot_id)) for j in jobs]
|
|
return JobListResp(items=items, total=total, page=page, size=size)
|
|
|
|
|
|
@router.post("/jobs/retry-failed")
|
|
def retry_failed_jobs(payload: JobRetryIn | None = None) -> dict:
|
|
"""将全部或指定 failed 任务重新排队。"""
|
|
job_ids = payload.job_ids if payload else None
|
|
count = worker.retry_failed(job_ids)
|
|
return {"ok": True, "count": count}
|
|
|
|
|
|
@router.post("/jobs/reset-stale")
|
|
def reset_stale_jobs(
|
|
minutes: int = Query(5, ge=1, le=1440),
|
|
reset_all: bool = Query(False, description="为 true 时复位全部 RUNNING"),
|
|
) -> dict:
|
|
"""复位僵尸 RUNNING 任务(worker 崩溃或未正常 finish 时)。"""
|
|
count = worker.reset_stale_running(minutes=minutes, reset_all=reset_all)
|
|
return {"ok": True, "count": count}
|
|
|
|
|
|
@router.post("/jobs/enqueue-ocr-failed")
|
|
def enqueue_ocr_failed(limit: int = Query(500, ge=1, le=5000)) -> dict:
|
|
"""为 AI 已成功但 OCR 失败的截图批量创建 OCR 补跑任务。"""
|
|
count = enqueue_ocr_jobs(limit=limit)
|
|
if count:
|
|
worker.notify()
|
|
return {"ok": True, "count": count}
|