Files
SnapAndAnaly/backend/app/api/watch.py
T

257 lines
8.7 KiB
Python
Raw Normal View History

"""监听目录的增删改、手动导入、分析队列。"""
from __future__ import annotations
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
from sqlalchemy import func, select
from sqlalchemy.orm import Session, selectinload
from app.api.deps import db_session
from app.core.db import session_scope
from app.core.path_utils import (
count_files_sample,
is_accessible_dir,
normalize_user_path,
)
from app.models.job import Job, JobStatus
from app.models.screenshot import ProcessStatus, Screenshot
from app.models.watch_folder import WatchFolder
from app.schemas.common import WatchFolderIn, WatchFolderOut
from app.schemas.job import JobListResp, JobOut, JobRetryIn
from app.services.analyze import enqueue_ocr_jobs
from app.services.ingest import ingest_directory
from app.services.watcher import watcher_service
from app.services.worker import worker
router = APIRouter(prefix="/api/watch", tags=["watch"])
def _validate_folder_path(raw: str) -> str:
"""校验并规范化监听目录路径(含 UNC 网络路径)。"""
normalized = normalize_user_path(raw)
if not normalized:
raise HTTPException(400, "路径不能为空")
if not is_accessible_dir(normalized):
raise HTTPException(
400,
f"目录不存在或无法访问: {normalized}"
"请确认 NAS 已挂载、有读权限,UNC 路径形如 \\\\服务器\\共享\\文件夹",
)
return normalized
@router.get("/folders", response_model=list[WatchFolderOut])
def list_folders(session: Session = Depends(db_session)) -> list[WatchFolderOut]:
rows = session.scalars(select(WatchFolder).order_by(WatchFolder.id)).all()
return [WatchFolderOut.model_validate(r) for r in rows]
@router.post("/folders", response_model=WatchFolderOut)
def add_folder(
payload: WatchFolderIn,
background: BackgroundTasks,
session: Session = Depends(db_session),
) -> WatchFolderOut:
"""新增监听目录,自动触发一次扫描入库。"""
normalized = _validate_folder_path(payload.path)
exists = session.scalar(select(WatchFolder).where(WatchFolder.path == normalized))
if exists is not None:
raise HTTPException(400, "目录已存在")
folder = WatchFolder(
path=normalized,
enabled=payload.enabled,
recursive=payload.recursive,
is_sensitive=payload.is_sensitive,
)
session.add(folder)
session.commit()
session.refresh(folder)
watcher_service.reload()
background.add_task(_scan_folder, normalized, payload.recursive)
return WatchFolderOut.model_validate(folder)
@router.patch("/folders/{folder_id}", response_model=WatchFolderOut)
def update_folder(
folder_id: int,
payload: WatchFolderIn,
session: Session = Depends(db_session),
) -> WatchFolderOut:
folder = session.get(WatchFolder, folder_id)
if folder is None:
raise HTTPException(404, "folder not found")
normalized = _validate_folder_path(payload.path)
folder.path = normalized
folder.enabled = payload.enabled
folder.recursive = payload.recursive
folder.is_sensitive = payload.is_sensitive
session.commit()
session.refresh(folder)
watcher_service.reload()
return WatchFolderOut.model_validate(folder)
@router.delete("/folders/{folder_id}")
def delete_folder(
folder_id: int,
session: Session = Depends(db_session),
) -> dict:
folder = session.get(WatchFolder, folder_id)
if folder is None:
raise HTTPException(404, "folder not found")
session.delete(folder)
session.commit()
watcher_service.reload()
return {"ok": True}
@router.post("/import")
def import_now(
payload: WatchFolderIn,
background: BackgroundTasks,
) -> dict:
"""手动触发一次目录扫描(不一定要登记为监听)。"""
normalized = _validate_folder_path(payload.path)
background.add_task(_scan_folder, normalized, payload.recursive)
return {"ok": True, "message": "已在后台扫描"}
@router.post("/validate-path")
def validate_path(payload: WatchFolderIn) -> dict:
"""测试目录是否可访问(含 UNC 网络路径),返回抽样文件数。"""
normalized = _validate_folder_path(payload.path)
total, samples = count_files_sample(normalized, limit=3)
return {
"ok": True,
"path": normalized,
"sample_image_count": total,
"samples": samples,
"message": f"目录可访问,抽样发现约 {total}+ 张图片",
}
def _scan_folder(path: str, recursive: bool) -> None:
"""后台任务:扫描目录入库,再通知 worker。
BackgroundTasks 的同步函数运行在线程池中,必须用 threadsafe 入口
唤醒事件循环,否则 asyncio.Event.set() 会有竞态。
"""
with session_scope() as session:
ingest_directory(session, path, recursive=recursive)
worker.notify_threadsafe()
@router.get("/queue")
async def queue_status() -> dict:
"""读取 worker 队列状态。"""
counts = await worker.status()
with session_scope() as session:
counts["ocr_retryable"] = (
session.scalar(
select(func.count(Screenshot.id)).where(
Screenshot.ocr_status == ProcessStatus.FAILED.value,
Screenshot.ai_status == ProcessStatus.DONE.value,
)
)
or 0
)
counts["ocr_pending"] = (
session.scalar(
select(func.count(Job.id)).where(
Job.kind == "ocr",
Job.status == JobStatus.PENDING.value,
)
)
or 0
)
return counts
def _job_to_out(job: Job, shot: Screenshot | None) -> JobOut:
"""ORM -> JobOut,附带截图摘要字段。"""
return JobOut(
id=job.id,
screenshot_id=job.screenshot_id,
kind=job.kind,
status=job.status,
retries=job.retries or 0,
last_error=job.last_error,
created_at=job.created_at,
started_at=job.started_at,
finished_at=job.finished_at,
thumb_url=(
f"/api/screenshots/{shot.id}/thumb" if shot and shot.thumb_path else None
),
path=shot.path if shot else None,
ai_title=(shot.meta.ai_title if shot and shot.meta else None),
ai_status=shot.ai_status if shot else None,
ocr_status=shot.ocr_status if shot else None,
)
@router.get("/jobs", response_model=JobListResp)
def list_jobs(
status: Optional[str] = Query(None, description="pending|running|done|failed"),
kind: Optional[str] = Query(None, description="full|ocr|vlm"),
page: int = Query(1, ge=1),
size: int = Query(50, ge=1, le=200),
session: Session = Depends(db_session),
) -> JobListResp:
"""分页列出分析任务,默认按 id 倒序(最新的在前)。"""
if status and status not in {s.value for s in JobStatus}:
raise HTTPException(400, f"无效 status: {status}")
base = select(Job)
if status:
base = base.where(Job.status == status)
if kind:
base = base.where(Job.kind == kind)
total = session.scalar(select(func.count()).select_from(base.subquery())) or 0
jobs = session.scalars(
base.order_by(Job.id.desc()).offset((page - 1) * size).limit(size)
).all()
shot_ids = [j.screenshot_id for j in jobs]
shots: dict[int, Screenshot] = {}
if shot_ids:
rows = session.scalars(
select(Screenshot)
.where(Screenshot.id.in_(shot_ids))
.options(selectinload(Screenshot.meta))
).all()
shots = {s.id: s for s in rows}
items = [_job_to_out(j, shots.get(j.screenshot_id)) for j in jobs]
return JobListResp(items=items, total=total, page=page, size=size)
@router.post("/jobs/retry-failed")
def retry_failed_jobs(payload: JobRetryIn | None = None) -> dict:
"""将全部或指定 failed 任务重新排队。"""
job_ids = payload.job_ids if payload else None
count = worker.retry_failed(job_ids)
return {"ok": True, "count": count}
@router.post("/jobs/reset-stale")
def reset_stale_jobs(
minutes: int = Query(5, ge=1, le=1440),
reset_all: bool = Query(False, description="为 true 时复位全部 RUNNING"),
) -> dict:
"""复位僵尸 RUNNING 任务(worker 崩溃或未正常 finish 时)。"""
count = worker.reset_stale_running(minutes=minutes, reset_all=reset_all)
return {"ok": True, "count": count}
@router.post("/jobs/enqueue-ocr-failed")
def enqueue_ocr_failed(limit: int = Query(500, ge=1, le=5000)) -> dict:
"""为 AI 已成功但 OCR 失败的截图批量创建 OCR 补跑任务。"""
count = enqueue_ocr_jobs(limit=limit)
if count:
worker.notify()
return {"ok": True, "count": count}