Initial commit: snapAna 截图智能整理工具

包含 FastAPI 后端、React 前端、队列/OCR/标签/待办等完整功能。

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
wjl
2026-05-27 15:45:50 +08:00
commit 5c028d7952
76 changed files with 10467 additions and 0 deletions
View File
+17
View File
@@ -0,0 +1,17 @@
"""API 通用依赖。"""
from __future__ import annotations
from typing import Iterator
from sqlalchemy.orm import Session
from app.core.db import SessionLocal
def db_session() -> Iterator[Session]:
"""每请求一个会话。"""
session = SessionLocal()
try:
yield session
finally:
session.close()
+367
View File
@@ -0,0 +1,367 @@
"""截图列表 / 详情 / 随机 / 重新分析 / 文件流。"""
from __future__ import annotations
from datetime import datetime
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import FileResponse
from sqlalchemy import and_, func, or_, select, text
from sqlalchemy.orm import Session, selectinload
from app.api.deps import db_session
from app.core.path_utils import is_accessible_file, path_from_storage
from app.models.category import Category
from app.models.job import Job, JobKind, JobStatus
from app.models.meta import ScreenshotMeta
from app.models.screenshot import ProcessStatus, Screenshot
from app.models.tag import Tag
from app.services.search_utils import collect_search_ids
from app.schemas.screenshot import (
CategoryOut,
ScreenshotBrief,
ScreenshotDetail,
ScreenshotListResp,
ScreenshotUpdate,
TagOut,
TodoBrief,
)
from app.services.worker import worker
router = APIRouter(prefix="/api/screenshots", tags=["screenshots"])
def _to_brief(shot: Screenshot, cat_map: dict[int, Category]) -> ScreenshotBrief:
"""ORM -> ScreenshotBrief。"""
return ScreenshotBrief(
id=shot.id,
path=shot.path,
width=shot.width,
height=shot.height,
captured_at=shot.captured_at,
thumb_url=f"/api/screenshots/{shot.id}/thumb" if shot.thumb_path else None,
ai_title=(shot.meta.ai_title if shot.meta else None),
ai_status=shot.ai_status,
ocr_status=shot.ocr_status,
is_favorite=bool(shot.is_favorite),
category=(
CategoryOut.model_validate(cat_map[shot.category_id])
if shot.category_id and shot.category_id in cat_map
else None
),
tags=[TagOut.model_validate(t) for t in (shot.tags or [])],
)
def _category_map(session: Session) -> dict[int, Category]:
return {c.id: c for c in session.scalars(select(Category)).all()}
@router.get("", response_model=ScreenshotListResp)
def list_screenshots(
session: Session = Depends(db_session),
q: Optional[str] = Query(None, description="OCR+AI 全文搜索关键词"),
category_id: Optional[int] = Query(None),
tag: Optional[str] = Query(None),
date_from: Optional[datetime] = Query(None),
date_to: Optional[datetime] = Query(None),
favorite: Optional[bool] = Query(None),
status: Optional[str] = Query(None, description="ai_status 过滤"),
show_hidden: bool = Query(False),
sort: str = Query("captured_desc"),
page: int = Query(1, ge=1),
size: int = Query(40, ge=1, le=200),
) -> ScreenshotListResp:
"""主列表查询:支持时间/分类/标签/收藏/状态/搜索词。"""
stmt = select(Screenshot).options(
selectinload(Screenshot.meta),
selectinload(Screenshot.tags),
)
filters = []
if not show_hidden:
filters.append(Screenshot.is_hidden == 0)
if category_id is not None:
filters.append(Screenshot.category_id == category_id)
if date_from is not None:
filters.append(Screenshot.captured_at >= date_from)
if date_to is not None:
filters.append(Screenshot.captured_at <= date_to)
if favorite is True:
filters.append(Screenshot.is_favorite == 1)
if status:
filters.append(Screenshot.ai_status == status)
if tag:
stmt = stmt.join(Screenshot.tags).where(Tag.name.ilike(f"%{tag}%"))
if q:
ids = collect_search_ids(session, q)
if not ids:
return ScreenshotListResp(items=[], total=0, page=page, size=size)
filters.append(Screenshot.id.in_(ids))
if filters:
stmt = stmt.where(and_(*filters))
# 排序
stmt = _apply_sort(stmt, sort)
total = session.scalar(select(func.count()).select_from(stmt.subquery())) or 0
rows = session.scalars(stmt.offset((page - 1) * size).limit(size)).unique().all()
cat_map = _category_map(session)
items = [_to_brief(r, cat_map) for r in rows]
return ScreenshotListResp(items=items, total=int(total), page=page, size=size)
def _apply_sort(stmt, sort: str):
"""列表排序:时间 / 导入 / 标题 / 文件大小。"""
if sort == "captured_asc":
return stmt.order_by(Screenshot.captured_at.asc())
if sort == "imported_desc":
return stmt.order_by(Screenshot.imported_at.desc())
if sort == "imported_asc":
return stmt.order_by(Screenshot.imported_at.asc())
if sort == "title_asc":
return stmt.outerjoin(ScreenshotMeta).order_by(
ScreenshotMeta.ai_title.asc().nulls_last()
)
if sort == "title_desc":
return stmt.outerjoin(ScreenshotMeta).order_by(
ScreenshotMeta.ai_title.desc().nulls_last()
)
if sort == "size_desc":
return stmt.order_by(Screenshot.size.desc())
if sort == "size_asc":
return stmt.order_by(Screenshot.size.asc())
return stmt.order_by(Screenshot.captured_at.desc())
@router.get("/random", response_model=list[ScreenshotBrief])
def random_screenshots(
session: Session = Depends(db_session),
n: int = Query(1, ge=1, le=20),
category_id: Optional[int] = Query(None),
) -> list[ScreenshotBrief]:
"""随机展示。"""
stmt = select(Screenshot).options(
selectinload(Screenshot.meta),
selectinload(Screenshot.tags),
).where(Screenshot.is_hidden == 0)
if category_id is not None:
stmt = stmt.where(Screenshot.category_id == category_id)
stmt = stmt.order_by(func.random()).limit(n)
rows = session.scalars(stmt).unique().all()
cat_map = _category_map(session)
return [_to_brief(r, cat_map) for r in rows]
@router.get("/stats")
def stats(session: Session = Depends(db_session)) -> dict:
"""汇总统计:总数、状态分布、按分类、按月份。"""
total = session.scalar(select(func.count(Screenshot.id))) or 0
by_status = {
st.value: session.scalar(
select(func.count(Screenshot.id)).where(Screenshot.ai_status == st.value)
)
or 0
for st in ProcessStatus
}
by_category_rows = session.execute(
select(Category.id, Category.name, Category.color, func.count(Screenshot.id))
.join(Screenshot, Screenshot.category_id == Category.id, isouter=True)
.group_by(Category.id)
.order_by(func.count(Screenshot.id).desc())
).all()
by_category = [
{"id": r[0], "name": r[1], "color": r[2], "count": int(r[3] or 0)}
for r in by_category_rows
]
by_month_rows = session.execute(
text(
"SELECT strftime('%Y-%m', captured_at) AS m, COUNT(1) AS c "
"FROM screenshots WHERE is_hidden=0 GROUP BY m ORDER BY m DESC LIMIT 36"
)
).all()
by_month = [{"month": r[0], "count": int(r[1])} for r in by_month_rows]
return {
"total": int(total),
"by_status": by_status,
"by_category": by_category,
"by_month": by_month,
"queue": _queue_summary(session),
}
def _queue_summary(session: Session) -> dict:
"""汇总 jobs 队列状态。"""
out: dict[str, int] = {}
for st in JobStatus:
out[st.value] = (
session.scalar(select(func.count(Job.id)).where(Job.status == st.value)) or 0
)
return out
@router.get("/{screenshot_id}", response_model=ScreenshotDetail)
def get_screenshot(
screenshot_id: int,
session: Session = Depends(db_session),
) -> ScreenshotDetail:
"""单张详情。"""
shot = session.get(Screenshot, screenshot_id)
if shot is None:
raise HTTPException(404, "Screenshot not found")
cat_map = _category_map(session)
brief = _to_brief(shot, cat_map)
meta = shot.meta
todos = [TodoBrief.model_validate(t) for t in shot.todos]
return ScreenshotDetail(
**brief.model_dump(),
file_url=f"/api/screenshots/{shot.id}/file",
size=shot.size,
ocr_text=(meta.ocr_text if meta else None),
ai_summary=(meta.ai_summary if meta else None),
ai_suggestion=(meta.ai_suggestion if meta else None),
todos=todos,
)
@router.patch("/{screenshot_id}", response_model=ScreenshotDetail)
def update_screenshot(
screenshot_id: int,
payload: ScreenshotUpdate,
session: Session = Depends(db_session),
) -> ScreenshotDetail:
"""前端编辑:分类、收藏、隐藏、标签。"""
shot = session.get(Screenshot, screenshot_id)
if shot is None:
raise HTTPException(404, "Screenshot not found")
# 用 model_fields_set 区分「未传字段」与「显式传入 null」
# 这样前端 PATCH {"category_id": null} 可以真正清空分类
fields = payload.model_fields_set
if "category_id" in fields:
if payload.category_id is not None:
cat = session.get(Category, payload.category_id)
if cat is None:
raise HTTPException(400, "category not found")
shot.category_id = payload.category_id
if "is_favorite" in fields and payload.is_favorite is not None:
shot.is_favorite = 1 if payload.is_favorite else 0
if "is_hidden" in fields and payload.is_hidden is not None:
shot.is_hidden = 1 if payload.is_hidden else 0
if "tags" in fields and payload.tags is not None:
tag_objs = []
for name in payload.tags:
name = (name or "").strip()[:64]
if not name:
continue
tag = session.scalar(select(Tag).where(Tag.name == name))
if tag is None:
tag = Tag(name=name)
session.add(tag)
session.flush()
tag_objs.append(tag)
shot.tags = tag_objs
session.commit()
session.refresh(shot)
return get_screenshot(screenshot_id, session)
@router.post("/{screenshot_id}/reanalyze")
def reanalyze(
screenshot_id: int,
session: Session = Depends(db_session),
) -> dict:
"""加入队列重新分析。"""
shot = session.get(Screenshot, screenshot_id)
if shot is None:
raise HTTPException(404, "Screenshot not found")
shot.ai_status = ProcessStatus.PENDING.value
shot.ocr_status = ProcessStatus.PENDING.value
job = Job(screenshot_id=shot.id, kind=JobKind.FULL.value, status=JobStatus.PENDING.value)
session.add(job)
session.commit()
# 同步路由跑在线程池,必须 threadsafe 唤醒事件循环
worker.notify_threadsafe()
return {"ok": True, "job_id": job.id}
@router.post("/{screenshot_id}/reocr")
def reocr(
screenshot_id: int,
session: Session = Depends(db_session),
) -> dict:
"""仅补跑 OCR,不重新调用 AI 分析。"""
shot = session.get(Screenshot, screenshot_id)
if shot is None:
raise HTTPException(404, "Screenshot not found")
active = session.scalar(
select(Job.id).where(
Job.screenshot_id == shot.id,
Job.kind == JobKind.OCR.value,
Job.status.in_((JobStatus.PENDING.value, JobStatus.RUNNING.value)),
)
)
if active is not None:
return {"ok": True, "job_id": active, "message": "已有 OCR 任务在队列中"}
shot.ocr_status = ProcessStatus.PENDING.value
job = Job(
screenshot_id=shot.id,
kind=JobKind.OCR.value,
status=JobStatus.PENDING.value,
)
session.add(job)
session.commit()
worker.notify_threadsafe()
return {"ok": True, "job_id": job.id}
@router.delete("/{screenshot_id}")
def delete_screenshot(
screenshot_id: int,
session: Session = Depends(db_session),
) -> dict:
"""删除记录(不删除原始文件)。"""
shot = session.get(Screenshot, screenshot_id)
if shot is None:
raise HTTPException(404, "Screenshot not found")
session.delete(shot)
session.commit()
return {"ok": True}
@router.get("/{screenshot_id}/file")
def get_file(
screenshot_id: int,
session: Session = Depends(db_session),
) -> FileResponse:
"""原图文件流。"""
shot = session.get(Screenshot, screenshot_id)
if shot is None:
raise HTTPException(404, "Screenshot not found")
p = path_from_storage(shot.path)
if not is_accessible_file(p):
raise HTTPException(404, "file missing")
return FileResponse(str(p))
@router.get("/{screenshot_id}/thumb")
def get_thumb(
screenshot_id: int,
session: Session = Depends(db_session),
) -> FileResponse:
"""缩略图流。"""
shot = session.get(Screenshot, screenshot_id)
if shot is None:
raise HTTPException(404, "Screenshot not found")
if shot.thumb_path:
p = Path(shot.thumb_path)
if p.exists():
return FileResponse(str(p), media_type="image/webp")
# 兜底:返回原图
p = path_from_storage(shot.path)
if is_accessible_file(p):
return FileResponse(str(p))
raise HTTPException(404, "thumb missing")
+220
View File
@@ -0,0 +1,220 @@
"""设置接口:Provider 配置、分类、Tag。"""
from __future__ import annotations
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from app.api.deps import db_session
from app.models.category import Category
from app.models.setting import (
DEFAULT_RECOGNITION_MODE,
KEY_OCR_PROVIDER,
KEY_RECOGNITION_MODE,
KEY_VLM_PROVIDER,
)
from app.models.screenshot import Screenshot
from app.models.tag import Tag
from app.providers import RECOGNITION_MODES
from app.schemas.common import (
CategoryIn,
ProviderConfig,
ProviderConfigOut,
ProviderTestResult,
RecognitionModeIn,
)
from app.services.provider_test import merge_provider_api_key, test_provider_config
from app.services.settings_store import all_settings, get_setting, set_setting
router = APIRouter(prefix="/api/settings", tags=["settings"])
@router.get("")
def get_all(session: Session = Depends(db_session)) -> dict:
"""返回所有非敏感设置。api_key 字段做脱敏。"""
raw = all_settings(session)
for key in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
cfg = raw.get(key)
if isinstance(cfg, dict) and cfg.get("api_key"):
cfg["api_key_mask"] = _mask(cfg["api_key"])
cfg["api_key"] = ""
return raw
def _mask(value: str) -> str:
if not value:
return ""
if len(value) <= 6:
return "*" * len(value)
return value[:3] + "*" * (len(value) - 6) + value[-3:]
@router.get("/providers/{key}", response_model=ProviderConfigOut | None)
def get_provider(
key: str,
session: Session = Depends(db_session),
) -> ProviderConfigOut | None:
"""读取 Provider 配置:api_key 明文不外传,只给一个掩码用于 UI 提示。"""
if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
raise HTTPException(400, "key must be ocr_provider or vlm_provider")
raw = get_setting(session, key, None)
if not raw:
return None
mask = _mask(raw.get("api_key", "") or "")
return ProviderConfigOut(
type=raw.get("type", ""),
base_url=raw.get("base_url"),
api_key="",
api_key_mask=mask or None,
model=raw.get("model"),
extra=raw.get("extra", {}) or {},
)
@router.put("/providers/{key}")
def put_provider(
key: str,
cfg: ProviderConfig,
session: Session = Depends(db_session),
) -> dict:
if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
raise HTTPException(400, "key must be ocr_provider or vlm_provider")
# 如果客户端没有传新的 api_key(空字符串),保留旧值
existing = get_setting(session, key, None)
payload = cfg.model_dump()
if (not payload.get("api_key")) and isinstance(existing, dict):
payload["api_key"] = existing.get("api_key", "")
set_setting(session, key, payload)
session.commit()
return {"ok": True}
@router.post("/providers/{key}/test", response_model=ProviderTestResult)
async def test_provider(
key: str,
cfg: ProviderConfig,
session: Session = Depends(db_session),
) -> ProviderTestResult:
"""测试 OCR / 视觉 AI Provider 连通性(使用当前表单配置,api_key 可留空沿用已保存值)。"""
if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
raise HTTPException(400, "key must be ocr_provider or vlm_provider")
existing = get_setting(session, key, None)
merged = merge_provider_api_key(cfg, existing if isinstance(existing, dict) else None)
result = await test_provider_config(key, merged)
return ProviderTestResult(**result)
@router.get("/recognition-mode")
def get_recognition_mode(session: Session = Depends(db_session)) -> dict:
"""读取文字识别策略:ocr / vision / hybrid。"""
mode = get_setting(session, KEY_RECOGNITION_MODE, DEFAULT_RECOGNITION_MODE)
if mode not in RECOGNITION_MODES:
mode = DEFAULT_RECOGNITION_MODE
return {"mode": mode, "options": list(RECOGNITION_MODES)}
@router.put("/recognition-mode")
def put_recognition_mode(
payload: RecognitionModeIn,
session: Session = Depends(db_session),
) -> dict:
"""保存文字识别策略。"""
if payload.mode not in RECOGNITION_MODES:
raise HTTPException(400, f"mode must be one of {RECOGNITION_MODES}")
set_setting(session, KEY_RECOGNITION_MODE, payload.mode)
session.commit()
return {"ok": True, "mode": payload.mode}
@router.get("/categories")
def list_categories(session: Session = Depends(db_session)) -> list[dict]:
rows = session.scalars(select(Category).order_by(Category.id)).all()
return [
{"id": c.id, "name": c.name, "color": c.color, "prompt_hint": c.prompt_hint}
for c in rows
]
@router.post("/categories")
def create_category(
payload: CategoryIn,
session: Session = Depends(db_session),
) -> dict:
exists = session.scalar(select(Category).where(Category.name == payload.name))
if exists is not None:
raise HTTPException(400, "category exists")
cat = Category(name=payload.name, color=payload.color, prompt_hint=payload.prompt_hint)
session.add(cat)
session.commit()
session.refresh(cat)
return {"id": cat.id}
@router.patch("/categories/{cat_id}")
def update_category(
cat_id: int,
payload: CategoryIn,
session: Session = Depends(db_session),
) -> dict:
cat = session.get(Category, cat_id)
if cat is None:
raise HTTPException(404, "category not found")
cat.name = payload.name
cat.color = payload.color
cat.prompt_hint = payload.prompt_hint
session.commit()
return {"ok": True}
@router.delete("/categories/{cat_id}")
def delete_category(
cat_id: int,
session: Session = Depends(db_session),
) -> dict:
cat = session.get(Category, cat_id)
if cat is None:
raise HTTPException(404, "category not found")
session.delete(cat)
session.commit()
return {"ok": True}
@router.get("/tags")
def list_tags(
session: Session = Depends(db_session),
q: Optional[str] = Query(None, description="标签名关键词"),
page: int = Query(1, ge=1),
size: int = Query(200, ge=1, le=500),
sort: str = Query("count_desc", description="count_desc|count_asc|name_asc|name_desc"),
) -> dict:
"""标签列表(含使用次数),支持搜索与分页。"""
base = select(Tag.id)
if q:
base = base.where(Tag.name.ilike(f"%{q.strip()}%"))
total = session.scalar(select(func.count()).select_from(base.subquery())) or 0
stmt = (
select(Tag.id, Tag.name, Tag.color, func.count(Screenshot.id))
.join(Tag.screenshots, isouter=True)
.group_by(Tag.id)
)
if q:
stmt = stmt.where(Tag.name.ilike(f"%{q.strip()}%"))
if sort == "count_asc":
stmt = stmt.order_by(func.count(Screenshot.id).asc())
elif sort == "name_asc":
stmt = stmt.order_by(Tag.name.asc())
elif sort == "name_desc":
stmt = stmt.order_by(Tag.name.desc())
else:
stmt = stmt.order_by(func.count(Screenshot.id).desc())
rows = session.execute(stmt.offset((page - 1) * size).limit(size)).all()
items = [
{"id": r[0], "name": r[1], "color": r[2], "count": int(r[3] or 0)} for r in rows
]
return {"items": items, "total": int(total), "page": page, "size": size}
+106
View File
@@ -0,0 +1,106 @@
"""待办清单接口。"""
from __future__ import annotations
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import Session
from app.api.deps import db_session
from app.models.todo import Todo, TodoStatus
from app.schemas.common import TodoUpdate
from app.schemas.screenshot import TodoBrief, TodoListResp
router = APIRouter(prefix="/api/todos", tags=["todos"])
def _todo_filters(
status: Optional[str],
kind: Optional[str],
q: Optional[str],
) -> list:
"""构建待办筛选条件。"""
filters = []
if status:
filters.append(Todo.status == status)
if kind:
filters.append(Todo.kind == kind)
if q:
like = f"%{q.strip()}%"
filters.append(or_(Todo.title.ilike(like), Todo.note.ilike(like)))
return filters
@router.get("", response_model=TodoListResp)
def list_todos(
session: Session = Depends(db_session),
status: Optional[str] = Query(None),
kind: Optional[str] = Query(None),
q: Optional[str] = Query(None, description="标题/备注关键词"),
page: int = Query(1, ge=1),
size: int = Query(50, ge=1, le=200),
) -> TodoListResp:
"""按状态/类型/关键词分页查询。"""
filters = _todo_filters(status, kind, q)
base = select(Todo)
if filters:
base = base.where(and_(*filters))
total = session.scalar(select(func.count()).select_from(base.subquery())) or 0
rows = session.scalars(
base.order_by(Todo.created_at.desc()).offset((page - 1) * size).limit(size)
).all()
return TodoListResp(
items=[TodoBrief.model_validate(r) for r in rows],
total=int(total),
page=page,
size=size,
)
@router.get("/summary")
def summary(session: Session = Depends(db_session)) -> dict:
"""各状态待办数量。"""
return {
st.value: session.scalar(select(func.count(Todo.id)).where(Todo.status == st.value)) or 0
for st in TodoStatus
}
@router.patch("/{todo_id}", response_model=TodoBrief)
def update_todo(
todo_id: int,
payload: TodoUpdate,
session: Session = Depends(db_session),
) -> TodoBrief:
"""更新状态/标题/备注。"""
todo = session.get(Todo, todo_id)
if todo is None:
raise HTTPException(404, "Todo not found")
if payload.status is not None:
todo.status = payload.status
if payload.status == TodoStatus.DONE.value:
todo.completed_at = datetime.utcnow()
if payload.title is not None:
todo.title = payload.title
if payload.note is not None:
todo.note = payload.note
session.commit()
session.refresh(todo)
return TodoBrief.model_validate(todo)
@router.delete("/{todo_id}")
def delete_todo(
todo_id: int,
session: Session = Depends(db_session),
) -> dict:
todo = session.get(Todo, todo_id)
if todo is None:
raise HTTPException(404, "Todo not found")
session.delete(todo)
session.commit()
return {"ok": True}
+256
View File
@@ -0,0 +1,256 @@
"""监听目录的增删改、手动导入、分析队列。"""
from __future__ import annotations
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
from sqlalchemy import func, select
from sqlalchemy.orm import Session, selectinload
from app.api.deps import db_session
from app.core.db import session_scope
from app.core.path_utils import (
count_files_sample,
is_accessible_dir,
normalize_user_path,
)
from app.models.job import Job, JobStatus
from app.models.screenshot import ProcessStatus, Screenshot
from app.models.watch_folder import WatchFolder
from app.schemas.common import WatchFolderIn, WatchFolderOut
from app.schemas.job import JobListResp, JobOut, JobRetryIn
from app.services.analyze import enqueue_ocr_jobs
from app.services.ingest import ingest_directory
from app.services.watcher import watcher_service
from app.services.worker import worker
router = APIRouter(prefix="/api/watch", tags=["watch"])
def _validate_folder_path(raw: str) -> str:
"""校验并规范化监听目录路径(含 UNC 网络路径)。"""
normalized = normalize_user_path(raw)
if not normalized:
raise HTTPException(400, "路径不能为空")
if not is_accessible_dir(normalized):
raise HTTPException(
400,
f"目录不存在或无法访问: {normalized}"
"请确认 NAS 已挂载、有读权限,UNC 路径形如 \\\\服务器\\共享\\文件夹",
)
return normalized
@router.get("/folders", response_model=list[WatchFolderOut])
def list_folders(session: Session = Depends(db_session)) -> list[WatchFolderOut]:
rows = session.scalars(select(WatchFolder).order_by(WatchFolder.id)).all()
return [WatchFolderOut.model_validate(r) for r in rows]
@router.post("/folders", response_model=WatchFolderOut)
def add_folder(
payload: WatchFolderIn,
background: BackgroundTasks,
session: Session = Depends(db_session),
) -> WatchFolderOut:
"""新增监听目录,自动触发一次扫描入库。"""
normalized = _validate_folder_path(payload.path)
exists = session.scalar(select(WatchFolder).where(WatchFolder.path == normalized))
if exists is not None:
raise HTTPException(400, "目录已存在")
folder = WatchFolder(
path=normalized,
enabled=payload.enabled,
recursive=payload.recursive,
is_sensitive=payload.is_sensitive,
)
session.add(folder)
session.commit()
session.refresh(folder)
watcher_service.reload()
background.add_task(_scan_folder, normalized, payload.recursive)
return WatchFolderOut.model_validate(folder)
@router.patch("/folders/{folder_id}", response_model=WatchFolderOut)
def update_folder(
folder_id: int,
payload: WatchFolderIn,
session: Session = Depends(db_session),
) -> WatchFolderOut:
folder = session.get(WatchFolder, folder_id)
if folder is None:
raise HTTPException(404, "folder not found")
normalized = _validate_folder_path(payload.path)
folder.path = normalized
folder.enabled = payload.enabled
folder.recursive = payload.recursive
folder.is_sensitive = payload.is_sensitive
session.commit()
session.refresh(folder)
watcher_service.reload()
return WatchFolderOut.model_validate(folder)
@router.delete("/folders/{folder_id}")
def delete_folder(
folder_id: int,
session: Session = Depends(db_session),
) -> dict:
folder = session.get(WatchFolder, folder_id)
if folder is None:
raise HTTPException(404, "folder not found")
session.delete(folder)
session.commit()
watcher_service.reload()
return {"ok": True}
@router.post("/import")
def import_now(
payload: WatchFolderIn,
background: BackgroundTasks,
) -> dict:
"""手动触发一次目录扫描(不一定要登记为监听)。"""
normalized = _validate_folder_path(payload.path)
background.add_task(_scan_folder, normalized, payload.recursive)
return {"ok": True, "message": "已在后台扫描"}
@router.post("/validate-path")
def validate_path(payload: WatchFolderIn) -> dict:
"""测试目录是否可访问(含 UNC 网络路径),返回抽样文件数。"""
normalized = _validate_folder_path(payload.path)
total, samples = count_files_sample(normalized, limit=3)
return {
"ok": True,
"path": normalized,
"sample_image_count": total,
"samples": samples,
"message": f"目录可访问,抽样发现约 {total}+ 张图片",
}
def _scan_folder(path: str, recursive: bool) -> None:
"""后台任务:扫描目录入库,再通知 worker。
BackgroundTasks 的同步函数运行在线程池中,必须用 threadsafe 入口
唤醒事件循环,否则 asyncio.Event.set() 会有竞态。
"""
with session_scope() as session:
ingest_directory(session, path, recursive=recursive)
worker.notify_threadsafe()
@router.get("/queue")
async def queue_status() -> dict:
"""读取 worker 队列状态。"""
counts = await worker.status()
with session_scope() as session:
counts["ocr_retryable"] = (
session.scalar(
select(func.count(Screenshot.id)).where(
Screenshot.ocr_status == ProcessStatus.FAILED.value,
Screenshot.ai_status == ProcessStatus.DONE.value,
)
)
or 0
)
counts["ocr_pending"] = (
session.scalar(
select(func.count(Job.id)).where(
Job.kind == "ocr",
Job.status == JobStatus.PENDING.value,
)
)
or 0
)
return counts
def _job_to_out(job: Job, shot: Screenshot | None) -> JobOut:
"""ORM -> JobOut,附带截图摘要字段。"""
return JobOut(
id=job.id,
screenshot_id=job.screenshot_id,
kind=job.kind,
status=job.status,
retries=job.retries or 0,
last_error=job.last_error,
created_at=job.created_at,
started_at=job.started_at,
finished_at=job.finished_at,
thumb_url=(
f"/api/screenshots/{shot.id}/thumb" if shot and shot.thumb_path else None
),
path=shot.path if shot else None,
ai_title=(shot.meta.ai_title if shot and shot.meta else None),
ai_status=shot.ai_status if shot else None,
ocr_status=shot.ocr_status if shot else None,
)
@router.get("/jobs", response_model=JobListResp)
def list_jobs(
status: Optional[str] = Query(None, description="pending|running|done|failed"),
kind: Optional[str] = Query(None, description="full|ocr|vlm"),
page: int = Query(1, ge=1),
size: int = Query(50, ge=1, le=200),
session: Session = Depends(db_session),
) -> JobListResp:
"""分页列出分析任务,默认按 id 倒序(最新的在前)。"""
if status and status not in {s.value for s in JobStatus}:
raise HTTPException(400, f"无效 status: {status}")
base = select(Job)
if status:
base = base.where(Job.status == status)
if kind:
base = base.where(Job.kind == kind)
total = session.scalar(select(func.count()).select_from(base.subquery())) or 0
jobs = session.scalars(
base.order_by(Job.id.desc()).offset((page - 1) * size).limit(size)
).all()
shot_ids = [j.screenshot_id for j in jobs]
shots: dict[int, Screenshot] = {}
if shot_ids:
rows = session.scalars(
select(Screenshot)
.where(Screenshot.id.in_(shot_ids))
.options(selectinload(Screenshot.meta))
).all()
shots = {s.id: s for s in rows}
items = [_job_to_out(j, shots.get(j.screenshot_id)) for j in jobs]
return JobListResp(items=items, total=total, page=page, size=size)
@router.post("/jobs/retry-failed")
def retry_failed_jobs(payload: JobRetryIn | None = None) -> dict:
"""将全部或指定 failed 任务重新排队。"""
job_ids = payload.job_ids if payload else None
count = worker.retry_failed(job_ids)
return {"ok": True, "count": count}
@router.post("/jobs/reset-stale")
def reset_stale_jobs(
minutes: int = Query(5, ge=1, le=1440),
reset_all: bool = Query(False, description="为 true 时复位全部 RUNNING"),
) -> dict:
"""复位僵尸 RUNNING 任务(worker 崩溃或未正常 finish 时)。"""
count = worker.reset_stale_running(minutes=minutes, reset_all=reset_all)
return {"ok": True, "count": count}
@router.post("/jobs/enqueue-ocr-failed")
def enqueue_ocr_failed(limit: int = Query(500, ge=1, le=5000)) -> dict:
"""为 AI 已成功但 OCR 失败的截图批量创建 OCR 补跑任务。"""
count = enqueue_ocr_jobs(limit=limit)
if count:
worker.notify()
return {"ok": True, "count": count}