"""截图列表 / 详情 / 随机 / 重新分析 / 文件流。""" from __future__ import annotations from datetime import datetime from pathlib import Path from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query from fastapi.responses import FileResponse from sqlalchemy import and_, func, or_, select, text from sqlalchemy.orm import Session, selectinload from app.api.deps import db_session from app.core.path_utils import is_accessible_file, path_from_storage from app.models.category import Category from app.models.job import Job, JobKind, JobStatus from app.models.meta import ScreenshotMeta from app.models.screenshot import ProcessStatus, Screenshot from app.models.tag import Tag from app.services.search_utils import collect_search_ids from app.schemas.screenshot import ( CategoryOut, ScreenshotBrief, ScreenshotDetail, ScreenshotListResp, ScreenshotUpdate, TagOut, TodoBrief, ) from app.services.worker import worker router = APIRouter(prefix="/api/screenshots", tags=["screenshots"]) def _to_brief(shot: Screenshot, cat_map: dict[int, Category]) -> ScreenshotBrief: """ORM -> ScreenshotBrief。""" return ScreenshotBrief( id=shot.id, path=shot.path, width=shot.width, height=shot.height, captured_at=shot.captured_at, thumb_url=f"/api/screenshots/{shot.id}/thumb" if shot.thumb_path else None, ai_title=(shot.meta.ai_title if shot.meta else None), ai_status=shot.ai_status, ocr_status=shot.ocr_status, is_favorite=bool(shot.is_favorite), category=( CategoryOut.model_validate(cat_map[shot.category_id]) if shot.category_id and shot.category_id in cat_map else None ), tags=[TagOut.model_validate(t) for t in (shot.tags or [])], ) def _category_map(session: Session) -> dict[int, Category]: return {c.id: c for c in session.scalars(select(Category)).all()} @router.get("", response_model=ScreenshotListResp) def list_screenshots( session: Session = Depends(db_session), q: Optional[str] = Query(None, description="OCR+AI 全文搜索关键词"), category_id: Optional[int] = Query(None), tag: Optional[str] = Query(None), date_from: Optional[datetime] = Query(None), date_to: Optional[datetime] = Query(None), favorite: Optional[bool] = Query(None), status: Optional[str] = Query(None, description="ai_status 过滤"), show_hidden: bool = Query(False), sort: str = Query("captured_desc"), page: int = Query(1, ge=1), size: int = Query(40, ge=1, le=200), ) -> ScreenshotListResp: """主列表查询:支持时间/分类/标签/收藏/状态/搜索词。""" stmt = select(Screenshot).options( selectinload(Screenshot.meta), selectinload(Screenshot.tags), ) filters = [] if not show_hidden: filters.append(Screenshot.is_hidden == 0) if category_id is not None: filters.append(Screenshot.category_id == category_id) if date_from is not None: filters.append(Screenshot.captured_at >= date_from) if date_to is not None: filters.append(Screenshot.captured_at <= date_to) if favorite is True: filters.append(Screenshot.is_favorite == 1) if status: filters.append(Screenshot.ai_status == status) if tag: stmt = stmt.join(Screenshot.tags).where(Tag.name.ilike(f"%{tag}%")) if q: ids = collect_search_ids(session, q) if not ids: return ScreenshotListResp(items=[], total=0, page=page, size=size) filters.append(Screenshot.id.in_(ids)) if filters: stmt = stmt.where(and_(*filters)) # 排序 stmt = _apply_sort(stmt, sort) total = session.scalar(select(func.count()).select_from(stmt.subquery())) or 0 rows = session.scalars(stmt.offset((page - 1) * size).limit(size)).unique().all() cat_map = _category_map(session) items = [_to_brief(r, cat_map) for r in rows] return ScreenshotListResp(items=items, total=int(total), page=page, size=size) def _apply_sort(stmt, sort: str): """列表排序:时间 / 导入 / 标题 / 文件大小。""" if sort == "captured_asc": return stmt.order_by(Screenshot.captured_at.asc()) if sort == "imported_desc": return stmt.order_by(Screenshot.imported_at.desc()) if sort == "imported_asc": return stmt.order_by(Screenshot.imported_at.asc()) if sort == "title_asc": return stmt.outerjoin(ScreenshotMeta).order_by( ScreenshotMeta.ai_title.asc().nulls_last() ) if sort == "title_desc": return stmt.outerjoin(ScreenshotMeta).order_by( ScreenshotMeta.ai_title.desc().nulls_last() ) if sort == "size_desc": return stmt.order_by(Screenshot.size.desc()) if sort == "size_asc": return stmt.order_by(Screenshot.size.asc()) return stmt.order_by(Screenshot.captured_at.desc()) @router.get("/random", response_model=list[ScreenshotBrief]) def random_screenshots( session: Session = Depends(db_session), n: int = Query(1, ge=1, le=20), category_id: Optional[int] = Query(None), ) -> list[ScreenshotBrief]: """随机展示。""" stmt = select(Screenshot).options( selectinload(Screenshot.meta), selectinload(Screenshot.tags), ).where(Screenshot.is_hidden == 0) if category_id is not None: stmt = stmt.where(Screenshot.category_id == category_id) stmt = stmt.order_by(func.random()).limit(n) rows = session.scalars(stmt).unique().all() cat_map = _category_map(session) return [_to_brief(r, cat_map) for r in rows] @router.get("/stats") def stats(session: Session = Depends(db_session)) -> dict: """汇总统计:总数、状态分布、按分类、按月份。""" total = session.scalar(select(func.count(Screenshot.id))) or 0 by_status = { st.value: session.scalar( select(func.count(Screenshot.id)).where(Screenshot.ai_status == st.value) ) or 0 for st in ProcessStatus } by_category_rows = session.execute( select(Category.id, Category.name, Category.color, func.count(Screenshot.id)) .join(Screenshot, Screenshot.category_id == Category.id, isouter=True) .group_by(Category.id) .order_by(func.count(Screenshot.id).desc()) ).all() by_category = [ {"id": r[0], "name": r[1], "color": r[2], "count": int(r[3] or 0)} for r in by_category_rows ] by_month_rows = session.execute( text( "SELECT strftime('%Y-%m', captured_at) AS m, COUNT(1) AS c " "FROM screenshots WHERE is_hidden=0 GROUP BY m ORDER BY m DESC LIMIT 36" ) ).all() by_month = [{"month": r[0], "count": int(r[1])} for r in by_month_rows] return { "total": int(total), "by_status": by_status, "by_category": by_category, "by_month": by_month, "queue": _queue_summary(session), } def _queue_summary(session: Session) -> dict: """汇总 jobs 队列状态。""" out: dict[str, int] = {} for st in JobStatus: out[st.value] = ( session.scalar(select(func.count(Job.id)).where(Job.status == st.value)) or 0 ) return out @router.get("/{screenshot_id}", response_model=ScreenshotDetail) def get_screenshot( screenshot_id: int, session: Session = Depends(db_session), ) -> ScreenshotDetail: """单张详情。""" shot = session.get(Screenshot, screenshot_id) if shot is None: raise HTTPException(404, "Screenshot not found") cat_map = _category_map(session) brief = _to_brief(shot, cat_map) meta = shot.meta todos = [TodoBrief.model_validate(t) for t in shot.todos] return ScreenshotDetail( **brief.model_dump(), file_url=f"/api/screenshots/{shot.id}/file", size=shot.size, ocr_text=(meta.ocr_text if meta else None), ai_summary=(meta.ai_summary if meta else None), ai_suggestion=(meta.ai_suggestion if meta else None), todos=todos, ) @router.patch("/{screenshot_id}", response_model=ScreenshotDetail) def update_screenshot( screenshot_id: int, payload: ScreenshotUpdate, session: Session = Depends(db_session), ) -> ScreenshotDetail: """前端编辑:分类、收藏、隐藏、标签。""" shot = session.get(Screenshot, screenshot_id) if shot is None: raise HTTPException(404, "Screenshot not found") # 用 model_fields_set 区分「未传字段」与「显式传入 null」 # 这样前端 PATCH {"category_id": null} 可以真正清空分类 fields = payload.model_fields_set if "category_id" in fields: if payload.category_id is not None: cat = session.get(Category, payload.category_id) if cat is None: raise HTTPException(400, "category not found") shot.category_id = payload.category_id if "is_favorite" in fields and payload.is_favorite is not None: shot.is_favorite = 1 if payload.is_favorite else 0 if "is_hidden" in fields and payload.is_hidden is not None: shot.is_hidden = 1 if payload.is_hidden else 0 if "tags" in fields and payload.tags is not None: tag_objs = [] for name in payload.tags: name = (name or "").strip()[:64] if not name: continue tag = session.scalar(select(Tag).where(Tag.name == name)) if tag is None: tag = Tag(name=name) session.add(tag) session.flush() tag_objs.append(tag) shot.tags = tag_objs session.commit() session.refresh(shot) return get_screenshot(screenshot_id, session) @router.post("/{screenshot_id}/reanalyze") def reanalyze( screenshot_id: int, session: Session = Depends(db_session), ) -> dict: """加入队列重新分析。""" shot = session.get(Screenshot, screenshot_id) if shot is None: raise HTTPException(404, "Screenshot not found") shot.ai_status = ProcessStatus.PENDING.value shot.ocr_status = ProcessStatus.PENDING.value job = Job(screenshot_id=shot.id, kind=JobKind.FULL.value, status=JobStatus.PENDING.value) session.add(job) session.commit() # 同步路由跑在线程池,必须 threadsafe 唤醒事件循环 worker.notify_threadsafe() return {"ok": True, "job_id": job.id} @router.post("/{screenshot_id}/reocr") def reocr( screenshot_id: int, session: Session = Depends(db_session), ) -> dict: """仅补跑 OCR,不重新调用 AI 分析。""" shot = session.get(Screenshot, screenshot_id) if shot is None: raise HTTPException(404, "Screenshot not found") active = session.scalar( select(Job.id).where( Job.screenshot_id == shot.id, Job.kind == JobKind.OCR.value, Job.status.in_((JobStatus.PENDING.value, JobStatus.RUNNING.value)), ) ) if active is not None: return {"ok": True, "job_id": active, "message": "已有 OCR 任务在队列中"} shot.ocr_status = ProcessStatus.PENDING.value job = Job( screenshot_id=shot.id, kind=JobKind.OCR.value, status=JobStatus.PENDING.value, ) session.add(job) session.commit() worker.notify_threadsafe() return {"ok": True, "job_id": job.id} @router.delete("/{screenshot_id}") def delete_screenshot( screenshot_id: int, session: Session = Depends(db_session), ) -> dict: """删除记录(不删除原始文件)。""" shot = session.get(Screenshot, screenshot_id) if shot is None: raise HTTPException(404, "Screenshot not found") session.delete(shot) session.commit() return {"ok": True} @router.get("/{screenshot_id}/file") def get_file( screenshot_id: int, session: Session = Depends(db_session), ) -> FileResponse: """原图文件流。""" shot = session.get(Screenshot, screenshot_id) if shot is None: raise HTTPException(404, "Screenshot not found") p = path_from_storage(shot.path) if not is_accessible_file(p): raise HTTPException(404, "file missing") return FileResponse(str(p)) @router.get("/{screenshot_id}/thumb") def get_thumb( screenshot_id: int, session: Session = Depends(db_session), ) -> FileResponse: """缩略图流。""" shot = session.get(Screenshot, screenshot_id) if shot is None: raise HTTPException(404, "Screenshot not found") if shot.thumb_path: p = Path(shot.thumb_path) if p.exists(): return FileResponse(str(p), media_type="image/webp") # 兜底:返回原图 p = path_from_storage(shot.path) if is_accessible_file(p): return FileResponse(str(p)) raise HTTPException(404, "thumb missing")