Initial commit: snapAna 截图智能整理工具
包含 FastAPI 后端、React 前端、队列/OCR/标签/待办等完整功能。 Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -0,0 +1,17 @@
|
||||
"""API 通用依赖。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterator
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.db import SessionLocal
|
||||
|
||||
|
||||
def db_session() -> Iterator[Session]:
|
||||
"""每请求一个会话。"""
|
||||
session = SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
@@ -0,0 +1,367 @@
|
||||
"""截图列表 / 详情 / 随机 / 重新分析 / 文件流。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy import and_, func, or_, select, text
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from app.api.deps import db_session
|
||||
from app.core.path_utils import is_accessible_file, path_from_storage
|
||||
from app.models.category import Category
|
||||
from app.models.job import Job, JobKind, JobStatus
|
||||
from app.models.meta import ScreenshotMeta
|
||||
from app.models.screenshot import ProcessStatus, Screenshot
|
||||
from app.models.tag import Tag
|
||||
from app.services.search_utils import collect_search_ids
|
||||
from app.schemas.screenshot import (
|
||||
CategoryOut,
|
||||
ScreenshotBrief,
|
||||
ScreenshotDetail,
|
||||
ScreenshotListResp,
|
||||
ScreenshotUpdate,
|
||||
TagOut,
|
||||
TodoBrief,
|
||||
)
|
||||
from app.services.worker import worker
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/screenshots", tags=["screenshots"])
|
||||
|
||||
|
||||
def _to_brief(shot: Screenshot, cat_map: dict[int, Category]) -> ScreenshotBrief:
|
||||
"""ORM -> ScreenshotBrief。"""
|
||||
return ScreenshotBrief(
|
||||
id=shot.id,
|
||||
path=shot.path,
|
||||
width=shot.width,
|
||||
height=shot.height,
|
||||
captured_at=shot.captured_at,
|
||||
thumb_url=f"/api/screenshots/{shot.id}/thumb" if shot.thumb_path else None,
|
||||
ai_title=(shot.meta.ai_title if shot.meta else None),
|
||||
ai_status=shot.ai_status,
|
||||
ocr_status=shot.ocr_status,
|
||||
is_favorite=bool(shot.is_favorite),
|
||||
category=(
|
||||
CategoryOut.model_validate(cat_map[shot.category_id])
|
||||
if shot.category_id and shot.category_id in cat_map
|
||||
else None
|
||||
),
|
||||
tags=[TagOut.model_validate(t) for t in (shot.tags or [])],
|
||||
)
|
||||
|
||||
|
||||
def _category_map(session: Session) -> dict[int, Category]:
|
||||
return {c.id: c for c in session.scalars(select(Category)).all()}
|
||||
|
||||
|
||||
@router.get("", response_model=ScreenshotListResp)
|
||||
def list_screenshots(
|
||||
session: Session = Depends(db_session),
|
||||
q: Optional[str] = Query(None, description="OCR+AI 全文搜索关键词"),
|
||||
category_id: Optional[int] = Query(None),
|
||||
tag: Optional[str] = Query(None),
|
||||
date_from: Optional[datetime] = Query(None),
|
||||
date_to: Optional[datetime] = Query(None),
|
||||
favorite: Optional[bool] = Query(None),
|
||||
status: Optional[str] = Query(None, description="ai_status 过滤"),
|
||||
show_hidden: bool = Query(False),
|
||||
sort: str = Query("captured_desc"),
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(40, ge=1, le=200),
|
||||
) -> ScreenshotListResp:
|
||||
"""主列表查询:支持时间/分类/标签/收藏/状态/搜索词。"""
|
||||
stmt = select(Screenshot).options(
|
||||
selectinload(Screenshot.meta),
|
||||
selectinload(Screenshot.tags),
|
||||
)
|
||||
filters = []
|
||||
if not show_hidden:
|
||||
filters.append(Screenshot.is_hidden == 0)
|
||||
if category_id is not None:
|
||||
filters.append(Screenshot.category_id == category_id)
|
||||
if date_from is not None:
|
||||
filters.append(Screenshot.captured_at >= date_from)
|
||||
if date_to is not None:
|
||||
filters.append(Screenshot.captured_at <= date_to)
|
||||
if favorite is True:
|
||||
filters.append(Screenshot.is_favorite == 1)
|
||||
if status:
|
||||
filters.append(Screenshot.ai_status == status)
|
||||
if tag:
|
||||
stmt = stmt.join(Screenshot.tags).where(Tag.name.ilike(f"%{tag}%"))
|
||||
|
||||
if q:
|
||||
ids = collect_search_ids(session, q)
|
||||
if not ids:
|
||||
return ScreenshotListResp(items=[], total=0, page=page, size=size)
|
||||
filters.append(Screenshot.id.in_(ids))
|
||||
|
||||
if filters:
|
||||
stmt = stmt.where(and_(*filters))
|
||||
|
||||
# 排序
|
||||
stmt = _apply_sort(stmt, sort)
|
||||
|
||||
total = session.scalar(select(func.count()).select_from(stmt.subquery())) or 0
|
||||
rows = session.scalars(stmt.offset((page - 1) * size).limit(size)).unique().all()
|
||||
|
||||
cat_map = _category_map(session)
|
||||
items = [_to_brief(r, cat_map) for r in rows]
|
||||
return ScreenshotListResp(items=items, total=int(total), page=page, size=size)
|
||||
|
||||
|
||||
def _apply_sort(stmt, sort: str):
|
||||
"""列表排序:时间 / 导入 / 标题 / 文件大小。"""
|
||||
if sort == "captured_asc":
|
||||
return stmt.order_by(Screenshot.captured_at.asc())
|
||||
if sort == "imported_desc":
|
||||
return stmt.order_by(Screenshot.imported_at.desc())
|
||||
if sort == "imported_asc":
|
||||
return stmt.order_by(Screenshot.imported_at.asc())
|
||||
if sort == "title_asc":
|
||||
return stmt.outerjoin(ScreenshotMeta).order_by(
|
||||
ScreenshotMeta.ai_title.asc().nulls_last()
|
||||
)
|
||||
if sort == "title_desc":
|
||||
return stmt.outerjoin(ScreenshotMeta).order_by(
|
||||
ScreenshotMeta.ai_title.desc().nulls_last()
|
||||
)
|
||||
if sort == "size_desc":
|
||||
return stmt.order_by(Screenshot.size.desc())
|
||||
if sort == "size_asc":
|
||||
return stmt.order_by(Screenshot.size.asc())
|
||||
return stmt.order_by(Screenshot.captured_at.desc())
|
||||
|
||||
|
||||
@router.get("/random", response_model=list[ScreenshotBrief])
|
||||
def random_screenshots(
|
||||
session: Session = Depends(db_session),
|
||||
n: int = Query(1, ge=1, le=20),
|
||||
category_id: Optional[int] = Query(None),
|
||||
) -> list[ScreenshotBrief]:
|
||||
"""随机展示。"""
|
||||
stmt = select(Screenshot).options(
|
||||
selectinload(Screenshot.meta),
|
||||
selectinload(Screenshot.tags),
|
||||
).where(Screenshot.is_hidden == 0)
|
||||
if category_id is not None:
|
||||
stmt = stmt.where(Screenshot.category_id == category_id)
|
||||
stmt = stmt.order_by(func.random()).limit(n)
|
||||
rows = session.scalars(stmt).unique().all()
|
||||
cat_map = _category_map(session)
|
||||
return [_to_brief(r, cat_map) for r in rows]
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
def stats(session: Session = Depends(db_session)) -> dict:
|
||||
"""汇总统计:总数、状态分布、按分类、按月份。"""
|
||||
total = session.scalar(select(func.count(Screenshot.id))) or 0
|
||||
by_status = {
|
||||
st.value: session.scalar(
|
||||
select(func.count(Screenshot.id)).where(Screenshot.ai_status == st.value)
|
||||
)
|
||||
or 0
|
||||
for st in ProcessStatus
|
||||
}
|
||||
by_category_rows = session.execute(
|
||||
select(Category.id, Category.name, Category.color, func.count(Screenshot.id))
|
||||
.join(Screenshot, Screenshot.category_id == Category.id, isouter=True)
|
||||
.group_by(Category.id)
|
||||
.order_by(func.count(Screenshot.id).desc())
|
||||
).all()
|
||||
by_category = [
|
||||
{"id": r[0], "name": r[1], "color": r[2], "count": int(r[3] or 0)}
|
||||
for r in by_category_rows
|
||||
]
|
||||
by_month_rows = session.execute(
|
||||
text(
|
||||
"SELECT strftime('%Y-%m', captured_at) AS m, COUNT(1) AS c "
|
||||
"FROM screenshots WHERE is_hidden=0 GROUP BY m ORDER BY m DESC LIMIT 36"
|
||||
)
|
||||
).all()
|
||||
by_month = [{"month": r[0], "count": int(r[1])} for r in by_month_rows]
|
||||
return {
|
||||
"total": int(total),
|
||||
"by_status": by_status,
|
||||
"by_category": by_category,
|
||||
"by_month": by_month,
|
||||
"queue": _queue_summary(session),
|
||||
}
|
||||
|
||||
|
||||
def _queue_summary(session: Session) -> dict:
|
||||
"""汇总 jobs 队列状态。"""
|
||||
out: dict[str, int] = {}
|
||||
for st in JobStatus:
|
||||
out[st.value] = (
|
||||
session.scalar(select(func.count(Job.id)).where(Job.status == st.value)) or 0
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@router.get("/{screenshot_id}", response_model=ScreenshotDetail)
|
||||
def get_screenshot(
|
||||
screenshot_id: int,
|
||||
session: Session = Depends(db_session),
|
||||
) -> ScreenshotDetail:
|
||||
"""单张详情。"""
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
raise HTTPException(404, "Screenshot not found")
|
||||
cat_map = _category_map(session)
|
||||
brief = _to_brief(shot, cat_map)
|
||||
meta = shot.meta
|
||||
todos = [TodoBrief.model_validate(t) for t in shot.todos]
|
||||
return ScreenshotDetail(
|
||||
**brief.model_dump(),
|
||||
file_url=f"/api/screenshots/{shot.id}/file",
|
||||
size=shot.size,
|
||||
ocr_text=(meta.ocr_text if meta else None),
|
||||
ai_summary=(meta.ai_summary if meta else None),
|
||||
ai_suggestion=(meta.ai_suggestion if meta else None),
|
||||
todos=todos,
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/{screenshot_id}", response_model=ScreenshotDetail)
|
||||
def update_screenshot(
|
||||
screenshot_id: int,
|
||||
payload: ScreenshotUpdate,
|
||||
session: Session = Depends(db_session),
|
||||
) -> ScreenshotDetail:
|
||||
"""前端编辑:分类、收藏、隐藏、标签。"""
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
raise HTTPException(404, "Screenshot not found")
|
||||
# 用 model_fields_set 区分「未传字段」与「显式传入 null」
|
||||
# 这样前端 PATCH {"category_id": null} 可以真正清空分类
|
||||
fields = payload.model_fields_set
|
||||
if "category_id" in fields:
|
||||
if payload.category_id is not None:
|
||||
cat = session.get(Category, payload.category_id)
|
||||
if cat is None:
|
||||
raise HTTPException(400, "category not found")
|
||||
shot.category_id = payload.category_id
|
||||
if "is_favorite" in fields and payload.is_favorite is not None:
|
||||
shot.is_favorite = 1 if payload.is_favorite else 0
|
||||
if "is_hidden" in fields and payload.is_hidden is not None:
|
||||
shot.is_hidden = 1 if payload.is_hidden else 0
|
||||
if "tags" in fields and payload.tags is not None:
|
||||
tag_objs = []
|
||||
for name in payload.tags:
|
||||
name = (name or "").strip()[:64]
|
||||
if not name:
|
||||
continue
|
||||
tag = session.scalar(select(Tag).where(Tag.name == name))
|
||||
if tag is None:
|
||||
tag = Tag(name=name)
|
||||
session.add(tag)
|
||||
session.flush()
|
||||
tag_objs.append(tag)
|
||||
shot.tags = tag_objs
|
||||
session.commit()
|
||||
session.refresh(shot)
|
||||
return get_screenshot(screenshot_id, session)
|
||||
|
||||
|
||||
@router.post("/{screenshot_id}/reanalyze")
|
||||
def reanalyze(
|
||||
screenshot_id: int,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
"""加入队列重新分析。"""
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
raise HTTPException(404, "Screenshot not found")
|
||||
shot.ai_status = ProcessStatus.PENDING.value
|
||||
shot.ocr_status = ProcessStatus.PENDING.value
|
||||
job = Job(screenshot_id=shot.id, kind=JobKind.FULL.value, status=JobStatus.PENDING.value)
|
||||
session.add(job)
|
||||
session.commit()
|
||||
# 同步路由跑在线程池,必须 threadsafe 唤醒事件循环
|
||||
worker.notify_threadsafe()
|
||||
return {"ok": True, "job_id": job.id}
|
||||
|
||||
|
||||
@router.post("/{screenshot_id}/reocr")
|
||||
def reocr(
|
||||
screenshot_id: int,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
"""仅补跑 OCR,不重新调用 AI 分析。"""
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
raise HTTPException(404, "Screenshot not found")
|
||||
active = session.scalar(
|
||||
select(Job.id).where(
|
||||
Job.screenshot_id == shot.id,
|
||||
Job.kind == JobKind.OCR.value,
|
||||
Job.status.in_((JobStatus.PENDING.value, JobStatus.RUNNING.value)),
|
||||
)
|
||||
)
|
||||
if active is not None:
|
||||
return {"ok": True, "job_id": active, "message": "已有 OCR 任务在队列中"}
|
||||
shot.ocr_status = ProcessStatus.PENDING.value
|
||||
job = Job(
|
||||
screenshot_id=shot.id,
|
||||
kind=JobKind.OCR.value,
|
||||
status=JobStatus.PENDING.value,
|
||||
)
|
||||
session.add(job)
|
||||
session.commit()
|
||||
worker.notify_threadsafe()
|
||||
return {"ok": True, "job_id": job.id}
|
||||
|
||||
|
||||
@router.delete("/{screenshot_id}")
|
||||
def delete_screenshot(
|
||||
screenshot_id: int,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
"""删除记录(不删除原始文件)。"""
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
raise HTTPException(404, "Screenshot not found")
|
||||
session.delete(shot)
|
||||
session.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/{screenshot_id}/file")
|
||||
def get_file(
|
||||
screenshot_id: int,
|
||||
session: Session = Depends(db_session),
|
||||
) -> FileResponse:
|
||||
"""原图文件流。"""
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
raise HTTPException(404, "Screenshot not found")
|
||||
p = path_from_storage(shot.path)
|
||||
if not is_accessible_file(p):
|
||||
raise HTTPException(404, "file missing")
|
||||
return FileResponse(str(p))
|
||||
|
||||
|
||||
@router.get("/{screenshot_id}/thumb")
|
||||
def get_thumb(
|
||||
screenshot_id: int,
|
||||
session: Session = Depends(db_session),
|
||||
) -> FileResponse:
|
||||
"""缩略图流。"""
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
raise HTTPException(404, "Screenshot not found")
|
||||
if shot.thumb_path:
|
||||
p = Path(shot.thumb_path)
|
||||
if p.exists():
|
||||
return FileResponse(str(p), media_type="image/webp")
|
||||
# 兜底:返回原图
|
||||
p = path_from_storage(shot.path)
|
||||
if is_accessible_file(p):
|
||||
return FileResponse(str(p))
|
||||
raise HTTPException(404, "thumb missing")
|
||||
@@ -0,0 +1,220 @@
|
||||
"""设置接口:Provider 配置、分类、Tag。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.deps import db_session
|
||||
from app.models.category import Category
|
||||
from app.models.setting import (
|
||||
DEFAULT_RECOGNITION_MODE,
|
||||
KEY_OCR_PROVIDER,
|
||||
KEY_RECOGNITION_MODE,
|
||||
KEY_VLM_PROVIDER,
|
||||
)
|
||||
from app.models.screenshot import Screenshot
|
||||
from app.models.tag import Tag
|
||||
from app.providers import RECOGNITION_MODES
|
||||
from app.schemas.common import (
|
||||
CategoryIn,
|
||||
ProviderConfig,
|
||||
ProviderConfigOut,
|
||||
ProviderTestResult,
|
||||
RecognitionModeIn,
|
||||
)
|
||||
from app.services.provider_test import merge_provider_api_key, test_provider_config
|
||||
from app.services.settings_store import all_settings, get_setting, set_setting
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/settings", tags=["settings"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
def get_all(session: Session = Depends(db_session)) -> dict:
|
||||
"""返回所有非敏感设置。api_key 字段做脱敏。"""
|
||||
raw = all_settings(session)
|
||||
for key in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
|
||||
cfg = raw.get(key)
|
||||
if isinstance(cfg, dict) and cfg.get("api_key"):
|
||||
cfg["api_key_mask"] = _mask(cfg["api_key"])
|
||||
cfg["api_key"] = ""
|
||||
return raw
|
||||
|
||||
|
||||
def _mask(value: str) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
if len(value) <= 6:
|
||||
return "*" * len(value)
|
||||
return value[:3] + "*" * (len(value) - 6) + value[-3:]
|
||||
|
||||
|
||||
@router.get("/providers/{key}", response_model=ProviderConfigOut | None)
|
||||
def get_provider(
|
||||
key: str,
|
||||
session: Session = Depends(db_session),
|
||||
) -> ProviderConfigOut | None:
|
||||
"""读取 Provider 配置:api_key 明文不外传,只给一个掩码用于 UI 提示。"""
|
||||
if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
|
||||
raise HTTPException(400, "key must be ocr_provider or vlm_provider")
|
||||
raw = get_setting(session, key, None)
|
||||
if not raw:
|
||||
return None
|
||||
mask = _mask(raw.get("api_key", "") or "")
|
||||
return ProviderConfigOut(
|
||||
type=raw.get("type", ""),
|
||||
base_url=raw.get("base_url"),
|
||||
api_key="",
|
||||
api_key_mask=mask or None,
|
||||
model=raw.get("model"),
|
||||
extra=raw.get("extra", {}) or {},
|
||||
)
|
||||
|
||||
|
||||
@router.put("/providers/{key}")
|
||||
def put_provider(
|
||||
key: str,
|
||||
cfg: ProviderConfig,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
|
||||
raise HTTPException(400, "key must be ocr_provider or vlm_provider")
|
||||
# 如果客户端没有传新的 api_key(空字符串),保留旧值
|
||||
existing = get_setting(session, key, None)
|
||||
payload = cfg.model_dump()
|
||||
if (not payload.get("api_key")) and isinstance(existing, dict):
|
||||
payload["api_key"] = existing.get("api_key", "")
|
||||
set_setting(session, key, payload)
|
||||
session.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.post("/providers/{key}/test", response_model=ProviderTestResult)
|
||||
async def test_provider(
|
||||
key: str,
|
||||
cfg: ProviderConfig,
|
||||
session: Session = Depends(db_session),
|
||||
) -> ProviderTestResult:
|
||||
"""测试 OCR / 视觉 AI Provider 连通性(使用当前表单配置,api_key 可留空沿用已保存值)。"""
|
||||
if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
|
||||
raise HTTPException(400, "key must be ocr_provider or vlm_provider")
|
||||
existing = get_setting(session, key, None)
|
||||
merged = merge_provider_api_key(cfg, existing if isinstance(existing, dict) else None)
|
||||
result = await test_provider_config(key, merged)
|
||||
return ProviderTestResult(**result)
|
||||
|
||||
|
||||
@router.get("/recognition-mode")
|
||||
def get_recognition_mode(session: Session = Depends(db_session)) -> dict:
|
||||
"""读取文字识别策略:ocr / vision / hybrid。"""
|
||||
mode = get_setting(session, KEY_RECOGNITION_MODE, DEFAULT_RECOGNITION_MODE)
|
||||
if mode not in RECOGNITION_MODES:
|
||||
mode = DEFAULT_RECOGNITION_MODE
|
||||
return {"mode": mode, "options": list(RECOGNITION_MODES)}
|
||||
|
||||
|
||||
@router.put("/recognition-mode")
|
||||
def put_recognition_mode(
|
||||
payload: RecognitionModeIn,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
"""保存文字识别策略。"""
|
||||
if payload.mode not in RECOGNITION_MODES:
|
||||
raise HTTPException(400, f"mode must be one of {RECOGNITION_MODES}")
|
||||
set_setting(session, KEY_RECOGNITION_MODE, payload.mode)
|
||||
session.commit()
|
||||
return {"ok": True, "mode": payload.mode}
|
||||
|
||||
|
||||
@router.get("/categories")
|
||||
def list_categories(session: Session = Depends(db_session)) -> list[dict]:
|
||||
rows = session.scalars(select(Category).order_by(Category.id)).all()
|
||||
return [
|
||||
{"id": c.id, "name": c.name, "color": c.color, "prompt_hint": c.prompt_hint}
|
||||
for c in rows
|
||||
]
|
||||
|
||||
|
||||
@router.post("/categories")
|
||||
def create_category(
|
||||
payload: CategoryIn,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
exists = session.scalar(select(Category).where(Category.name == payload.name))
|
||||
if exists is not None:
|
||||
raise HTTPException(400, "category exists")
|
||||
cat = Category(name=payload.name, color=payload.color, prompt_hint=payload.prompt_hint)
|
||||
session.add(cat)
|
||||
session.commit()
|
||||
session.refresh(cat)
|
||||
return {"id": cat.id}
|
||||
|
||||
|
||||
@router.patch("/categories/{cat_id}")
|
||||
def update_category(
|
||||
cat_id: int,
|
||||
payload: CategoryIn,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
cat = session.get(Category, cat_id)
|
||||
if cat is None:
|
||||
raise HTTPException(404, "category not found")
|
||||
cat.name = payload.name
|
||||
cat.color = payload.color
|
||||
cat.prompt_hint = payload.prompt_hint
|
||||
session.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.delete("/categories/{cat_id}")
|
||||
def delete_category(
|
||||
cat_id: int,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
cat = session.get(Category, cat_id)
|
||||
if cat is None:
|
||||
raise HTTPException(404, "category not found")
|
||||
session.delete(cat)
|
||||
session.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/tags")
|
||||
def list_tags(
|
||||
session: Session = Depends(db_session),
|
||||
q: Optional[str] = Query(None, description="标签名关键词"),
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(200, ge=1, le=500),
|
||||
sort: str = Query("count_desc", description="count_desc|count_asc|name_asc|name_desc"),
|
||||
) -> dict:
|
||||
"""标签列表(含使用次数),支持搜索与分页。"""
|
||||
base = select(Tag.id)
|
||||
if q:
|
||||
base = base.where(Tag.name.ilike(f"%{q.strip()}%"))
|
||||
total = session.scalar(select(func.count()).select_from(base.subquery())) or 0
|
||||
|
||||
stmt = (
|
||||
select(Tag.id, Tag.name, Tag.color, func.count(Screenshot.id))
|
||||
.join(Tag.screenshots, isouter=True)
|
||||
.group_by(Tag.id)
|
||||
)
|
||||
if q:
|
||||
stmt = stmt.where(Tag.name.ilike(f"%{q.strip()}%"))
|
||||
|
||||
if sort == "count_asc":
|
||||
stmt = stmt.order_by(func.count(Screenshot.id).asc())
|
||||
elif sort == "name_asc":
|
||||
stmt = stmt.order_by(Tag.name.asc())
|
||||
elif sort == "name_desc":
|
||||
stmt = stmt.order_by(Tag.name.desc())
|
||||
else:
|
||||
stmt = stmt.order_by(func.count(Screenshot.id).desc())
|
||||
|
||||
rows = session.execute(stmt.offset((page - 1) * size).limit(size)).all()
|
||||
items = [
|
||||
{"id": r[0], "name": r[1], "color": r[2], "count": int(r[3] or 0)} for r in rows
|
||||
]
|
||||
return {"items": items, "total": int(total), "page": page, "size": size}
|
||||
@@ -0,0 +1,106 @@
|
||||
"""待办清单接口。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.deps import db_session
|
||||
from app.models.todo import Todo, TodoStatus
|
||||
from app.schemas.common import TodoUpdate
|
||||
from app.schemas.screenshot import TodoBrief, TodoListResp
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/todos", tags=["todos"])
|
||||
|
||||
|
||||
def _todo_filters(
|
||||
status: Optional[str],
|
||||
kind: Optional[str],
|
||||
q: Optional[str],
|
||||
) -> list:
|
||||
"""构建待办筛选条件。"""
|
||||
filters = []
|
||||
if status:
|
||||
filters.append(Todo.status == status)
|
||||
if kind:
|
||||
filters.append(Todo.kind == kind)
|
||||
if q:
|
||||
like = f"%{q.strip()}%"
|
||||
filters.append(or_(Todo.title.ilike(like), Todo.note.ilike(like)))
|
||||
return filters
|
||||
|
||||
|
||||
@router.get("", response_model=TodoListResp)
|
||||
def list_todos(
|
||||
session: Session = Depends(db_session),
|
||||
status: Optional[str] = Query(None),
|
||||
kind: Optional[str] = Query(None),
|
||||
q: Optional[str] = Query(None, description="标题/备注关键词"),
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(50, ge=1, le=200),
|
||||
) -> TodoListResp:
|
||||
"""按状态/类型/关键词分页查询。"""
|
||||
filters = _todo_filters(status, kind, q)
|
||||
base = select(Todo)
|
||||
if filters:
|
||||
base = base.where(and_(*filters))
|
||||
|
||||
total = session.scalar(select(func.count()).select_from(base.subquery())) or 0
|
||||
rows = session.scalars(
|
||||
base.order_by(Todo.created_at.desc()).offset((page - 1) * size).limit(size)
|
||||
).all()
|
||||
return TodoListResp(
|
||||
items=[TodoBrief.model_validate(r) for r in rows],
|
||||
total=int(total),
|
||||
page=page,
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
def summary(session: Session = Depends(db_session)) -> dict:
|
||||
"""各状态待办数量。"""
|
||||
return {
|
||||
st.value: session.scalar(select(func.count(Todo.id)).where(Todo.status == st.value)) or 0
|
||||
for st in TodoStatus
|
||||
}
|
||||
|
||||
|
||||
@router.patch("/{todo_id}", response_model=TodoBrief)
|
||||
def update_todo(
|
||||
todo_id: int,
|
||||
payload: TodoUpdate,
|
||||
session: Session = Depends(db_session),
|
||||
) -> TodoBrief:
|
||||
"""更新状态/标题/备注。"""
|
||||
todo = session.get(Todo, todo_id)
|
||||
if todo is None:
|
||||
raise HTTPException(404, "Todo not found")
|
||||
if payload.status is not None:
|
||||
todo.status = payload.status
|
||||
if payload.status == TodoStatus.DONE.value:
|
||||
todo.completed_at = datetime.utcnow()
|
||||
if payload.title is not None:
|
||||
todo.title = payload.title
|
||||
if payload.note is not None:
|
||||
todo.note = payload.note
|
||||
session.commit()
|
||||
session.refresh(todo)
|
||||
return TodoBrief.model_validate(todo)
|
||||
|
||||
|
||||
@router.delete("/{todo_id}")
|
||||
def delete_todo(
|
||||
todo_id: int,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
todo = session.get(Todo, todo_id)
|
||||
if todo is None:
|
||||
raise HTTPException(404, "Todo not found")
|
||||
session.delete(todo)
|
||||
session.commit()
|
||||
return {"ok": True}
|
||||
@@ -0,0 +1,256 @@
|
||||
"""监听目录的增删改、手动导入、分析队列。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from app.api.deps import db_session
|
||||
from app.core.db import session_scope
|
||||
from app.core.path_utils import (
|
||||
count_files_sample,
|
||||
is_accessible_dir,
|
||||
normalize_user_path,
|
||||
)
|
||||
from app.models.job import Job, JobStatus
|
||||
from app.models.screenshot import ProcessStatus, Screenshot
|
||||
from app.models.watch_folder import WatchFolder
|
||||
from app.schemas.common import WatchFolderIn, WatchFolderOut
|
||||
from app.schemas.job import JobListResp, JobOut, JobRetryIn
|
||||
from app.services.analyze import enqueue_ocr_jobs
|
||||
from app.services.ingest import ingest_directory
|
||||
from app.services.watcher import watcher_service
|
||||
from app.services.worker import worker
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/watch", tags=["watch"])
|
||||
|
||||
|
||||
def _validate_folder_path(raw: str) -> str:
|
||||
"""校验并规范化监听目录路径(含 UNC 网络路径)。"""
|
||||
normalized = normalize_user_path(raw)
|
||||
if not normalized:
|
||||
raise HTTPException(400, "路径不能为空")
|
||||
if not is_accessible_dir(normalized):
|
||||
raise HTTPException(
|
||||
400,
|
||||
f"目录不存在或无法访问: {normalized}。"
|
||||
"请确认 NAS 已挂载、有读权限,UNC 路径形如 \\\\服务器\\共享\\文件夹",
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
@router.get("/folders", response_model=list[WatchFolderOut])
|
||||
def list_folders(session: Session = Depends(db_session)) -> list[WatchFolderOut]:
|
||||
rows = session.scalars(select(WatchFolder).order_by(WatchFolder.id)).all()
|
||||
return [WatchFolderOut.model_validate(r) for r in rows]
|
||||
|
||||
|
||||
@router.post("/folders", response_model=WatchFolderOut)
|
||||
def add_folder(
|
||||
payload: WatchFolderIn,
|
||||
background: BackgroundTasks,
|
||||
session: Session = Depends(db_session),
|
||||
) -> WatchFolderOut:
|
||||
"""新增监听目录,自动触发一次扫描入库。"""
|
||||
normalized = _validate_folder_path(payload.path)
|
||||
exists = session.scalar(select(WatchFolder).where(WatchFolder.path == normalized))
|
||||
if exists is not None:
|
||||
raise HTTPException(400, "目录已存在")
|
||||
folder = WatchFolder(
|
||||
path=normalized,
|
||||
enabled=payload.enabled,
|
||||
recursive=payload.recursive,
|
||||
is_sensitive=payload.is_sensitive,
|
||||
)
|
||||
session.add(folder)
|
||||
session.commit()
|
||||
session.refresh(folder)
|
||||
watcher_service.reload()
|
||||
background.add_task(_scan_folder, normalized, payload.recursive)
|
||||
return WatchFolderOut.model_validate(folder)
|
||||
|
||||
|
||||
@router.patch("/folders/{folder_id}", response_model=WatchFolderOut)
|
||||
def update_folder(
|
||||
folder_id: int,
|
||||
payload: WatchFolderIn,
|
||||
session: Session = Depends(db_session),
|
||||
) -> WatchFolderOut:
|
||||
folder = session.get(WatchFolder, folder_id)
|
||||
if folder is None:
|
||||
raise HTTPException(404, "folder not found")
|
||||
normalized = _validate_folder_path(payload.path)
|
||||
folder.path = normalized
|
||||
folder.enabled = payload.enabled
|
||||
folder.recursive = payload.recursive
|
||||
folder.is_sensitive = payload.is_sensitive
|
||||
session.commit()
|
||||
session.refresh(folder)
|
||||
watcher_service.reload()
|
||||
return WatchFolderOut.model_validate(folder)
|
||||
|
||||
|
||||
@router.delete("/folders/{folder_id}")
|
||||
def delete_folder(
|
||||
folder_id: int,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
folder = session.get(WatchFolder, folder_id)
|
||||
if folder is None:
|
||||
raise HTTPException(404, "folder not found")
|
||||
session.delete(folder)
|
||||
session.commit()
|
||||
watcher_service.reload()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.post("/import")
|
||||
def import_now(
|
||||
payload: WatchFolderIn,
|
||||
background: BackgroundTasks,
|
||||
) -> dict:
|
||||
"""手动触发一次目录扫描(不一定要登记为监听)。"""
|
||||
normalized = _validate_folder_path(payload.path)
|
||||
background.add_task(_scan_folder, normalized, payload.recursive)
|
||||
return {"ok": True, "message": "已在后台扫描"}
|
||||
|
||||
|
||||
@router.post("/validate-path")
|
||||
def validate_path(payload: WatchFolderIn) -> dict:
|
||||
"""测试目录是否可访问(含 UNC 网络路径),返回抽样文件数。"""
|
||||
normalized = _validate_folder_path(payload.path)
|
||||
total, samples = count_files_sample(normalized, limit=3)
|
||||
return {
|
||||
"ok": True,
|
||||
"path": normalized,
|
||||
"sample_image_count": total,
|
||||
"samples": samples,
|
||||
"message": f"目录可访问,抽样发现约 {total}+ 张图片",
|
||||
}
|
||||
|
||||
|
||||
def _scan_folder(path: str, recursive: bool) -> None:
|
||||
"""后台任务:扫描目录入库,再通知 worker。
|
||||
|
||||
BackgroundTasks 的同步函数运行在线程池中,必须用 threadsafe 入口
|
||||
唤醒事件循环,否则 asyncio.Event.set() 会有竞态。
|
||||
"""
|
||||
with session_scope() as session:
|
||||
ingest_directory(session, path, recursive=recursive)
|
||||
worker.notify_threadsafe()
|
||||
|
||||
|
||||
@router.get("/queue")
|
||||
async def queue_status() -> dict:
|
||||
"""读取 worker 队列状态。"""
|
||||
counts = await worker.status()
|
||||
with session_scope() as session:
|
||||
counts["ocr_retryable"] = (
|
||||
session.scalar(
|
||||
select(func.count(Screenshot.id)).where(
|
||||
Screenshot.ocr_status == ProcessStatus.FAILED.value,
|
||||
Screenshot.ai_status == ProcessStatus.DONE.value,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
counts["ocr_pending"] = (
|
||||
session.scalar(
|
||||
select(func.count(Job.id)).where(
|
||||
Job.kind == "ocr",
|
||||
Job.status == JobStatus.PENDING.value,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
return counts
|
||||
|
||||
|
||||
def _job_to_out(job: Job, shot: Screenshot | None) -> JobOut:
|
||||
"""ORM -> JobOut,附带截图摘要字段。"""
|
||||
return JobOut(
|
||||
id=job.id,
|
||||
screenshot_id=job.screenshot_id,
|
||||
kind=job.kind,
|
||||
status=job.status,
|
||||
retries=job.retries or 0,
|
||||
last_error=job.last_error,
|
||||
created_at=job.created_at,
|
||||
started_at=job.started_at,
|
||||
finished_at=job.finished_at,
|
||||
thumb_url=(
|
||||
f"/api/screenshots/{shot.id}/thumb" if shot and shot.thumb_path else None
|
||||
),
|
||||
path=shot.path if shot else None,
|
||||
ai_title=(shot.meta.ai_title if shot and shot.meta else None),
|
||||
ai_status=shot.ai_status if shot else None,
|
||||
ocr_status=shot.ocr_status if shot else None,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/jobs", response_model=JobListResp)
|
||||
def list_jobs(
|
||||
status: Optional[str] = Query(None, description="pending|running|done|failed"),
|
||||
kind: Optional[str] = Query(None, description="full|ocr|vlm"),
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(50, ge=1, le=200),
|
||||
session: Session = Depends(db_session),
|
||||
) -> JobListResp:
|
||||
"""分页列出分析任务,默认按 id 倒序(最新的在前)。"""
|
||||
if status and status not in {s.value for s in JobStatus}:
|
||||
raise HTTPException(400, f"无效 status: {status}")
|
||||
|
||||
base = select(Job)
|
||||
if status:
|
||||
base = base.where(Job.status == status)
|
||||
if kind:
|
||||
base = base.where(Job.kind == kind)
|
||||
|
||||
total = session.scalar(select(func.count()).select_from(base.subquery())) or 0
|
||||
jobs = session.scalars(
|
||||
base.order_by(Job.id.desc()).offset((page - 1) * size).limit(size)
|
||||
).all()
|
||||
|
||||
shot_ids = [j.screenshot_id for j in jobs]
|
||||
shots: dict[int, Screenshot] = {}
|
||||
if shot_ids:
|
||||
rows = session.scalars(
|
||||
select(Screenshot)
|
||||
.where(Screenshot.id.in_(shot_ids))
|
||||
.options(selectinload(Screenshot.meta))
|
||||
).all()
|
||||
shots = {s.id: s for s in rows}
|
||||
|
||||
items = [_job_to_out(j, shots.get(j.screenshot_id)) for j in jobs]
|
||||
return JobListResp(items=items, total=total, page=page, size=size)
|
||||
|
||||
|
||||
@router.post("/jobs/retry-failed")
|
||||
def retry_failed_jobs(payload: JobRetryIn | None = None) -> dict:
|
||||
"""将全部或指定 failed 任务重新排队。"""
|
||||
job_ids = payload.job_ids if payload else None
|
||||
count = worker.retry_failed(job_ids)
|
||||
return {"ok": True, "count": count}
|
||||
|
||||
|
||||
@router.post("/jobs/reset-stale")
|
||||
def reset_stale_jobs(
|
||||
minutes: int = Query(5, ge=1, le=1440),
|
||||
reset_all: bool = Query(False, description="为 true 时复位全部 RUNNING"),
|
||||
) -> dict:
|
||||
"""复位僵尸 RUNNING 任务(worker 崩溃或未正常 finish 时)。"""
|
||||
count = worker.reset_stale_running(minutes=minutes, reset_all=reset_all)
|
||||
return {"ok": True, "count": count}
|
||||
|
||||
|
||||
@router.post("/jobs/enqueue-ocr-failed")
|
||||
def enqueue_ocr_failed(limit: int = Query(500, ge=1, le=5000)) -> dict:
|
||||
"""为 AI 已成功但 OCR 失败的截图批量创建 OCR 补跑任务。"""
|
||||
count = enqueue_ocr_jobs(limit=limit)
|
||||
if count:
|
||||
worker.notify()
|
||||
return {"ok": True, "count": count}
|
||||
Reference in New Issue
Block a user