Files
congsh 5c028d7952 Initial commit: snapAna 截图智能整理工具
包含 FastAPI 后端、React 前端、队列/OCR/标签/待办等完整功能。

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-27 15:45:50 +08:00

368 lines
13 KiB
Python

"""截图列表 / 详情 / 随机 / 重新分析 / 文件流。"""
from __future__ import annotations
from datetime import datetime
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import FileResponse
from sqlalchemy import and_, func, or_, select, text
from sqlalchemy.orm import Session, selectinload
from app.api.deps import db_session
from app.core.path_utils import is_accessible_file, path_from_storage
from app.models.category import Category
from app.models.job import Job, JobKind, JobStatus
from app.models.meta import ScreenshotMeta
from app.models.screenshot import ProcessStatus, Screenshot
from app.models.tag import Tag
from app.services.search_utils import collect_search_ids
from app.schemas.screenshot import (
CategoryOut,
ScreenshotBrief,
ScreenshotDetail,
ScreenshotListResp,
ScreenshotUpdate,
TagOut,
TodoBrief,
)
from app.services.worker import worker
router = APIRouter(prefix="/api/screenshots", tags=["screenshots"])
def _to_brief(shot: Screenshot, cat_map: dict[int, Category]) -> ScreenshotBrief:
"""ORM -> ScreenshotBrief。"""
return ScreenshotBrief(
id=shot.id,
path=shot.path,
width=shot.width,
height=shot.height,
captured_at=shot.captured_at,
thumb_url=f"/api/screenshots/{shot.id}/thumb" if shot.thumb_path else None,
ai_title=(shot.meta.ai_title if shot.meta else None),
ai_status=shot.ai_status,
ocr_status=shot.ocr_status,
is_favorite=bool(shot.is_favorite),
category=(
CategoryOut.model_validate(cat_map[shot.category_id])
if shot.category_id and shot.category_id in cat_map
else None
),
tags=[TagOut.model_validate(t) for t in (shot.tags or [])],
)
def _category_map(session: Session) -> dict[int, Category]:
return {c.id: c for c in session.scalars(select(Category)).all()}
@router.get("", response_model=ScreenshotListResp)
def list_screenshots(
session: Session = Depends(db_session),
q: Optional[str] = Query(None, description="OCR+AI 全文搜索关键词"),
category_id: Optional[int] = Query(None),
tag: Optional[str] = Query(None),
date_from: Optional[datetime] = Query(None),
date_to: Optional[datetime] = Query(None),
favorite: Optional[bool] = Query(None),
status: Optional[str] = Query(None, description="ai_status 过滤"),
show_hidden: bool = Query(False),
sort: str = Query("captured_desc"),
page: int = Query(1, ge=1),
size: int = Query(40, ge=1, le=200),
) -> ScreenshotListResp:
"""主列表查询:支持时间/分类/标签/收藏/状态/搜索词。"""
stmt = select(Screenshot).options(
selectinload(Screenshot.meta),
selectinload(Screenshot.tags),
)
filters = []
if not show_hidden:
filters.append(Screenshot.is_hidden == 0)
if category_id is not None:
filters.append(Screenshot.category_id == category_id)
if date_from is not None:
filters.append(Screenshot.captured_at >= date_from)
if date_to is not None:
filters.append(Screenshot.captured_at <= date_to)
if favorite is True:
filters.append(Screenshot.is_favorite == 1)
if status:
filters.append(Screenshot.ai_status == status)
if tag:
stmt = stmt.join(Screenshot.tags).where(Tag.name.ilike(f"%{tag}%"))
if q:
ids = collect_search_ids(session, q)
if not ids:
return ScreenshotListResp(items=[], total=0, page=page, size=size)
filters.append(Screenshot.id.in_(ids))
if filters:
stmt = stmt.where(and_(*filters))
# 排序
stmt = _apply_sort(stmt, sort)
total = session.scalar(select(func.count()).select_from(stmt.subquery())) or 0
rows = session.scalars(stmt.offset((page - 1) * size).limit(size)).unique().all()
cat_map = _category_map(session)
items = [_to_brief(r, cat_map) for r in rows]
return ScreenshotListResp(items=items, total=int(total), page=page, size=size)
def _apply_sort(stmt, sort: str):
"""列表排序:时间 / 导入 / 标题 / 文件大小。"""
if sort == "captured_asc":
return stmt.order_by(Screenshot.captured_at.asc())
if sort == "imported_desc":
return stmt.order_by(Screenshot.imported_at.desc())
if sort == "imported_asc":
return stmt.order_by(Screenshot.imported_at.asc())
if sort == "title_asc":
return stmt.outerjoin(ScreenshotMeta).order_by(
ScreenshotMeta.ai_title.asc().nulls_last()
)
if sort == "title_desc":
return stmt.outerjoin(ScreenshotMeta).order_by(
ScreenshotMeta.ai_title.desc().nulls_last()
)
if sort == "size_desc":
return stmt.order_by(Screenshot.size.desc())
if sort == "size_asc":
return stmt.order_by(Screenshot.size.asc())
return stmt.order_by(Screenshot.captured_at.desc())
@router.get("/random", response_model=list[ScreenshotBrief])
def random_screenshots(
session: Session = Depends(db_session),
n: int = Query(1, ge=1, le=20),
category_id: Optional[int] = Query(None),
) -> list[ScreenshotBrief]:
"""随机展示。"""
stmt = select(Screenshot).options(
selectinload(Screenshot.meta),
selectinload(Screenshot.tags),
).where(Screenshot.is_hidden == 0)
if category_id is not None:
stmt = stmt.where(Screenshot.category_id == category_id)
stmt = stmt.order_by(func.random()).limit(n)
rows = session.scalars(stmt).unique().all()
cat_map = _category_map(session)
return [_to_brief(r, cat_map) for r in rows]
@router.get("/stats")
def stats(session: Session = Depends(db_session)) -> dict:
"""汇总统计:总数、状态分布、按分类、按月份。"""
total = session.scalar(select(func.count(Screenshot.id))) or 0
by_status = {
st.value: session.scalar(
select(func.count(Screenshot.id)).where(Screenshot.ai_status == st.value)
)
or 0
for st in ProcessStatus
}
by_category_rows = session.execute(
select(Category.id, Category.name, Category.color, func.count(Screenshot.id))
.join(Screenshot, Screenshot.category_id == Category.id, isouter=True)
.group_by(Category.id)
.order_by(func.count(Screenshot.id).desc())
).all()
by_category = [
{"id": r[0], "name": r[1], "color": r[2], "count": int(r[3] or 0)}
for r in by_category_rows
]
by_month_rows = session.execute(
text(
"SELECT strftime('%Y-%m', captured_at) AS m, COUNT(1) AS c "
"FROM screenshots WHERE is_hidden=0 GROUP BY m ORDER BY m DESC LIMIT 36"
)
).all()
by_month = [{"month": r[0], "count": int(r[1])} for r in by_month_rows]
return {
"total": int(total),
"by_status": by_status,
"by_category": by_category,
"by_month": by_month,
"queue": _queue_summary(session),
}
def _queue_summary(session: Session) -> dict:
"""汇总 jobs 队列状态。"""
out: dict[str, int] = {}
for st in JobStatus:
out[st.value] = (
session.scalar(select(func.count(Job.id)).where(Job.status == st.value)) or 0
)
return out
@router.get("/{screenshot_id}", response_model=ScreenshotDetail)
def get_screenshot(
screenshot_id: int,
session: Session = Depends(db_session),
) -> ScreenshotDetail:
"""单张详情。"""
shot = session.get(Screenshot, screenshot_id)
if shot is None:
raise HTTPException(404, "Screenshot not found")
cat_map = _category_map(session)
brief = _to_brief(shot, cat_map)
meta = shot.meta
todos = [TodoBrief.model_validate(t) for t in shot.todos]
return ScreenshotDetail(
**brief.model_dump(),
file_url=f"/api/screenshots/{shot.id}/file",
size=shot.size,
ocr_text=(meta.ocr_text if meta else None),
ai_summary=(meta.ai_summary if meta else None),
ai_suggestion=(meta.ai_suggestion if meta else None),
todos=todos,
)
@router.patch("/{screenshot_id}", response_model=ScreenshotDetail)
def update_screenshot(
screenshot_id: int,
payload: ScreenshotUpdate,
session: Session = Depends(db_session),
) -> ScreenshotDetail:
"""前端编辑:分类、收藏、隐藏、标签。"""
shot = session.get(Screenshot, screenshot_id)
if shot is None:
raise HTTPException(404, "Screenshot not found")
# 用 model_fields_set 区分「未传字段」与「显式传入 null」
# 这样前端 PATCH {"category_id": null} 可以真正清空分类
fields = payload.model_fields_set
if "category_id" in fields:
if payload.category_id is not None:
cat = session.get(Category, payload.category_id)
if cat is None:
raise HTTPException(400, "category not found")
shot.category_id = payload.category_id
if "is_favorite" in fields and payload.is_favorite is not None:
shot.is_favorite = 1 if payload.is_favorite else 0
if "is_hidden" in fields and payload.is_hidden is not None:
shot.is_hidden = 1 if payload.is_hidden else 0
if "tags" in fields and payload.tags is not None:
tag_objs = []
for name in payload.tags:
name = (name or "").strip()[:64]
if not name:
continue
tag = session.scalar(select(Tag).where(Tag.name == name))
if tag is None:
tag = Tag(name=name)
session.add(tag)
session.flush()
tag_objs.append(tag)
shot.tags = tag_objs
session.commit()
session.refresh(shot)
return get_screenshot(screenshot_id, session)
@router.post("/{screenshot_id}/reanalyze")
def reanalyze(
screenshot_id: int,
session: Session = Depends(db_session),
) -> dict:
"""加入队列重新分析。"""
shot = session.get(Screenshot, screenshot_id)
if shot is None:
raise HTTPException(404, "Screenshot not found")
shot.ai_status = ProcessStatus.PENDING.value
shot.ocr_status = ProcessStatus.PENDING.value
job = Job(screenshot_id=shot.id, kind=JobKind.FULL.value, status=JobStatus.PENDING.value)
session.add(job)
session.commit()
# 同步路由跑在线程池,必须 threadsafe 唤醒事件循环
worker.notify_threadsafe()
return {"ok": True, "job_id": job.id}
@router.post("/{screenshot_id}/reocr")
def reocr(
screenshot_id: int,
session: Session = Depends(db_session),
) -> dict:
"""仅补跑 OCR,不重新调用 AI 分析。"""
shot = session.get(Screenshot, screenshot_id)
if shot is None:
raise HTTPException(404, "Screenshot not found")
active = session.scalar(
select(Job.id).where(
Job.screenshot_id == shot.id,
Job.kind == JobKind.OCR.value,
Job.status.in_((JobStatus.PENDING.value, JobStatus.RUNNING.value)),
)
)
if active is not None:
return {"ok": True, "job_id": active, "message": "已有 OCR 任务在队列中"}
shot.ocr_status = ProcessStatus.PENDING.value
job = Job(
screenshot_id=shot.id,
kind=JobKind.OCR.value,
status=JobStatus.PENDING.value,
)
session.add(job)
session.commit()
worker.notify_threadsafe()
return {"ok": True, "job_id": job.id}
@router.delete("/{screenshot_id}")
def delete_screenshot(
screenshot_id: int,
session: Session = Depends(db_session),
) -> dict:
"""删除记录(不删除原始文件)。"""
shot = session.get(Screenshot, screenshot_id)
if shot is None:
raise HTTPException(404, "Screenshot not found")
session.delete(shot)
session.commit()
return {"ok": True}
@router.get("/{screenshot_id}/file")
def get_file(
screenshot_id: int,
session: Session = Depends(db_session),
) -> FileResponse:
"""原图文件流。"""
shot = session.get(Screenshot, screenshot_id)
if shot is None:
raise HTTPException(404, "Screenshot not found")
p = path_from_storage(shot.path)
if not is_accessible_file(p):
raise HTTPException(404, "file missing")
return FileResponse(str(p))
@router.get("/{screenshot_id}/thumb")
def get_thumb(
screenshot_id: int,
session: Session = Depends(db_session),
) -> FileResponse:
"""缩略图流。"""
shot = session.get(Screenshot, screenshot_id)
if shot is None:
raise HTTPException(404, "Screenshot not found")
if shot.thumb_path:
p = Path(shot.thumb_path)
if p.exists():
return FileResponse(str(p), media_type="image/webp")
# 兜底:返回原图
p = path_from_storage(shot.path)
if is_accessible_file(p):
return FileResponse(str(p))
raise HTTPException(404, "thumb missing")