Initial commit: snapAna 截图智能整理工具
包含 FastAPI 后端、React 前端、队列/OCR/标签/待办等完整功能。 Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
"""snapAna 截图分析后端应用包。"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
@@ -0,0 +1,17 @@
|
||||
"""API 通用依赖。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterator
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.db import SessionLocal
|
||||
|
||||
|
||||
def db_session() -> Iterator[Session]:
|
||||
"""每请求一个会话。"""
|
||||
session = SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
@@ -0,0 +1,367 @@
|
||||
"""截图列表 / 详情 / 随机 / 重新分析 / 文件流。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy import and_, func, or_, select, text
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from app.api.deps import db_session
|
||||
from app.core.path_utils import is_accessible_file, path_from_storage
|
||||
from app.models.category import Category
|
||||
from app.models.job import Job, JobKind, JobStatus
|
||||
from app.models.meta import ScreenshotMeta
|
||||
from app.models.screenshot import ProcessStatus, Screenshot
|
||||
from app.models.tag import Tag
|
||||
from app.services.search_utils import collect_search_ids
|
||||
from app.schemas.screenshot import (
|
||||
CategoryOut,
|
||||
ScreenshotBrief,
|
||||
ScreenshotDetail,
|
||||
ScreenshotListResp,
|
||||
ScreenshotUpdate,
|
||||
TagOut,
|
||||
TodoBrief,
|
||||
)
|
||||
from app.services.worker import worker
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/screenshots", tags=["screenshots"])
|
||||
|
||||
|
||||
def _to_brief(shot: Screenshot, cat_map: dict[int, Category]) -> ScreenshotBrief:
|
||||
"""ORM -> ScreenshotBrief。"""
|
||||
return ScreenshotBrief(
|
||||
id=shot.id,
|
||||
path=shot.path,
|
||||
width=shot.width,
|
||||
height=shot.height,
|
||||
captured_at=shot.captured_at,
|
||||
thumb_url=f"/api/screenshots/{shot.id}/thumb" if shot.thumb_path else None,
|
||||
ai_title=(shot.meta.ai_title if shot.meta else None),
|
||||
ai_status=shot.ai_status,
|
||||
ocr_status=shot.ocr_status,
|
||||
is_favorite=bool(shot.is_favorite),
|
||||
category=(
|
||||
CategoryOut.model_validate(cat_map[shot.category_id])
|
||||
if shot.category_id and shot.category_id in cat_map
|
||||
else None
|
||||
),
|
||||
tags=[TagOut.model_validate(t) for t in (shot.tags or [])],
|
||||
)
|
||||
|
||||
|
||||
def _category_map(session: Session) -> dict[int, Category]:
|
||||
return {c.id: c for c in session.scalars(select(Category)).all()}
|
||||
|
||||
|
||||
@router.get("", response_model=ScreenshotListResp)
|
||||
def list_screenshots(
|
||||
session: Session = Depends(db_session),
|
||||
q: Optional[str] = Query(None, description="OCR+AI 全文搜索关键词"),
|
||||
category_id: Optional[int] = Query(None),
|
||||
tag: Optional[str] = Query(None),
|
||||
date_from: Optional[datetime] = Query(None),
|
||||
date_to: Optional[datetime] = Query(None),
|
||||
favorite: Optional[bool] = Query(None),
|
||||
status: Optional[str] = Query(None, description="ai_status 过滤"),
|
||||
show_hidden: bool = Query(False),
|
||||
sort: str = Query("captured_desc"),
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(40, ge=1, le=200),
|
||||
) -> ScreenshotListResp:
|
||||
"""主列表查询:支持时间/分类/标签/收藏/状态/搜索词。"""
|
||||
stmt = select(Screenshot).options(
|
||||
selectinload(Screenshot.meta),
|
||||
selectinload(Screenshot.tags),
|
||||
)
|
||||
filters = []
|
||||
if not show_hidden:
|
||||
filters.append(Screenshot.is_hidden == 0)
|
||||
if category_id is not None:
|
||||
filters.append(Screenshot.category_id == category_id)
|
||||
if date_from is not None:
|
||||
filters.append(Screenshot.captured_at >= date_from)
|
||||
if date_to is not None:
|
||||
filters.append(Screenshot.captured_at <= date_to)
|
||||
if favorite is True:
|
||||
filters.append(Screenshot.is_favorite == 1)
|
||||
if status:
|
||||
filters.append(Screenshot.ai_status == status)
|
||||
if tag:
|
||||
stmt = stmt.join(Screenshot.tags).where(Tag.name.ilike(f"%{tag}%"))
|
||||
|
||||
if q:
|
||||
ids = collect_search_ids(session, q)
|
||||
if not ids:
|
||||
return ScreenshotListResp(items=[], total=0, page=page, size=size)
|
||||
filters.append(Screenshot.id.in_(ids))
|
||||
|
||||
if filters:
|
||||
stmt = stmt.where(and_(*filters))
|
||||
|
||||
# 排序
|
||||
stmt = _apply_sort(stmt, sort)
|
||||
|
||||
total = session.scalar(select(func.count()).select_from(stmt.subquery())) or 0
|
||||
rows = session.scalars(stmt.offset((page - 1) * size).limit(size)).unique().all()
|
||||
|
||||
cat_map = _category_map(session)
|
||||
items = [_to_brief(r, cat_map) for r in rows]
|
||||
return ScreenshotListResp(items=items, total=int(total), page=page, size=size)
|
||||
|
||||
|
||||
def _apply_sort(stmt, sort: str):
|
||||
"""列表排序:时间 / 导入 / 标题 / 文件大小。"""
|
||||
if sort == "captured_asc":
|
||||
return stmt.order_by(Screenshot.captured_at.asc())
|
||||
if sort == "imported_desc":
|
||||
return stmt.order_by(Screenshot.imported_at.desc())
|
||||
if sort == "imported_asc":
|
||||
return stmt.order_by(Screenshot.imported_at.asc())
|
||||
if sort == "title_asc":
|
||||
return stmt.outerjoin(ScreenshotMeta).order_by(
|
||||
ScreenshotMeta.ai_title.asc().nulls_last()
|
||||
)
|
||||
if sort == "title_desc":
|
||||
return stmt.outerjoin(ScreenshotMeta).order_by(
|
||||
ScreenshotMeta.ai_title.desc().nulls_last()
|
||||
)
|
||||
if sort == "size_desc":
|
||||
return stmt.order_by(Screenshot.size.desc())
|
||||
if sort == "size_asc":
|
||||
return stmt.order_by(Screenshot.size.asc())
|
||||
return stmt.order_by(Screenshot.captured_at.desc())
|
||||
|
||||
|
||||
@router.get("/random", response_model=list[ScreenshotBrief])
|
||||
def random_screenshots(
|
||||
session: Session = Depends(db_session),
|
||||
n: int = Query(1, ge=1, le=20),
|
||||
category_id: Optional[int] = Query(None),
|
||||
) -> list[ScreenshotBrief]:
|
||||
"""随机展示。"""
|
||||
stmt = select(Screenshot).options(
|
||||
selectinload(Screenshot.meta),
|
||||
selectinload(Screenshot.tags),
|
||||
).where(Screenshot.is_hidden == 0)
|
||||
if category_id is not None:
|
||||
stmt = stmt.where(Screenshot.category_id == category_id)
|
||||
stmt = stmt.order_by(func.random()).limit(n)
|
||||
rows = session.scalars(stmt).unique().all()
|
||||
cat_map = _category_map(session)
|
||||
return [_to_brief(r, cat_map) for r in rows]
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
def stats(session: Session = Depends(db_session)) -> dict:
|
||||
"""汇总统计:总数、状态分布、按分类、按月份。"""
|
||||
total = session.scalar(select(func.count(Screenshot.id))) or 0
|
||||
by_status = {
|
||||
st.value: session.scalar(
|
||||
select(func.count(Screenshot.id)).where(Screenshot.ai_status == st.value)
|
||||
)
|
||||
or 0
|
||||
for st in ProcessStatus
|
||||
}
|
||||
by_category_rows = session.execute(
|
||||
select(Category.id, Category.name, Category.color, func.count(Screenshot.id))
|
||||
.join(Screenshot, Screenshot.category_id == Category.id, isouter=True)
|
||||
.group_by(Category.id)
|
||||
.order_by(func.count(Screenshot.id).desc())
|
||||
).all()
|
||||
by_category = [
|
||||
{"id": r[0], "name": r[1], "color": r[2], "count": int(r[3] or 0)}
|
||||
for r in by_category_rows
|
||||
]
|
||||
by_month_rows = session.execute(
|
||||
text(
|
||||
"SELECT strftime('%Y-%m', captured_at) AS m, COUNT(1) AS c "
|
||||
"FROM screenshots WHERE is_hidden=0 GROUP BY m ORDER BY m DESC LIMIT 36"
|
||||
)
|
||||
).all()
|
||||
by_month = [{"month": r[0], "count": int(r[1])} for r in by_month_rows]
|
||||
return {
|
||||
"total": int(total),
|
||||
"by_status": by_status,
|
||||
"by_category": by_category,
|
||||
"by_month": by_month,
|
||||
"queue": _queue_summary(session),
|
||||
}
|
||||
|
||||
|
||||
def _queue_summary(session: Session) -> dict:
|
||||
"""汇总 jobs 队列状态。"""
|
||||
out: dict[str, int] = {}
|
||||
for st in JobStatus:
|
||||
out[st.value] = (
|
||||
session.scalar(select(func.count(Job.id)).where(Job.status == st.value)) or 0
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@router.get("/{screenshot_id}", response_model=ScreenshotDetail)
|
||||
def get_screenshot(
|
||||
screenshot_id: int,
|
||||
session: Session = Depends(db_session),
|
||||
) -> ScreenshotDetail:
|
||||
"""单张详情。"""
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
raise HTTPException(404, "Screenshot not found")
|
||||
cat_map = _category_map(session)
|
||||
brief = _to_brief(shot, cat_map)
|
||||
meta = shot.meta
|
||||
todos = [TodoBrief.model_validate(t) for t in shot.todos]
|
||||
return ScreenshotDetail(
|
||||
**brief.model_dump(),
|
||||
file_url=f"/api/screenshots/{shot.id}/file",
|
||||
size=shot.size,
|
||||
ocr_text=(meta.ocr_text if meta else None),
|
||||
ai_summary=(meta.ai_summary if meta else None),
|
||||
ai_suggestion=(meta.ai_suggestion if meta else None),
|
||||
todos=todos,
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/{screenshot_id}", response_model=ScreenshotDetail)
|
||||
def update_screenshot(
|
||||
screenshot_id: int,
|
||||
payload: ScreenshotUpdate,
|
||||
session: Session = Depends(db_session),
|
||||
) -> ScreenshotDetail:
|
||||
"""前端编辑:分类、收藏、隐藏、标签。"""
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
raise HTTPException(404, "Screenshot not found")
|
||||
# 用 model_fields_set 区分「未传字段」与「显式传入 null」
|
||||
# 这样前端 PATCH {"category_id": null} 可以真正清空分类
|
||||
fields = payload.model_fields_set
|
||||
if "category_id" in fields:
|
||||
if payload.category_id is not None:
|
||||
cat = session.get(Category, payload.category_id)
|
||||
if cat is None:
|
||||
raise HTTPException(400, "category not found")
|
||||
shot.category_id = payload.category_id
|
||||
if "is_favorite" in fields and payload.is_favorite is not None:
|
||||
shot.is_favorite = 1 if payload.is_favorite else 0
|
||||
if "is_hidden" in fields and payload.is_hidden is not None:
|
||||
shot.is_hidden = 1 if payload.is_hidden else 0
|
||||
if "tags" in fields and payload.tags is not None:
|
||||
tag_objs = []
|
||||
for name in payload.tags:
|
||||
name = (name or "").strip()[:64]
|
||||
if not name:
|
||||
continue
|
||||
tag = session.scalar(select(Tag).where(Tag.name == name))
|
||||
if tag is None:
|
||||
tag = Tag(name=name)
|
||||
session.add(tag)
|
||||
session.flush()
|
||||
tag_objs.append(tag)
|
||||
shot.tags = tag_objs
|
||||
session.commit()
|
||||
session.refresh(shot)
|
||||
return get_screenshot(screenshot_id, session)
|
||||
|
||||
|
||||
@router.post("/{screenshot_id}/reanalyze")
|
||||
def reanalyze(
|
||||
screenshot_id: int,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
"""加入队列重新分析。"""
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
raise HTTPException(404, "Screenshot not found")
|
||||
shot.ai_status = ProcessStatus.PENDING.value
|
||||
shot.ocr_status = ProcessStatus.PENDING.value
|
||||
job = Job(screenshot_id=shot.id, kind=JobKind.FULL.value, status=JobStatus.PENDING.value)
|
||||
session.add(job)
|
||||
session.commit()
|
||||
# 同步路由跑在线程池,必须 threadsafe 唤醒事件循环
|
||||
worker.notify_threadsafe()
|
||||
return {"ok": True, "job_id": job.id}
|
||||
|
||||
|
||||
@router.post("/{screenshot_id}/reocr")
|
||||
def reocr(
|
||||
screenshot_id: int,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
"""仅补跑 OCR,不重新调用 AI 分析。"""
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
raise HTTPException(404, "Screenshot not found")
|
||||
active = session.scalar(
|
||||
select(Job.id).where(
|
||||
Job.screenshot_id == shot.id,
|
||||
Job.kind == JobKind.OCR.value,
|
||||
Job.status.in_((JobStatus.PENDING.value, JobStatus.RUNNING.value)),
|
||||
)
|
||||
)
|
||||
if active is not None:
|
||||
return {"ok": True, "job_id": active, "message": "已有 OCR 任务在队列中"}
|
||||
shot.ocr_status = ProcessStatus.PENDING.value
|
||||
job = Job(
|
||||
screenshot_id=shot.id,
|
||||
kind=JobKind.OCR.value,
|
||||
status=JobStatus.PENDING.value,
|
||||
)
|
||||
session.add(job)
|
||||
session.commit()
|
||||
worker.notify_threadsafe()
|
||||
return {"ok": True, "job_id": job.id}
|
||||
|
||||
|
||||
@router.delete("/{screenshot_id}")
|
||||
def delete_screenshot(
|
||||
screenshot_id: int,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
"""删除记录(不删除原始文件)。"""
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
raise HTTPException(404, "Screenshot not found")
|
||||
session.delete(shot)
|
||||
session.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/{screenshot_id}/file")
|
||||
def get_file(
|
||||
screenshot_id: int,
|
||||
session: Session = Depends(db_session),
|
||||
) -> FileResponse:
|
||||
"""原图文件流。"""
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
raise HTTPException(404, "Screenshot not found")
|
||||
p = path_from_storage(shot.path)
|
||||
if not is_accessible_file(p):
|
||||
raise HTTPException(404, "file missing")
|
||||
return FileResponse(str(p))
|
||||
|
||||
|
||||
@router.get("/{screenshot_id}/thumb")
|
||||
def get_thumb(
|
||||
screenshot_id: int,
|
||||
session: Session = Depends(db_session),
|
||||
) -> FileResponse:
|
||||
"""缩略图流。"""
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
raise HTTPException(404, "Screenshot not found")
|
||||
if shot.thumb_path:
|
||||
p = Path(shot.thumb_path)
|
||||
if p.exists():
|
||||
return FileResponse(str(p), media_type="image/webp")
|
||||
# 兜底:返回原图
|
||||
p = path_from_storage(shot.path)
|
||||
if is_accessible_file(p):
|
||||
return FileResponse(str(p))
|
||||
raise HTTPException(404, "thumb missing")
|
||||
@@ -0,0 +1,220 @@
|
||||
"""设置接口:Provider 配置、分类、Tag。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.deps import db_session
|
||||
from app.models.category import Category
|
||||
from app.models.setting import (
|
||||
DEFAULT_RECOGNITION_MODE,
|
||||
KEY_OCR_PROVIDER,
|
||||
KEY_RECOGNITION_MODE,
|
||||
KEY_VLM_PROVIDER,
|
||||
)
|
||||
from app.models.screenshot import Screenshot
|
||||
from app.models.tag import Tag
|
||||
from app.providers import RECOGNITION_MODES
|
||||
from app.schemas.common import (
|
||||
CategoryIn,
|
||||
ProviderConfig,
|
||||
ProviderConfigOut,
|
||||
ProviderTestResult,
|
||||
RecognitionModeIn,
|
||||
)
|
||||
from app.services.provider_test import merge_provider_api_key, test_provider_config
|
||||
from app.services.settings_store import all_settings, get_setting, set_setting
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/settings", tags=["settings"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
def get_all(session: Session = Depends(db_session)) -> dict:
|
||||
"""返回所有非敏感设置。api_key 字段做脱敏。"""
|
||||
raw = all_settings(session)
|
||||
for key in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
|
||||
cfg = raw.get(key)
|
||||
if isinstance(cfg, dict) and cfg.get("api_key"):
|
||||
cfg["api_key_mask"] = _mask(cfg["api_key"])
|
||||
cfg["api_key"] = ""
|
||||
return raw
|
||||
|
||||
|
||||
def _mask(value: str) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
if len(value) <= 6:
|
||||
return "*" * len(value)
|
||||
return value[:3] + "*" * (len(value) - 6) + value[-3:]
|
||||
|
||||
|
||||
@router.get("/providers/{key}", response_model=ProviderConfigOut | None)
|
||||
def get_provider(
|
||||
key: str,
|
||||
session: Session = Depends(db_session),
|
||||
) -> ProviderConfigOut | None:
|
||||
"""读取 Provider 配置:api_key 明文不外传,只给一个掩码用于 UI 提示。"""
|
||||
if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
|
||||
raise HTTPException(400, "key must be ocr_provider or vlm_provider")
|
||||
raw = get_setting(session, key, None)
|
||||
if not raw:
|
||||
return None
|
||||
mask = _mask(raw.get("api_key", "") or "")
|
||||
return ProviderConfigOut(
|
||||
type=raw.get("type", ""),
|
||||
base_url=raw.get("base_url"),
|
||||
api_key="",
|
||||
api_key_mask=mask or None,
|
||||
model=raw.get("model"),
|
||||
extra=raw.get("extra", {}) or {},
|
||||
)
|
||||
|
||||
|
||||
@router.put("/providers/{key}")
|
||||
def put_provider(
|
||||
key: str,
|
||||
cfg: ProviderConfig,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
|
||||
raise HTTPException(400, "key must be ocr_provider or vlm_provider")
|
||||
# 如果客户端没有传新的 api_key(空字符串),保留旧值
|
||||
existing = get_setting(session, key, None)
|
||||
payload = cfg.model_dump()
|
||||
if (not payload.get("api_key")) and isinstance(existing, dict):
|
||||
payload["api_key"] = existing.get("api_key", "")
|
||||
set_setting(session, key, payload)
|
||||
session.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.post("/providers/{key}/test", response_model=ProviderTestResult)
|
||||
async def test_provider(
|
||||
key: str,
|
||||
cfg: ProviderConfig,
|
||||
session: Session = Depends(db_session),
|
||||
) -> ProviderTestResult:
|
||||
"""测试 OCR / 视觉 AI Provider 连通性(使用当前表单配置,api_key 可留空沿用已保存值)。"""
|
||||
if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
|
||||
raise HTTPException(400, "key must be ocr_provider or vlm_provider")
|
||||
existing = get_setting(session, key, None)
|
||||
merged = merge_provider_api_key(cfg, existing if isinstance(existing, dict) else None)
|
||||
result = await test_provider_config(key, merged)
|
||||
return ProviderTestResult(**result)
|
||||
|
||||
|
||||
@router.get("/recognition-mode")
|
||||
def get_recognition_mode(session: Session = Depends(db_session)) -> dict:
|
||||
"""读取文字识别策略:ocr / vision / hybrid。"""
|
||||
mode = get_setting(session, KEY_RECOGNITION_MODE, DEFAULT_RECOGNITION_MODE)
|
||||
if mode not in RECOGNITION_MODES:
|
||||
mode = DEFAULT_RECOGNITION_MODE
|
||||
return {"mode": mode, "options": list(RECOGNITION_MODES)}
|
||||
|
||||
|
||||
@router.put("/recognition-mode")
|
||||
def put_recognition_mode(
|
||||
payload: RecognitionModeIn,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
"""保存文字识别策略。"""
|
||||
if payload.mode not in RECOGNITION_MODES:
|
||||
raise HTTPException(400, f"mode must be one of {RECOGNITION_MODES}")
|
||||
set_setting(session, KEY_RECOGNITION_MODE, payload.mode)
|
||||
session.commit()
|
||||
return {"ok": True, "mode": payload.mode}
|
||||
|
||||
|
||||
@router.get("/categories")
|
||||
def list_categories(session: Session = Depends(db_session)) -> list[dict]:
|
||||
rows = session.scalars(select(Category).order_by(Category.id)).all()
|
||||
return [
|
||||
{"id": c.id, "name": c.name, "color": c.color, "prompt_hint": c.prompt_hint}
|
||||
for c in rows
|
||||
]
|
||||
|
||||
|
||||
@router.post("/categories")
|
||||
def create_category(
|
||||
payload: CategoryIn,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
exists = session.scalar(select(Category).where(Category.name == payload.name))
|
||||
if exists is not None:
|
||||
raise HTTPException(400, "category exists")
|
||||
cat = Category(name=payload.name, color=payload.color, prompt_hint=payload.prompt_hint)
|
||||
session.add(cat)
|
||||
session.commit()
|
||||
session.refresh(cat)
|
||||
return {"id": cat.id}
|
||||
|
||||
|
||||
@router.patch("/categories/{cat_id}")
|
||||
def update_category(
|
||||
cat_id: int,
|
||||
payload: CategoryIn,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
cat = session.get(Category, cat_id)
|
||||
if cat is None:
|
||||
raise HTTPException(404, "category not found")
|
||||
cat.name = payload.name
|
||||
cat.color = payload.color
|
||||
cat.prompt_hint = payload.prompt_hint
|
||||
session.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.delete("/categories/{cat_id}")
|
||||
def delete_category(
|
||||
cat_id: int,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
cat = session.get(Category, cat_id)
|
||||
if cat is None:
|
||||
raise HTTPException(404, "category not found")
|
||||
session.delete(cat)
|
||||
session.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/tags")
|
||||
def list_tags(
|
||||
session: Session = Depends(db_session),
|
||||
q: Optional[str] = Query(None, description="标签名关键词"),
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(200, ge=1, le=500),
|
||||
sort: str = Query("count_desc", description="count_desc|count_asc|name_asc|name_desc"),
|
||||
) -> dict:
|
||||
"""标签列表(含使用次数),支持搜索与分页。"""
|
||||
base = select(Tag.id)
|
||||
if q:
|
||||
base = base.where(Tag.name.ilike(f"%{q.strip()}%"))
|
||||
total = session.scalar(select(func.count()).select_from(base.subquery())) or 0
|
||||
|
||||
stmt = (
|
||||
select(Tag.id, Tag.name, Tag.color, func.count(Screenshot.id))
|
||||
.join(Tag.screenshots, isouter=True)
|
||||
.group_by(Tag.id)
|
||||
)
|
||||
if q:
|
||||
stmt = stmt.where(Tag.name.ilike(f"%{q.strip()}%"))
|
||||
|
||||
if sort == "count_asc":
|
||||
stmt = stmt.order_by(func.count(Screenshot.id).asc())
|
||||
elif sort == "name_asc":
|
||||
stmt = stmt.order_by(Tag.name.asc())
|
||||
elif sort == "name_desc":
|
||||
stmt = stmt.order_by(Tag.name.desc())
|
||||
else:
|
||||
stmt = stmt.order_by(func.count(Screenshot.id).desc())
|
||||
|
||||
rows = session.execute(stmt.offset((page - 1) * size).limit(size)).all()
|
||||
items = [
|
||||
{"id": r[0], "name": r[1], "color": r[2], "count": int(r[3] or 0)} for r in rows
|
||||
]
|
||||
return {"items": items, "total": int(total), "page": page, "size": size}
|
||||
@@ -0,0 +1,106 @@
|
||||
"""待办清单接口。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.deps import db_session
|
||||
from app.models.todo import Todo, TodoStatus
|
||||
from app.schemas.common import TodoUpdate
|
||||
from app.schemas.screenshot import TodoBrief, TodoListResp
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/todos", tags=["todos"])
|
||||
|
||||
|
||||
def _todo_filters(
|
||||
status: Optional[str],
|
||||
kind: Optional[str],
|
||||
q: Optional[str],
|
||||
) -> list:
|
||||
"""构建待办筛选条件。"""
|
||||
filters = []
|
||||
if status:
|
||||
filters.append(Todo.status == status)
|
||||
if kind:
|
||||
filters.append(Todo.kind == kind)
|
||||
if q:
|
||||
like = f"%{q.strip()}%"
|
||||
filters.append(or_(Todo.title.ilike(like), Todo.note.ilike(like)))
|
||||
return filters
|
||||
|
||||
|
||||
@router.get("", response_model=TodoListResp)
|
||||
def list_todos(
|
||||
session: Session = Depends(db_session),
|
||||
status: Optional[str] = Query(None),
|
||||
kind: Optional[str] = Query(None),
|
||||
q: Optional[str] = Query(None, description="标题/备注关键词"),
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(50, ge=1, le=200),
|
||||
) -> TodoListResp:
|
||||
"""按状态/类型/关键词分页查询。"""
|
||||
filters = _todo_filters(status, kind, q)
|
||||
base = select(Todo)
|
||||
if filters:
|
||||
base = base.where(and_(*filters))
|
||||
|
||||
total = session.scalar(select(func.count()).select_from(base.subquery())) or 0
|
||||
rows = session.scalars(
|
||||
base.order_by(Todo.created_at.desc()).offset((page - 1) * size).limit(size)
|
||||
).all()
|
||||
return TodoListResp(
|
||||
items=[TodoBrief.model_validate(r) for r in rows],
|
||||
total=int(total),
|
||||
page=page,
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
def summary(session: Session = Depends(db_session)) -> dict:
|
||||
"""各状态待办数量。"""
|
||||
return {
|
||||
st.value: session.scalar(select(func.count(Todo.id)).where(Todo.status == st.value)) or 0
|
||||
for st in TodoStatus
|
||||
}
|
||||
|
||||
|
||||
@router.patch("/{todo_id}", response_model=TodoBrief)
|
||||
def update_todo(
|
||||
todo_id: int,
|
||||
payload: TodoUpdate,
|
||||
session: Session = Depends(db_session),
|
||||
) -> TodoBrief:
|
||||
"""更新状态/标题/备注。"""
|
||||
todo = session.get(Todo, todo_id)
|
||||
if todo is None:
|
||||
raise HTTPException(404, "Todo not found")
|
||||
if payload.status is not None:
|
||||
todo.status = payload.status
|
||||
if payload.status == TodoStatus.DONE.value:
|
||||
todo.completed_at = datetime.utcnow()
|
||||
if payload.title is not None:
|
||||
todo.title = payload.title
|
||||
if payload.note is not None:
|
||||
todo.note = payload.note
|
||||
session.commit()
|
||||
session.refresh(todo)
|
||||
return TodoBrief.model_validate(todo)
|
||||
|
||||
|
||||
@router.delete("/{todo_id}")
|
||||
def delete_todo(
|
||||
todo_id: int,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
todo = session.get(Todo, todo_id)
|
||||
if todo is None:
|
||||
raise HTTPException(404, "Todo not found")
|
||||
session.delete(todo)
|
||||
session.commit()
|
||||
return {"ok": True}
|
||||
@@ -0,0 +1,256 @@
|
||||
"""监听目录的增删改、手动导入、分析队列。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from app.api.deps import db_session
|
||||
from app.core.db import session_scope
|
||||
from app.core.path_utils import (
|
||||
count_files_sample,
|
||||
is_accessible_dir,
|
||||
normalize_user_path,
|
||||
)
|
||||
from app.models.job import Job, JobStatus
|
||||
from app.models.screenshot import ProcessStatus, Screenshot
|
||||
from app.models.watch_folder import WatchFolder
|
||||
from app.schemas.common import WatchFolderIn, WatchFolderOut
|
||||
from app.schemas.job import JobListResp, JobOut, JobRetryIn
|
||||
from app.services.analyze import enqueue_ocr_jobs
|
||||
from app.services.ingest import ingest_directory
|
||||
from app.services.watcher import watcher_service
|
||||
from app.services.worker import worker
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/watch", tags=["watch"])
|
||||
|
||||
|
||||
def _validate_folder_path(raw: str) -> str:
|
||||
"""校验并规范化监听目录路径(含 UNC 网络路径)。"""
|
||||
normalized = normalize_user_path(raw)
|
||||
if not normalized:
|
||||
raise HTTPException(400, "路径不能为空")
|
||||
if not is_accessible_dir(normalized):
|
||||
raise HTTPException(
|
||||
400,
|
||||
f"目录不存在或无法访问: {normalized}。"
|
||||
"请确认 NAS 已挂载、有读权限,UNC 路径形如 \\\\服务器\\共享\\文件夹",
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
@router.get("/folders", response_model=list[WatchFolderOut])
|
||||
def list_folders(session: Session = Depends(db_session)) -> list[WatchFolderOut]:
|
||||
rows = session.scalars(select(WatchFolder).order_by(WatchFolder.id)).all()
|
||||
return [WatchFolderOut.model_validate(r) for r in rows]
|
||||
|
||||
|
||||
@router.post("/folders", response_model=WatchFolderOut)
|
||||
def add_folder(
|
||||
payload: WatchFolderIn,
|
||||
background: BackgroundTasks,
|
||||
session: Session = Depends(db_session),
|
||||
) -> WatchFolderOut:
|
||||
"""新增监听目录,自动触发一次扫描入库。"""
|
||||
normalized = _validate_folder_path(payload.path)
|
||||
exists = session.scalar(select(WatchFolder).where(WatchFolder.path == normalized))
|
||||
if exists is not None:
|
||||
raise HTTPException(400, "目录已存在")
|
||||
folder = WatchFolder(
|
||||
path=normalized,
|
||||
enabled=payload.enabled,
|
||||
recursive=payload.recursive,
|
||||
is_sensitive=payload.is_sensitive,
|
||||
)
|
||||
session.add(folder)
|
||||
session.commit()
|
||||
session.refresh(folder)
|
||||
watcher_service.reload()
|
||||
background.add_task(_scan_folder, normalized, payload.recursive)
|
||||
return WatchFolderOut.model_validate(folder)
|
||||
|
||||
|
||||
@router.patch("/folders/{folder_id}", response_model=WatchFolderOut)
|
||||
def update_folder(
|
||||
folder_id: int,
|
||||
payload: WatchFolderIn,
|
||||
session: Session = Depends(db_session),
|
||||
) -> WatchFolderOut:
|
||||
folder = session.get(WatchFolder, folder_id)
|
||||
if folder is None:
|
||||
raise HTTPException(404, "folder not found")
|
||||
normalized = _validate_folder_path(payload.path)
|
||||
folder.path = normalized
|
||||
folder.enabled = payload.enabled
|
||||
folder.recursive = payload.recursive
|
||||
folder.is_sensitive = payload.is_sensitive
|
||||
session.commit()
|
||||
session.refresh(folder)
|
||||
watcher_service.reload()
|
||||
return WatchFolderOut.model_validate(folder)
|
||||
|
||||
|
||||
@router.delete("/folders/{folder_id}")
|
||||
def delete_folder(
|
||||
folder_id: int,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
folder = session.get(WatchFolder, folder_id)
|
||||
if folder is None:
|
||||
raise HTTPException(404, "folder not found")
|
||||
session.delete(folder)
|
||||
session.commit()
|
||||
watcher_service.reload()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.post("/import")
|
||||
def import_now(
|
||||
payload: WatchFolderIn,
|
||||
background: BackgroundTasks,
|
||||
) -> dict:
|
||||
"""手动触发一次目录扫描(不一定要登记为监听)。"""
|
||||
normalized = _validate_folder_path(payload.path)
|
||||
background.add_task(_scan_folder, normalized, payload.recursive)
|
||||
return {"ok": True, "message": "已在后台扫描"}
|
||||
|
||||
|
||||
@router.post("/validate-path")
|
||||
def validate_path(payload: WatchFolderIn) -> dict:
|
||||
"""测试目录是否可访问(含 UNC 网络路径),返回抽样文件数。"""
|
||||
normalized = _validate_folder_path(payload.path)
|
||||
total, samples = count_files_sample(normalized, limit=3)
|
||||
return {
|
||||
"ok": True,
|
||||
"path": normalized,
|
||||
"sample_image_count": total,
|
||||
"samples": samples,
|
||||
"message": f"目录可访问,抽样发现约 {total}+ 张图片",
|
||||
}
|
||||
|
||||
|
||||
def _scan_folder(path: str, recursive: bool) -> None:
|
||||
"""后台任务:扫描目录入库,再通知 worker。
|
||||
|
||||
BackgroundTasks 的同步函数运行在线程池中,必须用 threadsafe 入口
|
||||
唤醒事件循环,否则 asyncio.Event.set() 会有竞态。
|
||||
"""
|
||||
with session_scope() as session:
|
||||
ingest_directory(session, path, recursive=recursive)
|
||||
worker.notify_threadsafe()
|
||||
|
||||
|
||||
@router.get("/queue")
|
||||
async def queue_status() -> dict:
|
||||
"""读取 worker 队列状态。"""
|
||||
counts = await worker.status()
|
||||
with session_scope() as session:
|
||||
counts["ocr_retryable"] = (
|
||||
session.scalar(
|
||||
select(func.count(Screenshot.id)).where(
|
||||
Screenshot.ocr_status == ProcessStatus.FAILED.value,
|
||||
Screenshot.ai_status == ProcessStatus.DONE.value,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
counts["ocr_pending"] = (
|
||||
session.scalar(
|
||||
select(func.count(Job.id)).where(
|
||||
Job.kind == "ocr",
|
||||
Job.status == JobStatus.PENDING.value,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
return counts
|
||||
|
||||
|
||||
def _job_to_out(job: Job, shot: Screenshot | None) -> JobOut:
|
||||
"""ORM -> JobOut,附带截图摘要字段。"""
|
||||
return JobOut(
|
||||
id=job.id,
|
||||
screenshot_id=job.screenshot_id,
|
||||
kind=job.kind,
|
||||
status=job.status,
|
||||
retries=job.retries or 0,
|
||||
last_error=job.last_error,
|
||||
created_at=job.created_at,
|
||||
started_at=job.started_at,
|
||||
finished_at=job.finished_at,
|
||||
thumb_url=(
|
||||
f"/api/screenshots/{shot.id}/thumb" if shot and shot.thumb_path else None
|
||||
),
|
||||
path=shot.path if shot else None,
|
||||
ai_title=(shot.meta.ai_title if shot and shot.meta else None),
|
||||
ai_status=shot.ai_status if shot else None,
|
||||
ocr_status=shot.ocr_status if shot else None,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/jobs", response_model=JobListResp)
|
||||
def list_jobs(
|
||||
status: Optional[str] = Query(None, description="pending|running|done|failed"),
|
||||
kind: Optional[str] = Query(None, description="full|ocr|vlm"),
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(50, ge=1, le=200),
|
||||
session: Session = Depends(db_session),
|
||||
) -> JobListResp:
|
||||
"""分页列出分析任务,默认按 id 倒序(最新的在前)。"""
|
||||
if status and status not in {s.value for s in JobStatus}:
|
||||
raise HTTPException(400, f"无效 status: {status}")
|
||||
|
||||
base = select(Job)
|
||||
if status:
|
||||
base = base.where(Job.status == status)
|
||||
if kind:
|
||||
base = base.where(Job.kind == kind)
|
||||
|
||||
total = session.scalar(select(func.count()).select_from(base.subquery())) or 0
|
||||
jobs = session.scalars(
|
||||
base.order_by(Job.id.desc()).offset((page - 1) * size).limit(size)
|
||||
).all()
|
||||
|
||||
shot_ids = [j.screenshot_id for j in jobs]
|
||||
shots: dict[int, Screenshot] = {}
|
||||
if shot_ids:
|
||||
rows = session.scalars(
|
||||
select(Screenshot)
|
||||
.where(Screenshot.id.in_(shot_ids))
|
||||
.options(selectinload(Screenshot.meta))
|
||||
).all()
|
||||
shots = {s.id: s for s in rows}
|
||||
|
||||
items = [_job_to_out(j, shots.get(j.screenshot_id)) for j in jobs]
|
||||
return JobListResp(items=items, total=total, page=page, size=size)
|
||||
|
||||
|
||||
@router.post("/jobs/retry-failed")
|
||||
def retry_failed_jobs(payload: JobRetryIn | None = None) -> dict:
|
||||
"""将全部或指定 failed 任务重新排队。"""
|
||||
job_ids = payload.job_ids if payload else None
|
||||
count = worker.retry_failed(job_ids)
|
||||
return {"ok": True, "count": count}
|
||||
|
||||
|
||||
@router.post("/jobs/reset-stale")
|
||||
def reset_stale_jobs(
|
||||
minutes: int = Query(5, ge=1, le=1440),
|
||||
reset_all: bool = Query(False, description="为 true 时复位全部 RUNNING"),
|
||||
) -> dict:
|
||||
"""复位僵尸 RUNNING 任务(worker 崩溃或未正常 finish 时)。"""
|
||||
count = worker.reset_stale_running(minutes=minutes, reset_all=reset_all)
|
||||
return {"ok": True, "count": count}
|
||||
|
||||
|
||||
@router.post("/jobs/enqueue-ocr-failed")
|
||||
def enqueue_ocr_failed(limit: int = Query(500, ge=1, le=5000)) -> dict:
|
||||
"""为 AI 已成功但 OCR 失败的截图批量创建 OCR 补跑任务。"""
|
||||
count = enqueue_ocr_jobs(limit=limit)
|
||||
if count:
|
||||
worker.notify()
|
||||
return {"ok": True, "count": count}
|
||||
@@ -0,0 +1,66 @@
|
||||
"""全局配置:路径、数据库、并发参数等。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
# 默认数据目录:放在 backend/.data 下,便于零配置启动
|
||||
_BACKEND_ROOT = Path(__file__).resolve().parents[2]
|
||||
_DEFAULT_DATA_DIR = _BACKEND_ROOT / ".data"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""读取 .env 与环境变量的全局配置。"""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=str(_BACKEND_ROOT / ".env"),
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# 应用基础
|
||||
app_name: str = "snapAna"
|
||||
debug: bool = False
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 8765
|
||||
|
||||
# 数据目录
|
||||
data_dir: Path = Field(default=_DEFAULT_DATA_DIR)
|
||||
|
||||
# 任务并发
|
||||
analyze_concurrency: int = 4
|
||||
max_retries: int = 3
|
||||
|
||||
# 缩略图
|
||||
thumb_size: int = 320
|
||||
vlm_max_side: int = 1280 # 上传 VLM 前压缩的长边像素
|
||||
|
||||
# CORS
|
||||
cors_origins: list[str] = ["http://localhost:5173", "http://127.0.0.1:5173"]
|
||||
|
||||
@property
|
||||
def db_path(self) -> Path:
|
||||
"""SQLite 数据库文件路径。"""
|
||||
return self.data_dir / "snapana.db"
|
||||
|
||||
@property
|
||||
def db_url(self) -> str:
|
||||
"""SQLAlchemy 连接串。"""
|
||||
return f"sqlite:///{self.db_path.as_posix()}"
|
||||
|
||||
@property
|
||||
def thumb_dir(self) -> Path:
|
||||
"""缩略图缓存目录。"""
|
||||
return self.data_dir / "thumbs"
|
||||
|
||||
def ensure_dirs(self) -> None:
|
||||
"""确保所有运行期目录存在。"""
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.thumb_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
settings.ensure_dirs()
|
||||
@@ -0,0 +1,153 @@
|
||||
"""数据库引擎、会话与初始化。
|
||||
|
||||
使用 SQLAlchemy 2.0 + SQLite。FTS5 虚拟表通过原生 SQL 创建,并配套触发器
|
||||
让 OCR/AI 字段更新时自动同步到全文索引。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator
|
||||
|
||||
from sqlalchemy import create_engine, event, text
|
||||
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""全局声明性 Base。"""
|
||||
|
||||
|
||||
engine = create_engine(
|
||||
settings.db_url,
|
||||
echo=False,
|
||||
future=True,
|
||||
connect_args={"check_same_thread": False},
|
||||
)
|
||||
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def _sqlite_pragmas(dbapi_connection, _connection_record):
|
||||
"""启用外键、WAL、忙等待等 SQLite 优化项。"""
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.execute("PRAGMA busy_timeout=5000")
|
||||
cursor.close()
|
||||
|
||||
|
||||
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True)
|
||||
|
||||
|
||||
def get_session() -> Iterator[Session]:
|
||||
"""FastAPI 依赖注入:每个请求一个会话。"""
|
||||
with SessionLocal() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@contextmanager
|
||||
def session_scope() -> Iterator[Session]:
|
||||
"""常规上下文管理:自动 commit/rollback。"""
|
||||
session = SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
# FTS5 虚拟表与触发器 SQL(独立维护,便于以后调整字段)
|
||||
_FTS_SCHEMA_SQL = [
|
||||
"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS screenshots_fts
|
||||
USING fts5(
|
||||
ocr_text,
|
||||
ai_title,
|
||||
ai_summary,
|
||||
ai_suggestion,
|
||||
content='screenshot_meta',
|
||||
content_rowid='screenshot_id',
|
||||
tokenize='unicode61'
|
||||
);
|
||||
""",
|
||||
"""
|
||||
CREATE TRIGGER IF NOT EXISTS screenshot_meta_ai
|
||||
AFTER INSERT ON screenshot_meta BEGIN
|
||||
INSERT INTO screenshots_fts(rowid, ocr_text, ai_title, ai_summary, ai_suggestion)
|
||||
VALUES (new.screenshot_id,
|
||||
coalesce(new.ocr_text, ''),
|
||||
coalesce(new.ai_title, ''),
|
||||
coalesce(new.ai_summary, ''),
|
||||
coalesce(new.ai_suggestion, ''));
|
||||
END;
|
||||
""",
|
||||
"""
|
||||
CREATE TRIGGER IF NOT EXISTS screenshot_meta_ad
|
||||
AFTER DELETE ON screenshot_meta BEGIN
|
||||
INSERT INTO screenshots_fts(screenshots_fts, rowid, ocr_text, ai_title, ai_summary, ai_suggestion)
|
||||
VALUES('delete', old.screenshot_id,
|
||||
coalesce(old.ocr_text, ''),
|
||||
coalesce(old.ai_title, ''),
|
||||
coalesce(old.ai_summary, ''),
|
||||
coalesce(old.ai_suggestion, ''));
|
||||
END;
|
||||
""",
|
||||
"""
|
||||
CREATE TRIGGER IF NOT EXISTS screenshot_meta_au
|
||||
AFTER UPDATE ON screenshot_meta BEGIN
|
||||
INSERT INTO screenshots_fts(screenshots_fts, rowid, ocr_text, ai_title, ai_summary, ai_suggestion)
|
||||
VALUES('delete', old.screenshot_id,
|
||||
coalesce(old.ocr_text, ''),
|
||||
coalesce(old.ai_title, ''),
|
||||
coalesce(old.ai_summary, ''),
|
||||
coalesce(old.ai_suggestion, ''));
|
||||
INSERT INTO screenshots_fts(rowid, ocr_text, ai_title, ai_summary, ai_suggestion)
|
||||
VALUES (new.screenshot_id,
|
||||
coalesce(new.ocr_text, ''),
|
||||
coalesce(new.ai_title, ''),
|
||||
coalesce(new.ai_summary, ''),
|
||||
coalesce(new.ai_suggestion, ''));
|
||||
END;
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
def init_db() -> None:
|
||||
"""启动时建表并装配 FTS5、灌入默认分类。"""
|
||||
from app.models import register_all # noqa: F401
|
||||
register_all()
|
||||
Base.metadata.create_all(engine)
|
||||
with engine.begin() as conn:
|
||||
for stmt in _FTS_SCHEMA_SQL:
|
||||
conn.execute(text(stmt))
|
||||
_migrate_legacy_schema(conn)
|
||||
|
||||
# 启动期 seed 默认分类(即使首次启动也能在「设置」/筛选页看到分类)
|
||||
from app.services.analyze import ensure_default_categories
|
||||
ensure_default_categories()
|
||||
|
||||
|
||||
def _migrate_legacy_schema(conn) -> None:
|
||||
"""轻量迁移:旧版本的 screenshots.category_id 没有外键。
|
||||
|
||||
SQLite 不支持 ALTER TABLE 加外键,但删除分类时 ON DELETE SET NULL 失效
|
||||
会导致悬空引用。检测到旧表时,主动用一次性 SQL 清理掉无效引用并打日志,
|
||||
建议用户用「分类管理」页重建索引。
|
||||
"""
|
||||
pragma_rows = conn.execute(
|
||||
text("PRAGMA foreign_key_list(screenshots)")
|
||||
).fetchall()
|
||||
has_cat_fk = any(row[2] == "categories" for row in pragma_rows)
|
||||
if not has_cat_fk:
|
||||
# 清理悬空 category_id,避免列表统计出错
|
||||
conn.execute(
|
||||
text(
|
||||
"UPDATE screenshots SET category_id = NULL "
|
||||
"WHERE category_id IS NOT NULL "
|
||||
"AND category_id NOT IN (SELECT id FROM categories)"
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
"""统一日志配置。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
def setup_logging(debug: bool = False) -> None:
|
||||
"""初始化根 logger 的格式与级别。"""
|
||||
level = logging.DEBUG if debug else logging.INFO
|
||||
fmt = "%(asctime)s | %(levelname)-7s | %(name)s | %(message)s"
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(logging.Formatter(fmt))
|
||||
root = logging.getLogger()
|
||||
root.handlers.clear()
|
||||
root.addHandler(handler)
|
||||
root.setLevel(level)
|
||||
# 降低第三方库噪音
|
||||
for noisy in ("watchdog", "httpx", "PIL"):
|
||||
logging.getLogger(noisy).setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""统一入口获取 logger。"""
|
||||
return logging.getLogger(name)
|
||||
@@ -0,0 +1,102 @@
|
||||
"""跨平台路径工具:重点兼容 Windows UNC 网络路径(\\\\NAS\\share\\...)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path, PureWindowsPath
|
||||
|
||||
|
||||
def normalize_user_path(raw: str) -> str:
|
||||
"""规范化用户输入的路径,保留 UNC 反斜杠格式。
|
||||
|
||||
示例:
|
||||
- \\\\JIULUGNAS\\personal_folder\\Photos -> 原样保留
|
||||
- //JIULUGNAS/personal_folder/Photos -> 转为 UNC
|
||||
- D:/Pictures/Screenshots -> D:\\Pictures\\Screenshots
|
||||
"""
|
||||
raw = (raw or "").strip().strip('"').strip("'")
|
||||
if not raw:
|
||||
return raw
|
||||
|
||||
if sys.platform == "win32":
|
||||
# //server/share -> \\server\share
|
||||
if raw.startswith("//") and not raw.startswith("///"):
|
||||
raw = "\\\\" + raw.lstrip("/").replace("/", "\\")
|
||||
elif raw.startswith("\\\\"):
|
||||
pass
|
||||
else:
|
||||
raw = raw.replace("/", "\\")
|
||||
return str(PureWindowsPath(raw))
|
||||
|
||||
return str(Path(raw).expanduser())
|
||||
|
||||
|
||||
def path_from_storage(stored: str) -> Path:
|
||||
"""从数据库读出的路径转为 Path(修复历史 as_posix 导致的 //NAS/...)。"""
|
||||
if not stored:
|
||||
return Path(stored)
|
||||
if sys.platform == "win32":
|
||||
# 历史数据://JIULUGNAS/foo/bar -> \\JIULUGNAS\foo\bar
|
||||
if stored.startswith("//") and not stored.startswith("///"):
|
||||
stored = "\\\\" + stored.lstrip("/").replace("/", "\\")
|
||||
return Path(stored)
|
||||
|
||||
|
||||
def path_to_storage(path: Path | str) -> str:
|
||||
"""写入数据库 / 比较用的路径字符串;Windows 下保留反斜杠。"""
|
||||
if isinstance(path, Path):
|
||||
if sys.platform == "win32":
|
||||
return str(path)
|
||||
return path.as_posix()
|
||||
return normalize_user_path(str(path)) if sys.platform == "win32" else str(path)
|
||||
|
||||
|
||||
def is_accessible_dir(path: str | Path) -> bool:
|
||||
"""目录是否可访问(UNC / 本地均适用)。"""
|
||||
try:
|
||||
return os.path.isdir(str(path))
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def is_accessible_file(path: str | Path) -> bool:
|
||||
"""文件是否可访问。"""
|
||||
try:
|
||||
return os.path.isfile(str(path))
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def path_is_under(parent: str | Path, child: str | Path) -> bool:
|
||||
"""判断 child 是否在 parent 目录下(用于敏感目录检测)。"""
|
||||
try:
|
||||
parent_norm = os.path.normcase(os.path.normpath(str(parent)))
|
||||
child_norm = os.path.normcase(os.path.normpath(str(child)))
|
||||
if not parent_norm.endswith(os.sep):
|
||||
parent_norm += os.sep
|
||||
return child_norm.startswith(parent_norm) or child_norm == parent_norm.rstrip(os.sep)
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def count_files_sample(root: str | Path, limit: int = 5) -> tuple[int, list[str]]:
|
||||
"""快速抽样统计目录下图片数量(网络路径可能较慢,limit 控制遍历深度)。"""
|
||||
from app.services.thumbnail import is_supported
|
||||
|
||||
root_p = path_from_storage(str(root)) if isinstance(root, str) else root
|
||||
total = 0
|
||||
samples: list[str] = []
|
||||
try:
|
||||
for dirpath, _, filenames in os.walk(str(root_p)):
|
||||
for name in filenames:
|
||||
p = Path(dirpath) / name
|
||||
if not is_supported(p):
|
||||
continue
|
||||
total += 1
|
||||
if len(samples) < limit:
|
||||
samples.append(path_to_storage(p))
|
||||
if total >= 1000:
|
||||
break
|
||||
except OSError:
|
||||
pass
|
||||
return total, samples
|
||||
@@ -0,0 +1,60 @@
|
||||
"""FastAPI 应用入口。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api import screenshots, settings_api, todos, watch
|
||||
from app.core.config import settings
|
||||
from app.core.db import init_db
|
||||
from app.core.logger import get_logger, setup_logging
|
||||
from app.services.watcher import watcher_service
|
||||
from app.services.worker import worker
|
||||
|
||||
|
||||
setup_logging(settings.debug)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI): # noqa: ARG001
|
||||
"""启动时初始化 DB、启动监听器与分析 worker。"""
|
||||
init_db()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
async def notify() -> None:
|
||||
worker.notify()
|
||||
|
||||
watcher_service.start(loop, notify)
|
||||
await worker.start()
|
||||
logger.info("snapAna 启动完成 @ http://%s:%d", settings.host, settings.port)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
watcher_service.stop()
|
||||
await worker.stop()
|
||||
|
||||
|
||||
app = FastAPI(title="snapAna", version="0.1.0", lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(screenshots.router)
|
||||
app.include_router(todos.router)
|
||||
app.include_router(settings_api.router)
|
||||
app.include_router(watch.router)
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
def health() -> dict:
|
||||
"""健康检查。"""
|
||||
return {"status": "ok", "version": "0.1.0"}
|
||||
@@ -0,0 +1,13 @@
|
||||
"""SQLAlchemy 模型集中注册入口。"""
|
||||
|
||||
|
||||
def register_all() -> None:
|
||||
"""显式导入以触发模型注册到 Base.metadata。"""
|
||||
from . import screenshot # noqa: F401
|
||||
from . import meta # noqa: F401
|
||||
from . import tag # noqa: F401
|
||||
from . import category # noqa: F401
|
||||
from . import todo # noqa: F401
|
||||
from . import job # noqa: F401
|
||||
from . import watch_folder # noqa: F401
|
||||
from . import setting # noqa: F401
|
||||
@@ -0,0 +1,32 @@
|
||||
"""截图分类。预置常见类目,AI 命中即可写回。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import Integer, String, Text, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.core.db import Base
|
||||
|
||||
|
||||
class Category(Base):
|
||||
"""截图分类。"""
|
||||
|
||||
__tablename__ = "categories"
|
||||
__table_args__ = (UniqueConstraint("name", name="uq_categories_name"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
color: Mapped[str | None] = mapped_column(String(16), nullable=True)
|
||||
prompt_hint: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
|
||||
# 首次启动时灌入的默认分类
|
||||
DEFAULT_CATEGORIES: list[dict[str, str | None]] = [
|
||||
{"name": "知识技术", "color": "#3b82f6", "prompt_hint": "技术文章、代码、教程、文档截图"},
|
||||
{"name": "梗图幽默", "color": "#f59e0b", "prompt_hint": "搞笑图、表情包、梗图"},
|
||||
{"name": "小说文字", "color": "#8b5cf6", "prompt_hint": "长段文字、小说阅读、电子书"},
|
||||
{"name": "聊天记录", "color": "#10b981", "prompt_hint": "微信/QQ/Slack 等聊天截图"},
|
||||
{"name": "UI 设计", "color": "#ec4899", "prompt_hint": "界面设计、网页/App 灵感参考"},
|
||||
{"name": "生活记录", "color": "#22c55e", "prompt_hint": "日常照片、生活记录、票据"},
|
||||
{"name": "购物商品", "color": "#ef4444", "prompt_hint": "商品截图、价格、订单"},
|
||||
{"name": "其他", "color": "#6b7280", "prompt_hint": "无法明确归类"},
|
||||
]
|
||||
@@ -0,0 +1,54 @@
|
||||
"""分析任务队列:持久化到 SQLite,断电可恢复。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, Index, Integer, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.core.db import Base
|
||||
|
||||
|
||||
class JobKind(str, Enum):
|
||||
"""任务种类。"""
|
||||
|
||||
OCR = "ocr"
|
||||
VLM = "vlm"
|
||||
FULL = "full" # OCR + VLM 一条龙
|
||||
|
||||
|
||||
class JobStatus(str, Enum):
|
||||
"""任务运行状态。"""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
DONE = "done"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class Job(Base):
|
||||
"""单条分析任务记录。"""
|
||||
|
||||
__tablename__ = "jobs"
|
||||
__table_args__ = (
|
||||
Index("ix_jobs_status", "status"),
|
||||
Index("ix_jobs_kind_status", "kind", "status"),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
screenshot_id: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey("screenshots.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
kind: Mapped[str] = mapped_column(String(16), default=JobKind.FULL.value)
|
||||
status: Mapped[str] = mapped_column(String(16), default=JobStatus.PENDING.value)
|
||||
retries: Mapped[int] = mapped_column(Integer, default=0)
|
||||
last_error: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.now(), nullable=False
|
||||
)
|
||||
started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
finished_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
@@ -0,0 +1,27 @@
|
||||
"""截图的 OCR / AI 元信息。与 screenshot 1:1。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import ForeignKey, Integer, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.core.db import Base
|
||||
|
||||
|
||||
class ScreenshotMeta(Base):
|
||||
"""OCR 文本 + AI 结构化结果。"""
|
||||
|
||||
__tablename__ = "screenshot_meta"
|
||||
|
||||
screenshot_id: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey("screenshots.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
ocr_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
ai_title: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
ai_summary: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
ai_suggestion: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
ai_raw_json: Mapped[str | None] = mapped_column(Text, nullable=True) # 完整原始 JSON
|
||||
|
||||
screenshot = relationship("Screenshot", back_populates="meta")
|
||||
@@ -0,0 +1,86 @@
|
||||
"""截图主表与处理状态。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.core.db import Base
|
||||
|
||||
|
||||
class ProcessStatus(str, Enum):
|
||||
"""处理流水线的状态枚举。"""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
DONE = "done"
|
||||
FAILED = "failed"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
class Screenshot(Base):
|
||||
"""截图文件主记录。"""
|
||||
|
||||
__tablename__ = "screenshots"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("file_hash", name="uq_screenshots_file_hash"),
|
||||
Index("ix_screenshots_captured_at", "captured_at"),
|
||||
Index("ix_screenshots_ai_status", "ai_status"),
|
||||
Index("ix_screenshots_category_id", "category_id"),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
path: Mapped[str] = mapped_column(String(1024), nullable=False)
|
||||
file_hash: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
width: Mapped[int] = mapped_column(Integer, default=0)
|
||||
height: Mapped[int] = mapped_column(Integer, default=0)
|
||||
size: Mapped[int] = mapped_column(BigInteger, default=0)
|
||||
|
||||
captured_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
imported_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.now(), nullable=False
|
||||
)
|
||||
|
||||
thumb_path: Mapped[str | None] = mapped_column(String(1024), nullable=True)
|
||||
|
||||
ocr_status: Mapped[str] = mapped_column(String(16), default=ProcessStatus.PENDING.value)
|
||||
ai_status: Mapped[str] = mapped_column(String(16), default=ProcessStatus.PENDING.value)
|
||||
|
||||
# AI 写回的分类:外键 + SET NULL,删除分类时自动把引用置空
|
||||
category_id: Mapped[int | None] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey("categories.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
is_favorite: Mapped[int] = mapped_column(Integer, default=0) # 0/1,便于 SQLite 索引
|
||||
is_hidden: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
meta = relationship(
|
||||
"ScreenshotMeta",
|
||||
back_populates="screenshot",
|
||||
uselist=False,
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
tags = relationship(
|
||||
"Tag",
|
||||
secondary="screenshot_tags",
|
||||
back_populates="screenshots",
|
||||
lazy="selectin",
|
||||
)
|
||||
todos = relationship(
|
||||
"Todo",
|
||||
back_populates="screenshot",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
@@ -0,0 +1,26 @@
|
||||
"""键值设置:Provider 配置等以 JSON 形式存储。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.core.db import Base
|
||||
|
||||
|
||||
class Setting(Base):
|
||||
"""通用键值设置。"""
|
||||
|
||||
__tablename__ = "settings"
|
||||
|
||||
key: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
value_json: Mapped[str] = mapped_column(Text, nullable=False, default="null")
|
||||
|
||||
|
||||
# 设置键名常量
|
||||
KEY_OCR_PROVIDER = "ocr_provider"
|
||||
KEY_VLM_PROVIDER = "vlm_provider"
|
||||
KEY_RECOGNITION_MODE = "recognition_mode" # ocr | vision | hybrid
|
||||
KEY_CATEGORY_HINT = "category_hint"
|
||||
|
||||
# 默认识别模式:混合(OCR 文本 + 视觉 AI 联合分析)
|
||||
DEFAULT_RECOGNITION_MODE = "hybrid"
|
||||
@@ -0,0 +1,42 @@
|
||||
"""标签与多对多关联。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import Column, ForeignKey, Integer, String, Table, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.core.db import Base
|
||||
|
||||
|
||||
screenshot_tags = Table(
|
||||
"screenshot_tags",
|
||||
Base.metadata,
|
||||
Column(
|
||||
"screenshot_id",
|
||||
Integer,
|
||||
ForeignKey("screenshots.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
),
|
||||
Column(
|
||||
"tag_id",
|
||||
Integer,
|
||||
ForeignKey("tags.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
"""用户/AI 共享的自由标签。"""
|
||||
|
||||
__tablename__ = "tags"
|
||||
__table_args__ = (UniqueConstraint("name", name="uq_tags_name"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
color: Mapped[str | None] = mapped_column(String(16), nullable=True)
|
||||
|
||||
screenshots = relationship(
|
||||
"Screenshot",
|
||||
secondary=screenshot_tags,
|
||||
back_populates="tags",
|
||||
)
|
||||
@@ -0,0 +1,47 @@
|
||||
"""AI 抽取的待办(待看/待读/待办)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, Index, Integer, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.core.db import Base
|
||||
|
||||
|
||||
class TodoStatus(str, Enum):
|
||||
"""待办状态。"""
|
||||
|
||||
PENDING = "pending"
|
||||
DOING = "doing"
|
||||
DONE = "done"
|
||||
DROPPED = "dropped"
|
||||
|
||||
|
||||
class Todo(Base):
|
||||
"""AI 从截图中抽取的待办项。"""
|
||||
|
||||
__tablename__ = "todos"
|
||||
__table_args__ = (
|
||||
Index("ix_todos_status", "status"),
|
||||
Index("ix_todos_screenshot_id", "screenshot_id"),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
screenshot_id: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey("screenshots.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
title: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
note: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
kind: Mapped[str | None] = mapped_column(String(32), nullable=True) # 待看/待读/待办等
|
||||
status: Mapped[str] = mapped_column(String(16), default=TodoStatus.PENDING.value)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.now(), nullable=False
|
||||
)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
screenshot = relationship("Screenshot", back_populates="todos")
|
||||
@@ -0,0 +1,25 @@
|
||||
"""被监听的截图目录列表。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Integer, String, UniqueConstraint, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.core.db import Base
|
||||
|
||||
|
||||
class WatchFolder(Base):
|
||||
"""监听的截图目录。"""
|
||||
|
||||
__tablename__ = "watch_folders"
|
||||
__table_args__ = (UniqueConstraint("path", name="uq_watch_folders_path"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
path: Mapped[str] = mapped_column(String(1024), nullable=False)
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
recursive: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
is_sensitive: Mapped[bool] = mapped_column(Boolean, default=False) # 是否禁止上传云端
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.now(), nullable=False
|
||||
)
|
||||
@@ -0,0 +1,81 @@
|
||||
"""Provider 工厂,按设置中的 type 字段实例化。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from app.schemas.common import ProviderConfig
|
||||
|
||||
from .base import OCRProvider, VLMProvider
|
||||
from .ocr_http import HttpOCR
|
||||
from .ocr_paddle import PaddleOCRProvider
|
||||
from .ocr_tesseract import TesseractOCR
|
||||
from .ocr_vision import VisionOCR
|
||||
from .vlm_openai import OpenAICompatVLM
|
||||
|
||||
# OCR Provider 类型常量
|
||||
OCR_TYPES = ("tesseract", "paddleocr", "http", "vision", "none")
|
||||
VLM_TYPES = ("openai_compat", "none")
|
||||
RECOGNITION_MODES = ("ocr", "vision", "hybrid")
|
||||
|
||||
|
||||
def build_ocr_provider(
|
||||
cfg: ProviderConfig | None,
|
||||
*,
|
||||
allow_upload: bool = True,
|
||||
) -> Optional[OCRProvider]:
|
||||
"""根据配置构造传统 OCR / 视觉 OCR Provider。"""
|
||||
if cfg is None or cfg.type in ("", "none", "disabled"):
|
||||
return None
|
||||
if cfg.type == "tesseract":
|
||||
return TesseractOCR(
|
||||
lang=cfg.extra.get("lang", "chi_sim+eng"),
|
||||
cmd=cfg.extra.get("cmd"),
|
||||
)
|
||||
if cfg.type == "paddleocr":
|
||||
return PaddleOCRProvider(lang=cfg.extra.get("lang", "ch"))
|
||||
if cfg.type == "http":
|
||||
if not cfg.base_url:
|
||||
raise ValueError("HTTP OCR 需要配置 base_url")
|
||||
return HttpOCR(
|
||||
base_url=cfg.base_url,
|
||||
api_key=cfg.api_key or "",
|
||||
text_path=str(cfg.extra.get("text_path", "text")),
|
||||
headers=cfg.extra.get("headers") if isinstance(cfg.extra.get("headers"), dict) else None,
|
||||
timeout=float(cfg.extra.get("timeout", 30)),
|
||||
)
|
||||
if cfg.type == "vision":
|
||||
return build_vision_ocr(cfg, allow_upload=allow_upload)
|
||||
raise ValueError(f"未知 OCR Provider 类型: {cfg.type}")
|
||||
|
||||
|
||||
def build_vision_ocr(
|
||||
cfg: ProviderConfig | None,
|
||||
*,
|
||||
allow_upload: bool = True,
|
||||
) -> Optional[VisionOCR]:
|
||||
"""从 ProviderConfig 构造视觉 OCR(可与 VLM 共用同一套接口配置)。"""
|
||||
if cfg is None or cfg.type in ("", "none", "disabled"):
|
||||
return None
|
||||
base_url = cfg.base_url or "http://localhost:11434/v1"
|
||||
model = cfg.model or "qwen2.5vl:7b"
|
||||
return VisionOCR(
|
||||
base_url=base_url,
|
||||
api_key=cfg.api_key or "",
|
||||
model=model,
|
||||
timeout=float(cfg.extra.get("timeout", 60)),
|
||||
allow_upload=allow_upload,
|
||||
)
|
||||
|
||||
|
||||
def build_vlm_provider(cfg: ProviderConfig | None) -> Optional[VLMProvider]:
|
||||
"""根据配置构造 VLM Provider(AI 分类/摘要/标签)。"""
|
||||
if cfg is None or cfg.type in ("", "none", "disabled"):
|
||||
return None
|
||||
if cfg.type in ("openai_compat", "openai", "ollama", "glm", "minimax", "moonshot", "vision"):
|
||||
return OpenAICompatVLM(
|
||||
base_url=cfg.base_url or "http://localhost:11434/v1",
|
||||
api_key=cfg.api_key or "",
|
||||
model=cfg.model or "gpt-4o-mini",
|
||||
timeout=float(cfg.extra.get("timeout", 60)),
|
||||
)
|
||||
raise ValueError(f"未知 VLM Provider 类型: {cfg.type}")
|
||||
@@ -0,0 +1,46 @@
|
||||
"""OCR / VLM Provider 抽象接口。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLMResult:
|
||||
"""VLM 结构化分析结果。"""
|
||||
|
||||
title: str = ""
|
||||
summary: str = ""
|
||||
category: str | None = None
|
||||
tags: list[str] = field(default_factory=list)
|
||||
todos: list[dict[str, str]] = field(default_factory=list) # [{title, kind, note}]
|
||||
suggestion: str = ""
|
||||
raw: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class OCRProvider(ABC):
|
||||
"""OCR 接口:输入图片路径,返回文本。"""
|
||||
|
||||
name: str = "ocr"
|
||||
|
||||
@abstractmethod
|
||||
async def recognize(self, image_path: Path) -> str:
|
||||
...
|
||||
|
||||
|
||||
class VLMProvider(ABC):
|
||||
"""多模态接口:根据图片 + OCR 文本生成结构化分析。"""
|
||||
|
||||
name: str = "vlm"
|
||||
|
||||
@abstractmethod
|
||||
async def analyze(
|
||||
self,
|
||||
image_path: Path,
|
||||
ocr_text: str,
|
||||
categories: list[str],
|
||||
allow_upload: bool,
|
||||
) -> VLMResult:
|
||||
...
|
||||
@@ -0,0 +1,63 @@
|
||||
"""通用 HTTP OCR:向自定义 REST 接口 POST 图片并解析文本。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import OCRProvider
|
||||
|
||||
|
||||
class HttpOCR(OCRProvider):
|
||||
"""POST JSON {"image_base64": "..."} 到指定 URL,从响应 JSON 取文本。
|
||||
|
||||
extra 配置项:
|
||||
- text_path: 点分路径,如 "data.text" 或 "result",默认 "text"
|
||||
- headers: 额外请求头 dict
|
||||
"""
|
||||
|
||||
name = "http"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: str = "",
|
||||
text_path: str = "text",
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float = 30.0,
|
||||
) -> None:
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.api_key = api_key
|
||||
self.text_path = text_path
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
|
||||
async def recognize(self, image_path: Path) -> str:
|
||||
with open(image_path, "rb") as f:
|
||||
encoded = base64.b64encode(f.read()).decode("ascii")
|
||||
|
||||
headers = {"Content-Type": "application/json", **self.headers}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
payload = {"image_base64": encoded, "image": encoded}
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
resp = await client.post(self.base_url, json=payload, headers=headers)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
return str(_dig(data, self.text_path) or "").strip()
|
||||
|
||||
|
||||
def _dig(obj: Any, path: str) -> Any:
|
||||
"""按点分路径从嵌套 dict 取值。"""
|
||||
cur = obj
|
||||
for part in path.split("."):
|
||||
if not isinstance(cur, dict):
|
||||
return None
|
||||
cur = cur.get(part)
|
||||
return cur
|
||||
@@ -0,0 +1,43 @@
|
||||
"""PaddleOCR 本地 OCR(可选依赖)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from .base import OCRProvider
|
||||
|
||||
|
||||
class PaddleOCRProvider(OCRProvider):
|
||||
"""通过 PaddleOCR 本地识文。需 pip install paddleocr paddlepaddle。"""
|
||||
|
||||
name = "paddleocr"
|
||||
|
||||
def __init__(self, lang: str = "ch") -> None:
|
||||
self.lang = lang
|
||||
self._engine = None
|
||||
|
||||
async def recognize(self, image_path: Path) -> str:
|
||||
return await asyncio.to_thread(self._sync_recognize, image_path)
|
||||
|
||||
def _sync_recognize(self, image_path: Path) -> str:
|
||||
try:
|
||||
from paddleocr import PaddleOCR # type: ignore
|
||||
except ImportError as exc:
|
||||
raise RuntimeError(
|
||||
"未安装 PaddleOCR,请执行: pip install paddleocr paddlepaddle"
|
||||
) from exc
|
||||
|
||||
if self._engine is None:
|
||||
self._engine = PaddleOCR(use_angle_cls=True, lang=self.lang, show_log=False)
|
||||
|
||||
result = self._engine.ocr(str(image_path), cls=True)
|
||||
lines: list[str] = []
|
||||
if result and result[0]:
|
||||
for line in result[0]:
|
||||
if line and len(line) >= 2:
|
||||
text_part = line[1]
|
||||
if isinstance(text_part, (list, tuple)) and text_part:
|
||||
lines.append(str(text_part[0]))
|
||||
elif isinstance(text_part, str):
|
||||
lines.append(text_part)
|
||||
return "\n".join(lines).strip()
|
||||
@@ -0,0 +1,39 @@
|
||||
"""Tesseract 本地 OCR 实现。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from .base import OCRProvider
|
||||
|
||||
|
||||
class TesseractOCR(OCRProvider):
|
||||
"""通过 pytesseract 调用本地 tesseract。
|
||||
|
||||
需提前安装 tesseract-ocr 及中文语言包。
|
||||
"""
|
||||
|
||||
name = "tesseract"
|
||||
|
||||
def __init__(self, lang: str = "chi_sim+eng", cmd: Optional[str] = None) -> None:
|
||||
self.lang = lang
|
||||
self.cmd = cmd
|
||||
|
||||
async def recognize(self, image_path: Path) -> str:
|
||||
"""异步包装:避免阻塞事件循环。"""
|
||||
return await asyncio.to_thread(self._sync_recognize, image_path)
|
||||
|
||||
def _sync_recognize(self, image_path: Path) -> str:
|
||||
try:
|
||||
import pytesseract
|
||||
from PIL import Image
|
||||
except ImportError as exc: # pragma: no cover
|
||||
raise RuntimeError("未安装 pytesseract / Pillow") from exc
|
||||
|
||||
if self.cmd:
|
||||
pytesseract.pytesseract.tesseract_cmd = self.cmd
|
||||
|
||||
with Image.open(image_path) as img:
|
||||
text = pytesseract.image_to_string(img, lang=self.lang)
|
||||
return text.strip()
|
||||
@@ -0,0 +1,52 @@
|
||||
"""视觉大模型 OCR:用多模态 API 从截图中提取文字。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from .base import OCRProvider
|
||||
from .openai_vision_client import chat_completions, safe_parse_json
|
||||
|
||||
|
||||
_VISION_OCR_SYSTEM = """你是 OCR 助手。用户会给你一张截图,请尽可能完整地提取其中的文字。
|
||||
只输出 JSON,格式:{"text": "提取到的全部文字,保留换行"}
|
||||
如果没有可识别文字,text 填空字符串。"""
|
||||
|
||||
|
||||
class VisionOCR(OCRProvider):
|
||||
"""OpenAI 兼容视觉模型识文(GLM-4V / GPT-4o / Qwen-VL / Ollama 等)。"""
|
||||
|
||||
name = "vision"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
model: str,
|
||||
timeout: float = 60.0,
|
||||
allow_upload: bool = True,
|
||||
) -> None:
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.allow_upload = allow_upload
|
||||
|
||||
async def recognize(self, image_path: Path) -> str:
|
||||
"""调用视觉模型提取文字。"""
|
||||
if not self.allow_upload:
|
||||
raise RuntimeError("敏感目录禁止上传图片,无法使用视觉 OCR")
|
||||
|
||||
content = await chat_completions(
|
||||
base_url=self.base_url,
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
system_prompt=_VISION_OCR_SYSTEM,
|
||||
user_text="请提取这张截图中的所有文字。",
|
||||
image_path=image_path,
|
||||
allow_upload=True,
|
||||
timeout=self.timeout,
|
||||
json_mode=True,
|
||||
)
|
||||
parsed = safe_parse_json(content)
|
||||
text = parsed.get("text") or parsed.get("ocr_text") or content
|
||||
return str(text).strip()
|
||||
@@ -0,0 +1,107 @@
|
||||
"""OpenAI 兼容视觉 API 的公共封装:图片编码 + chat/completions 调用。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from PIL import Image
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def image_to_data_url(image_path: Path, max_side: int | None = None) -> str:
|
||||
"""将图片压缩并编码为 data URL。"""
|
||||
max_side = max_side or settings.vlm_max_side
|
||||
with Image.open(image_path) as img:
|
||||
img = img.convert("RGB")
|
||||
w, h = img.size
|
||||
scale = max(w, h) / max_side
|
||||
if scale > 1:
|
||||
img = img.resize((int(w / scale), int(h / scale)), Image.LANCZOS)
|
||||
buf = BytesIO()
|
||||
img.save(buf, format="JPEG", quality=82)
|
||||
encoded = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
return f"data:image/jpeg;base64,{encoded}"
|
||||
|
||||
|
||||
def safe_parse_json(content: str) -> dict[str, Any]:
|
||||
"""解析模型 JSON 输出,兼容 markdown 包裹。"""
|
||||
text = content.strip()
|
||||
if text.startswith("```"):
|
||||
text = text.strip("`")
|
||||
if text.lower().startswith("json"):
|
||||
text = text[4:].strip()
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
start = text.find("{")
|
||||
end = text.rfind("}")
|
||||
if start >= 0 and end > start:
|
||||
try:
|
||||
return json.loads(text[start : end + 1])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return {"text": content}
|
||||
|
||||
|
||||
async def chat_completions(
|
||||
*,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
model: str,
|
||||
system_prompt: str,
|
||||
user_text: str,
|
||||
image_path: Path | None = None,
|
||||
allow_upload: bool = True,
|
||||
timeout: float = 60.0,
|
||||
json_mode: bool = True,
|
||||
) -> str:
|
||||
"""调用 /v1/chat/completions,返回 message.content 字符串。"""
|
||||
user_content: list[dict[str, Any]] = [{"type": "text", "text": user_text}]
|
||||
if image_path is not None and allow_upload:
|
||||
data_url = image_to_data_url(image_path)
|
||||
user_content.append({"type": "image_url", "image_url": {"url": data_url}})
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_content},
|
||||
],
|
||||
"temperature": 0.2,
|
||||
}
|
||||
if json_mode:
|
||||
payload["response_format"] = {"type": "json_object"}
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
url = f"{base_url.rstrip('/')}/chat/completions"
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
try:
|
||||
resp = await client.post(url, json=payload, headers=headers)
|
||||
except httpx.HTTPError as exc:
|
||||
logger.warning("视觉 API 请求失败,尝试移除 response_format:%s", exc)
|
||||
payload.pop("response_format", None)
|
||||
resp = await client.post(url, json=payload, headers=headers)
|
||||
|
||||
if resp.status_code == 400 and "response_format" in resp.text:
|
||||
payload.pop("response_format", None)
|
||||
resp = await client.post(url, json=payload, headers=headers)
|
||||
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
try:
|
||||
return data["choices"][0]["message"]["content"]
|
||||
except (KeyError, IndexError) as exc:
|
||||
raise RuntimeError(f"视觉 API 返回结构异常: {data}") from exc
|
||||
@@ -0,0 +1,107 @@
|
||||
"""OpenAI 兼容 VLM 实现:覆盖 Ollama / GLM / MiniMax / Moonshot / OpenRouter / OpenAI。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.core.logger import get_logger
|
||||
|
||||
from .base import VLMProvider, VLMResult
|
||||
from .openai_vision_client import chat_completions, safe_parse_json
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
_SYSTEM_PROMPT = """你是一个截图整理助手。用户会给你一张截图(可能附带 OCR 文本)。
|
||||
请用简洁的中文,按以下 JSON 结构返回分析结果,**只输出 JSON,不要解释**:
|
||||
|
||||
{
|
||||
"title": "一句话标题,不超过 24 个字",
|
||||
"summary": "2-3 句话总结这张截图的内容、要点或笑点",
|
||||
"category": "从给定分类列表中选一个最贴切的名字;如果都不符合就填'其他'",
|
||||
"tags": ["3-6 个能帮助检索的细分标签"],
|
||||
"todos": [
|
||||
{"title": "如果截图里出现'待看/待读/待办/想试试/记一下'的内容,抽成一条 todo", "kind": "待读|待看|待办|学习", "note": "可空"}
|
||||
],
|
||||
"suggestion": "可选:给用户的进一步行动建议或同类资源提示,可空"
|
||||
}
|
||||
|
||||
要求:
|
||||
- 标题要可读,不要复述"这是一张..."。
|
||||
- summary 不要超过 80 字。
|
||||
- todos 没有可识别项时给空数组。"""
|
||||
|
||||
|
||||
class OpenAICompatVLM(VLMProvider):
|
||||
"""统一调用 /v1/chat/completions,图片以 base64 data URL 传入。"""
|
||||
|
||||
name = "openai_compat"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
model: str,
|
||||
timeout: float = 60.0,
|
||||
) -> None:
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
|
||||
async def analyze(
|
||||
self,
|
||||
image_path: Path,
|
||||
ocr_text: str,
|
||||
categories: list[str],
|
||||
allow_upload: bool,
|
||||
) -> VLMResult:
|
||||
"""调用模型并解析结构化 JSON。"""
|
||||
prompt = (
|
||||
f"可选分类:{', '.join(categories)}\n\n"
|
||||
f"OCR 文本(可能不完整或为空):\n{ocr_text or '(无)'}"
|
||||
)
|
||||
content = await chat_completions(
|
||||
base_url=self.base_url,
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
system_prompt=_SYSTEM_PROMPT,
|
||||
user_text=prompt,
|
||||
image_path=image_path if allow_upload else None,
|
||||
allow_upload=allow_upload,
|
||||
timeout=self.timeout,
|
||||
json_mode=True,
|
||||
)
|
||||
parsed = safe_parse_json(content)
|
||||
return _to_vlm_result(parsed)
|
||||
|
||||
|
||||
def _to_vlm_result(data: dict[str, Any]) -> VLMResult:
|
||||
"""JSON -> dataclass,容错地兜住字段。"""
|
||||
todos_raw = data.get("todos") or []
|
||||
todos: list[dict[str, str]] = []
|
||||
if isinstance(todos_raw, list):
|
||||
for item in todos_raw:
|
||||
if isinstance(item, dict) and item.get("title"):
|
||||
todos.append(
|
||||
{
|
||||
"title": str(item.get("title", ""))[:512],
|
||||
"kind": str(item.get("kind", "")) or "待办",
|
||||
"note": str(item.get("note", "") or ""),
|
||||
}
|
||||
)
|
||||
elif isinstance(item, str):
|
||||
todos.append({"title": item, "kind": "待办", "note": ""})
|
||||
tags_raw = data.get("tags") or []
|
||||
if not isinstance(tags_raw, list):
|
||||
tags_raw = []
|
||||
return VLMResult(
|
||||
title=str(data.get("title", "") or "")[:128],
|
||||
summary=str(data.get("summary", "") or ""),
|
||||
category=str(data.get("category") or "") or None,
|
||||
tags=[str(t) for t in tags_raw if t][:8],
|
||||
todos=todos,
|
||||
suggestion=str(data.get("suggestion", "") or ""),
|
||||
raw=data,
|
||||
)
|
||||
@@ -0,0 +1,76 @@
|
||||
"""通用 Schema:状态、统计、设置。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class StatsResp(BaseModel):
|
||||
total: int
|
||||
pending_ocr: int
|
||||
pending_ai: int
|
||||
failed: int
|
||||
by_category: list[dict[str, Any]]
|
||||
by_date: list[dict[str, Any]]
|
||||
|
||||
|
||||
class WatchFolderIn(BaseModel):
|
||||
path: str
|
||||
enabled: bool = True
|
||||
recursive: bool = True
|
||||
is_sensitive: bool = False
|
||||
|
||||
|
||||
class WatchFolderOut(WatchFolderIn):
|
||||
id: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CategoryIn(BaseModel):
|
||||
name: str
|
||||
color: Optional[str] = None
|
||||
prompt_hint: Optional[str] = None
|
||||
|
||||
|
||||
class ProviderConfig(BaseModel):
|
||||
"""OCR/VLM Provider 配置。
|
||||
|
||||
type: openai_compat / tesseract / anthropic / none
|
||||
base_url、api_key、model 等都是可选的,按 provider 类型决定。
|
||||
"""
|
||||
|
||||
type: str
|
||||
base_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
extra: dict[str, Any] = {}
|
||||
|
||||
|
||||
class ProviderConfigOut(ProviderConfig):
|
||||
"""读取用:api_key 永远为空,只通过 api_key_mask 暴露提示。"""
|
||||
|
||||
api_key_mask: Optional[str] = None
|
||||
|
||||
|
||||
class RecognitionModeIn(BaseModel):
|
||||
"""文字识别策略:传统 OCR / 视觉 AI / 混合。"""
|
||||
|
||||
mode: str # ocr | vision | hybrid
|
||||
|
||||
|
||||
class ProviderTestResult(BaseModel):
|
||||
"""Provider 连通性测试结果。"""
|
||||
|
||||
ok: bool
|
||||
message: str
|
||||
detail: Optional[str] = None
|
||||
latency_ms: Optional[int] = None
|
||||
|
||||
|
||||
class TodoUpdate(BaseModel):
|
||||
status: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
note: Optional[str] = None
|
||||
@@ -0,0 +1,79 @@
|
||||
"""分析任务队列的请求/响应模型。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class JobOut(BaseModel):
|
||||
|
||||
"""单条任务详情,含关联截图摘要。"""
|
||||
|
||||
|
||||
|
||||
id: int
|
||||
|
||||
screenshot_id: int
|
||||
|
||||
kind: str
|
||||
|
||||
status: str
|
||||
|
||||
retries: int
|
||||
|
||||
last_error: Optional[str] = None
|
||||
|
||||
created_at: datetime
|
||||
|
||||
started_at: Optional[datetime] = None
|
||||
|
||||
finished_at: Optional[datetime] = None
|
||||
|
||||
thumb_url: Optional[str] = None
|
||||
|
||||
path: Optional[str] = None
|
||||
|
||||
ai_title: Optional[str] = None
|
||||
|
||||
ai_status: Optional[str] = None
|
||||
|
||||
ocr_status: Optional[str] = None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class JobListResp(BaseModel):
|
||||
|
||||
items: list[JobOut]
|
||||
|
||||
total: int
|
||||
|
||||
page: int
|
||||
|
||||
size: int
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class JobRetryIn(BaseModel):
|
||||
|
||||
"""可选:仅重试指定 job id;不传则重试全部 failed。"""
|
||||
|
||||
|
||||
|
||||
job_ids: Optional[list[int]] = None
|
||||
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
"""截图相关的请求/响应模型。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TagOut(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
color: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CategoryOut(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
color: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ScreenshotBrief(BaseModel):
|
||||
"""卡片列表用:尽量精简。"""
|
||||
|
||||
id: int
|
||||
path: str
|
||||
width: int
|
||||
height: int
|
||||
captured_at: datetime
|
||||
thumb_url: Optional[str] = None
|
||||
ai_title: Optional[str] = None
|
||||
ai_status: str
|
||||
ocr_status: str
|
||||
is_favorite: bool = False
|
||||
category: Optional[CategoryOut] = None
|
||||
tags: list[TagOut] = []
|
||||
|
||||
|
||||
class TodoBrief(BaseModel):
|
||||
id: int
|
||||
title: str
|
||||
note: Optional[str] = None
|
||||
kind: Optional[str] = None
|
||||
status: str
|
||||
created_at: datetime
|
||||
completed_at: Optional[datetime] = None
|
||||
screenshot_id: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TodoListResp(BaseModel):
|
||||
items: list[TodoBrief]
|
||||
total: int
|
||||
page: int
|
||||
size: int
|
||||
|
||||
|
||||
class ScreenshotDetail(ScreenshotBrief):
|
||||
"""详情页用:含 OCR 与 AI 文本。"""
|
||||
|
||||
file_url: str
|
||||
size: int
|
||||
ocr_text: Optional[str] = None
|
||||
ai_summary: Optional[str] = None
|
||||
ai_suggestion: Optional[str] = None
|
||||
todos: list[TodoBrief] = []
|
||||
|
||||
|
||||
class ScreenshotListResp(BaseModel):
|
||||
items: list[ScreenshotBrief]
|
||||
total: int
|
||||
page: int
|
||||
size: int
|
||||
|
||||
|
||||
class ScreenshotUpdate(BaseModel):
|
||||
"""前端更新可写字段。"""
|
||||
|
||||
category_id: Optional[int] = None
|
||||
is_favorite: Optional[bool] = None
|
||||
is_hidden: Optional[bool] = None
|
||||
tags: Optional[list[str]] = Field(default=None, description="标签名列表,自动新建")
|
||||
@@ -0,0 +1,484 @@
|
||||
"""单张截图的分析逻辑:OCR -> VLM -> 写回数据库。
|
||||
|
||||
设计要点:
|
||||
- 不在长时间网络调用期间持有 SQLite 写事务,避免 `database is locked`。
|
||||
- 把流程拆为「短事务(取配置/标记状态)」 -> 「无事务(OCR/VLM 网络调用)」
|
||||
-> 「短事务(写回结果)」。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.db import session_scope
|
||||
from app.core.logger import get_logger
|
||||
from app.core.path_utils import is_accessible_file, path_from_storage, path_is_under
|
||||
from app.models.category import Category, DEFAULT_CATEGORIES
|
||||
from app.models.meta import ScreenshotMeta
|
||||
from app.models.screenshot import ProcessStatus, Screenshot
|
||||
from app.models.setting import (
|
||||
DEFAULT_RECOGNITION_MODE,
|
||||
KEY_OCR_PROVIDER,
|
||||
KEY_RECOGNITION_MODE,
|
||||
KEY_VLM_PROVIDER,
|
||||
)
|
||||
from app.models.tag import Tag
|
||||
from app.models.todo import Todo, TodoStatus
|
||||
from app.models.watch_folder import WatchFolder
|
||||
from app.providers import (
|
||||
RECOGNITION_MODES,
|
||||
build_ocr_provider,
|
||||
build_vision_ocr,
|
||||
build_vlm_provider,
|
||||
)
|
||||
from app.providers.base import VLMResult
|
||||
from app.schemas.common import ProviderConfig
|
||||
from app.services.exif_utils import is_exif_location_tag
|
||||
from app.services.settings_store import get_provider_config, get_setting
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PreparedContext:
|
||||
"""从短事务中导出的、不依赖 ORM 会话的纯数据。"""
|
||||
|
||||
path: Path
|
||||
ocr_cfg: Optional[ProviderConfig]
|
||||
vlm_cfg: Optional[ProviderConfig]
|
||||
recognition_mode: str
|
||||
category_names: list[str]
|
||||
allow_upload: bool
|
||||
exists: bool
|
||||
|
||||
|
||||
async def analyze_screenshot_by_id(screenshot_id: int) -> None:
|
||||
"""对外入口:按 id 分析单张截图。
|
||||
|
||||
被 worker 调度。函数内部自己管理多个短事务。
|
||||
"""
|
||||
ctx = _prepare(screenshot_id)
|
||||
if ctx is None:
|
||||
return # 截图已被删除
|
||||
if not ctx.exists:
|
||||
_persist_missing(screenshot_id)
|
||||
return
|
||||
|
||||
ocr_provider = _safe_build(
|
||||
lambda c: build_ocr_provider(c, allow_upload=ctx.allow_upload),
|
||||
ctx.ocr_cfg if _use_traditional_ocr(ctx) else None,
|
||||
"OCR",
|
||||
)
|
||||
vlm_provider = _safe_build(build_vlm_provider, ctx.vlm_cfg, "VLM")
|
||||
|
||||
# ---- 文字识别阶段(在事务外执行)----
|
||||
ocr_text, ocr_status = await _extract_text(screenshot_id, ctx, ocr_provider)
|
||||
|
||||
# ---- VLM 阶段(事务外)----
|
||||
vlm_result: Optional[VLMResult] = None
|
||||
ai_status = ProcessStatus.SKIPPED.value
|
||||
vlm_error: Optional[Exception] = None
|
||||
if vlm_provider is not None:
|
||||
_mark_status(screenshot_id, ai=ProcessStatus.RUNNING.value)
|
||||
try:
|
||||
vlm_result = await vlm_provider.analyze(
|
||||
image_path=ctx.path,
|
||||
ocr_text=ocr_text,
|
||||
categories=ctx.category_names,
|
||||
allow_upload=ctx.allow_upload,
|
||||
)
|
||||
ai_status = ProcessStatus.DONE.value
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("VLM 失败 #%d: %s", screenshot_id, exc)
|
||||
ai_status = ProcessStatus.FAILED.value
|
||||
vlm_error = exc
|
||||
|
||||
# ---- 写回阶段(短事务)----
|
||||
_persist_result(
|
||||
screenshot_id=screenshot_id,
|
||||
ocr_text=ocr_text,
|
||||
ocr_status=ocr_status,
|
||||
ai_status=ai_status,
|
||||
vlm_result=vlm_result,
|
||||
)
|
||||
|
||||
if vlm_error is not None:
|
||||
raise vlm_error # 让 worker 决定重试
|
||||
|
||||
|
||||
async def analyze_ocr_only_by_id(screenshot_id: int) -> None:
|
||||
"""仅补跑 OCR/视觉识文,不改动 AI 分析结果。
|
||||
|
||||
用于 ai_status=done 但 ocr_status=failed 的截图。
|
||||
OCR 仍失败时抛异常,由 worker 按 max_retries 重试。
|
||||
"""
|
||||
ctx = _prepare(screenshot_id)
|
||||
if ctx is None:
|
||||
return
|
||||
if not ctx.exists:
|
||||
_persist_missing(screenshot_id)
|
||||
raise RuntimeError("截图文件丢失")
|
||||
|
||||
ocr_provider = _safe_build(
|
||||
lambda c: build_ocr_provider(c, allow_upload=ctx.allow_upload),
|
||||
ctx.ocr_cfg if _use_traditional_ocr(ctx) else None,
|
||||
"OCR",
|
||||
)
|
||||
|
||||
ocr_text, ocr_status = await _extract_text(screenshot_id, ctx, ocr_provider)
|
||||
_persist_ocr_only(screenshot_id, ocr_text, ocr_status)
|
||||
|
||||
if ocr_status == ProcessStatus.FAILED.value:
|
||||
raise RuntimeError("OCR 识别失败")
|
||||
|
||||
|
||||
# ---------------- 短事务工具 ---------------- #
|
||||
|
||||
|
||||
def _prepare(screenshot_id: int) -> Optional[_PreparedContext]:
|
||||
"""短事务:读取 Provider 配置、分类列表、敏感目录判定。"""
|
||||
with session_scope() as session:
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
return None
|
||||
image_path = path_from_storage(shot.path)
|
||||
exists = is_accessible_file(image_path)
|
||||
|
||||
ocr_cfg = _load_provider_config(session, KEY_OCR_PROVIDER)
|
||||
vlm_cfg = _load_provider_config(session, KEY_VLM_PROVIDER)
|
||||
mode = get_setting(session, KEY_RECOGNITION_MODE, DEFAULT_RECOGNITION_MODE)
|
||||
if mode not in RECOGNITION_MODES:
|
||||
mode = DEFAULT_RECOGNITION_MODE
|
||||
|
||||
categories = _ensure_default_categories(session)
|
||||
category_names = [c.name for c in categories]
|
||||
|
||||
allow_upload = not _is_sensitive(session, image_path)
|
||||
|
||||
return _PreparedContext(
|
||||
path=image_path,
|
||||
ocr_cfg=ocr_cfg,
|
||||
vlm_cfg=vlm_cfg,
|
||||
recognition_mode=mode,
|
||||
category_names=category_names,
|
||||
allow_upload=allow_upload,
|
||||
exists=exists,
|
||||
)
|
||||
|
||||
|
||||
def _use_traditional_ocr(ctx: _PreparedContext) -> bool:
|
||||
"""混合/传统模式下是否启用 OCR 区配置(排除 vision 类型,vision 单独处理)。"""
|
||||
if ctx.recognition_mode not in ("ocr", "hybrid"):
|
||||
return False
|
||||
if ctx.ocr_cfg is None or ctx.ocr_cfg.type in ("", "none", "disabled", "vision"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def _extract_text(
|
||||
screenshot_id: int,
|
||||
ctx: _PreparedContext,
|
||||
ocr_provider,
|
||||
) -> tuple[str, str]:
|
||||
"""按识别模式提取文字:传统 OCR / 视觉 AI / 混合。"""
|
||||
ocr_text = ""
|
||||
ocr_status = ProcessStatus.SKIPPED.value
|
||||
mode = ctx.recognition_mode
|
||||
|
||||
# 1) 传统 OCR(Tesseract / Paddle / HTTP)
|
||||
if ocr_provider is not None:
|
||||
_mark_status(screenshot_id, ocr=ProcessStatus.RUNNING.value)
|
||||
try:
|
||||
ocr_text = await ocr_provider.recognize(ctx.path)
|
||||
ocr_status = ProcessStatus.DONE.value
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("OCR 失败 #%d: %s", screenshot_id, exc)
|
||||
ocr_status = ProcessStatus.FAILED.value
|
||||
|
||||
# 2) 视觉 AI 识文
|
||||
need_vision = mode == "vision" or (
|
||||
mode == "hybrid" and not ocr_text.strip()
|
||||
)
|
||||
if mode == "ocr" and ctx.ocr_cfg and ctx.ocr_cfg.type == "vision":
|
||||
# 用户在 OCR 区选了「视觉模型识文」
|
||||
need_vision = True
|
||||
|
||||
if need_vision:
|
||||
vision_cfg = _pick_vision_config(ctx)
|
||||
vision = _safe_build(
|
||||
lambda c: build_vision_ocr(c, allow_upload=ctx.allow_upload),
|
||||
vision_cfg,
|
||||
"VisionOCR",
|
||||
)
|
||||
if vision is not None:
|
||||
_mark_status(screenshot_id, ocr=ProcessStatus.RUNNING.value)
|
||||
try:
|
||||
ocr_text = await vision.recognize(ctx.path)
|
||||
ocr_status = ProcessStatus.DONE.value
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("视觉识文失败 #%d: %s", screenshot_id, exc)
|
||||
if ocr_status != ProcessStatus.DONE.value:
|
||||
ocr_status = ProcessStatus.FAILED.value
|
||||
|
||||
return ocr_text, ocr_status
|
||||
|
||||
|
||||
def _pick_vision_config(ctx: _PreparedContext) -> Optional[ProviderConfig]:
|
||||
"""决定视觉识文用哪套配置:优先 OCR 区的 vision,否则 VLM 区。"""
|
||||
if ctx.ocr_cfg and ctx.ocr_cfg.type == "vision":
|
||||
return ctx.ocr_cfg
|
||||
if ctx.recognition_mode == "vision" or ctx.recognition_mode == "hybrid":
|
||||
return ctx.vlm_cfg
|
||||
return ctx.vlm_cfg
|
||||
|
||||
|
||||
def _mark_status(
|
||||
screenshot_id: int,
|
||||
ocr: Optional[str] = None,
|
||||
ai: Optional[str] = None,
|
||||
) -> None:
|
||||
"""短事务:把截图标记为 running,方便前端看到进度。"""
|
||||
if ocr is None and ai is None:
|
||||
return
|
||||
with session_scope() as session:
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
return
|
||||
if ocr is not None:
|
||||
shot.ocr_status = ocr
|
||||
if ai is not None:
|
||||
shot.ai_status = ai
|
||||
|
||||
|
||||
def _persist_missing(screenshot_id: int) -> None:
|
||||
"""短事务:标记文件已丢失。"""
|
||||
with session_scope() as session:
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
return
|
||||
shot.ocr_status = ProcessStatus.FAILED.value
|
||||
shot.ai_status = ProcessStatus.FAILED.value
|
||||
meta = _get_or_create_meta(session, screenshot_id)
|
||||
meta.ai_summary = "(文件丢失)"
|
||||
|
||||
|
||||
def _persist_result(
|
||||
screenshot_id: int,
|
||||
ocr_text: str,
|
||||
ocr_status: str,
|
||||
ai_status: str,
|
||||
vlm_result: Optional[VLMResult],
|
||||
) -> None:
|
||||
"""短事务:把 OCR/VLM 结果写回 DB,包括 meta/tags/category/todos。"""
|
||||
with session_scope() as session:
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
return
|
||||
shot.ocr_status = ocr_status
|
||||
shot.ai_status = ai_status
|
||||
|
||||
meta = _get_or_create_meta(session, screenshot_id)
|
||||
meta.ocr_text = ocr_text or None
|
||||
|
||||
if vlm_result is not None:
|
||||
meta.ai_title = vlm_result.title or None
|
||||
meta.ai_summary = vlm_result.summary or None
|
||||
meta.ai_suggestion = vlm_result.suggestion or None
|
||||
meta.ai_raw_json = json.dumps(vlm_result.raw, ensure_ascii=False)
|
||||
|
||||
categories = list(session.scalars(select(Category)).all())
|
||||
category = _resolve_category(session, vlm_result.category, categories)
|
||||
if category is not None:
|
||||
shot.category_id = category.id
|
||||
|
||||
_sync_tags(session, shot, vlm_result.tags)
|
||||
_sync_todos(session, shot, vlm_result.todos)
|
||||
|
||||
|
||||
def _persist_ocr_only(screenshot_id: int, ocr_text: str, ocr_status: str) -> None:
|
||||
"""短事务:仅写回 OCR 文本与状态,保留已有 AI 字段。"""
|
||||
with session_scope() as session:
|
||||
shot = session.get(Screenshot, screenshot_id)
|
||||
if shot is None:
|
||||
return
|
||||
shot.ocr_status = ocr_status
|
||||
meta = _get_or_create_meta(session, screenshot_id)
|
||||
meta.ocr_text = ocr_text or None
|
||||
|
||||
|
||||
def enqueue_ocr_jobs(*, limit: int = 500) -> int:
|
||||
"""为「AI 已成功、OCR 失败」的截图批量创建 OCR 补跑任务。
|
||||
|
||||
跳过已有 pending/running 的 ocr 任务,避免重复入队。
|
||||
"""
|
||||
from app.models.job import Job, JobKind, JobStatus
|
||||
|
||||
active_status = (JobStatus.PENDING.value, JobStatus.RUNNING.value)
|
||||
created = 0
|
||||
with session_scope() as session:
|
||||
# 已有活跃 OCR 任务的 screenshot_id
|
||||
busy_ids = set(
|
||||
session.scalars(
|
||||
select(Job.screenshot_id).where(
|
||||
Job.kind == JobKind.OCR.value,
|
||||
Job.status.in_(active_status),
|
||||
)
|
||||
).all()
|
||||
)
|
||||
shots = session.scalars(
|
||||
select(Screenshot)
|
||||
.where(
|
||||
Screenshot.ocr_status == ProcessStatus.FAILED.value,
|
||||
Screenshot.ai_status == ProcessStatus.DONE.value,
|
||||
)
|
||||
.order_by(Screenshot.id.asc())
|
||||
.limit(limit)
|
||||
).all()
|
||||
for shot in shots:
|
||||
if shot.id in busy_ids:
|
||||
continue
|
||||
session.add(
|
||||
Job(
|
||||
screenshot_id=shot.id,
|
||||
kind=JobKind.OCR.value,
|
||||
status=JobStatus.PENDING.value,
|
||||
)
|
||||
)
|
||||
busy_ids.add(shot.id)
|
||||
created += 1
|
||||
return created
|
||||
|
||||
|
||||
# ---------------- 内部辅助 ---------------- #
|
||||
|
||||
|
||||
def _load_provider_config(session, key: str) -> Optional[ProviderConfig]:
|
||||
raw = get_provider_config(session, key)
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
return ProviderConfig(**raw)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Provider 配置 %s 解析失败: %s", key, exc)
|
||||
return None
|
||||
|
||||
|
||||
def _safe_build(builder, cfg: Optional[ProviderConfig], label: str):
|
||||
if cfg is None:
|
||||
return None
|
||||
try:
|
||||
return builder(cfg)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("%s Provider 构造失败: %s", label, exc)
|
||||
return None
|
||||
|
||||
|
||||
def _is_sensitive(session, image_path: Path) -> bool:
|
||||
"""判断文件是否落在某个标记为敏感的监听目录内。"""
|
||||
sensitive_dirs = session.scalars(
|
||||
select(WatchFolder.path).where(WatchFolder.is_sensitive.is_(True))
|
||||
).all()
|
||||
child = str(image_path)
|
||||
for d in sensitive_dirs:
|
||||
if path_is_under(d, child):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_or_create_meta(session, screenshot_id: int) -> ScreenshotMeta:
|
||||
meta = session.get(ScreenshotMeta, screenshot_id)
|
||||
if meta is None:
|
||||
meta = ScreenshotMeta(screenshot_id=screenshot_id)
|
||||
session.add(meta)
|
||||
session.flush()
|
||||
return meta
|
||||
|
||||
|
||||
def ensure_default_categories() -> None:
|
||||
"""对外暴露:启动时 seed 默认分类。"""
|
||||
with session_scope() as session:
|
||||
_ensure_default_categories(session)
|
||||
|
||||
|
||||
def _ensure_default_categories(session) -> list[Category]:
|
||||
"""首次运行时灌入默认分类,返回最新列表。"""
|
||||
existing = session.scalars(select(Category)).all()
|
||||
if existing:
|
||||
return list(existing)
|
||||
for item in DEFAULT_CATEGORIES:
|
||||
session.add(Category(**item))
|
||||
session.flush()
|
||||
return list(session.scalars(select(Category)).all())
|
||||
|
||||
|
||||
def _resolve_category(
|
||||
session,
|
||||
name: str | None,
|
||||
categories: list[Category],
|
||||
) -> Optional[Category]:
|
||||
if not name:
|
||||
return None
|
||||
normalized = name.strip()
|
||||
for c in categories:
|
||||
if c.name == normalized or c.name in normalized or normalized in c.name:
|
||||
return c
|
||||
new_cat = Category(name=normalized[:64], color=None, prompt_hint=None)
|
||||
session.add(new_cat)
|
||||
session.flush()
|
||||
categories.append(new_cat)
|
||||
return new_cat
|
||||
|
||||
|
||||
def _sync_tags(session, screenshot: Screenshot, tag_names: list[str]) -> None:
|
||||
"""根据 AI 给的标签名同步多对多关系;保留 EXIF 地点标签不被覆盖。"""
|
||||
exif_tags = [t for t in (screenshot.tags or []) if is_exif_location_tag(t.name)]
|
||||
exif_names = {t.name for t in exif_tags}
|
||||
|
||||
seen: set[str] = set(exif_names)
|
||||
tag_objs: list[Tag] = list(exif_tags)
|
||||
for raw_name in tag_names:
|
||||
name = (raw_name or "").strip()[:64]
|
||||
if not name or name in seen:
|
||||
continue
|
||||
seen.add(name)
|
||||
tag = session.scalar(select(Tag).where(Tag.name == name))
|
||||
if tag is None:
|
||||
tag = Tag(name=name)
|
||||
session.add(tag)
|
||||
session.flush()
|
||||
tag_objs.append(tag)
|
||||
screenshot.tags = tag_objs
|
||||
|
||||
|
||||
def _sync_todos(
|
||||
session,
|
||||
screenshot: Screenshot,
|
||||
todos: list[dict[str, str]],
|
||||
) -> None:
|
||||
"""以 AI 输出覆盖该截图未完成的 todos;保留用户已完成/搁置项。"""
|
||||
existing = list(screenshot.todos)
|
||||
for t in existing:
|
||||
if t.status in (TodoStatus.DONE.value, TodoStatus.DROPPED.value):
|
||||
continue
|
||||
session.delete(t)
|
||||
session.flush()
|
||||
|
||||
for item in todos:
|
||||
title = (item.get("title") or "").strip()
|
||||
if not title:
|
||||
continue
|
||||
session.add(
|
||||
Todo(
|
||||
screenshot_id=screenshot.id,
|
||||
title=title[:512],
|
||||
note=(item.get("note") or "")[:2000] or None,
|
||||
kind=(item.get("kind") or "待办")[:32],
|
||||
status=TodoStatus.PENDING.value,
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
@@ -0,0 +1,107 @@
|
||||
"""从图片 EXIF 提取拍摄时间与 GPS 地点标签。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from fractions import Fraction
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from PIL import ExifTags, Image
|
||||
|
||||
from app.core.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# EXIF 地点类标签前缀,重分析时保留不被 AI 覆盖
|
||||
EXIF_TAG_PREFIX = "地点:"
|
||||
|
||||
|
||||
def _ratio_to_float(value) -> float:
|
||||
"""EXIF 有理数 → float。"""
|
||||
if isinstance(value, tuple) and len(value) == 2:
|
||||
num, den = value
|
||||
return float(num) / float(den) if den else 0.0
|
||||
if isinstance(value, Fraction):
|
||||
return float(value)
|
||||
return float(value)
|
||||
|
||||
|
||||
def _dms_to_decimal(dms: tuple, ref: str) -> Optional[float]:
|
||||
"""度分秒 → 十进制度。"""
|
||||
try:
|
||||
deg, minutes, seconds = dms
|
||||
decimal = _ratio_to_float(deg) + _ratio_to_float(minutes) / 60 + _ratio_to_float(seconds) / 3600
|
||||
if ref in ("S", "W"):
|
||||
decimal = -decimal
|
||||
return round(decimal, 6)
|
||||
except (TypeError, ValueError, ZeroDivisionError):
|
||||
return None
|
||||
|
||||
|
||||
def extract_image_metadata(path: Path) -> tuple[Optional[datetime], list[str]]:
|
||||
"""读取 EXIF,返回 (拍摄时间, 地点标签列表)。"""
|
||||
captured: Optional[datetime] = None
|
||||
location_tags: list[str] = []
|
||||
|
||||
try:
|
||||
with Image.open(path) as img:
|
||||
exif = img.getexif()
|
||||
if not exif:
|
||||
return None, []
|
||||
|
||||
# 拍摄时间:优先 DateTimeOriginal
|
||||
for key in (36867, 36868, 306): # DateTimeOriginal / DateTimeDigitized / DateTime
|
||||
raw = exif.get(key)
|
||||
if raw:
|
||||
captured = _parse_exif_datetime(str(raw))
|
||||
if captured:
|
||||
break
|
||||
|
||||
# GPS → 地点标签
|
||||
gps_ifd = exif.get_ifd(ExifTags.IFD.GPSInfo) if hasattr(exif, "get_ifd") else None
|
||||
if gps_ifd:
|
||||
lat = _dms_to_decimal(
|
||||
gps_ifd.get(2),
|
||||
gps_ifd.get(1, "N"),
|
||||
)
|
||||
lon = _dms_to_decimal(
|
||||
gps_ifd.get(4),
|
||||
gps_ifd.get(3, "E"),
|
||||
)
|
||||
if lat is not None and lon is not None:
|
||||
location_tags.append(f"{EXIF_TAG_PREFIX}{lat},{lon}")
|
||||
|
||||
# 部分设备写入可读地名(XP Keywords / ImageDescription 等)
|
||||
for key, val in exif.items():
|
||||
tag_name = ExifTags.TAGS.get(key, "")
|
||||
if tag_name in ("ImageDescription", "XPComment") and val:
|
||||
text = str(val).strip()[:64]
|
||||
if text and _looks_like_place(text):
|
||||
location_tags.append(f"{EXIF_TAG_PREFIX}{text}")
|
||||
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("读取 EXIF 失败 %s: %s", path.name, exc)
|
||||
|
||||
return captured, location_tags
|
||||
|
||||
|
||||
def _parse_exif_datetime(raw: str) -> Optional[datetime]:
|
||||
"""解析 EXIF 时间字符串。"""
|
||||
for fmt in ("%Y:%m:%d %H:%M:%S", "%Y-%m-%d %H:%M:%S"):
|
||||
try:
|
||||
return datetime.strptime(raw.strip(), fmt)
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _looks_like_place(text: str) -> bool:
|
||||
"""粗判字符串是否像地名(含中文或常见地址关键词)。"""
|
||||
keywords = ("市", "省", "区", "县", "镇", "村", "路", "街", "国", "GPS")
|
||||
return any(k in text for k in keywords) or any("\u4e00" <= c <= "\u9fff" for c in text)
|
||||
|
||||
|
||||
def is_exif_location_tag(name: str) -> bool:
|
||||
"""是否为 EXIF 自动写入的地点标签。"""
|
||||
return name.startswith(EXIF_TAG_PREFIX)
|
||||
@@ -0,0 +1,147 @@
|
||||
"""将磁盘上的截图文件入库 + 排队分析。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from PIL import Image
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.path_utils import (
|
||||
is_accessible_dir,
|
||||
is_accessible_file,
|
||||
path_from_storage,
|
||||
path_to_storage,
|
||||
)
|
||||
from app.core.logger import get_logger
|
||||
from app.models.job import Job, JobKind, JobStatus
|
||||
from app.models.screenshot import ProcessStatus, Screenshot
|
||||
from app.models.tag import Tag
|
||||
from app.services.exif_utils import extract_image_metadata
|
||||
from app.services.thumbnail import file_hash, generate_thumbnail, is_supported
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def ingest_path(session: Session, path: Path) -> Optional[Screenshot]:
|
||||
"""单文件入库。返回 Screenshot 或 None(不支持/重复时)。"""
|
||||
if not is_accessible_file(path) or not path.is_file():
|
||||
return None
|
||||
if not is_supported(path):
|
||||
return None
|
||||
|
||||
stored_path = path_to_storage(path)
|
||||
|
||||
try:
|
||||
digest = file_hash(path)
|
||||
except OSError as exc:
|
||||
logger.warning("无法读取文件 %s: %s", path, exc)
|
||||
return None
|
||||
|
||||
existing = session.scalar(select(Screenshot).where(Screenshot.file_hash == digest))
|
||||
if existing:
|
||||
# 同一内容重命名/移动:更新路径
|
||||
if existing.path != stored_path:
|
||||
existing.path = stored_path
|
||||
session.flush()
|
||||
return existing
|
||||
|
||||
try:
|
||||
with Image.open(path) as img:
|
||||
width, height = img.size
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("无法读取图片尺寸 %s: %s", path, exc)
|
||||
width, height = 0, 0
|
||||
|
||||
stat = path.stat()
|
||||
captured_at = datetime.fromtimestamp(stat.st_mtime)
|
||||
exif_time, location_tags = extract_image_metadata(path)
|
||||
if exif_time is not None:
|
||||
captured_at = exif_time
|
||||
|
||||
try:
|
||||
thumb = generate_thumbnail(path)
|
||||
thumb_path = thumb.as_posix()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("生成缩略图失败 %s: %s", path, exc)
|
||||
thumb_path = None
|
||||
|
||||
shot = Screenshot(
|
||||
path=stored_path,
|
||||
file_hash=digest,
|
||||
width=width,
|
||||
height=height,
|
||||
size=stat.st_size,
|
||||
captured_at=captured_at,
|
||||
thumb_path=thumb_path,
|
||||
ocr_status=ProcessStatus.PENDING.value,
|
||||
ai_status=ProcessStatus.PENDING.value,
|
||||
)
|
||||
session.add(shot)
|
||||
session.flush()
|
||||
|
||||
if location_tags:
|
||||
_attach_location_tags(session, shot, location_tags)
|
||||
|
||||
job = Job(screenshot_id=shot.id, kind=JobKind.FULL.value, status=JobStatus.PENDING.value)
|
||||
session.add(job)
|
||||
logger.info("入库 #%d %s", shot.id, path.name)
|
||||
return shot
|
||||
|
||||
|
||||
def _attach_location_tags(session: Session, shot: Screenshot, tag_names: list[str]) -> None:
|
||||
"""入库时写入 EXIF 地点标签。"""
|
||||
tag_objs: list[Tag] = []
|
||||
for raw in tag_names:
|
||||
name = (raw or "").strip()[:64]
|
||||
if not name:
|
||||
continue
|
||||
tag = session.scalar(select(Tag).where(Tag.name == name))
|
||||
if tag is None:
|
||||
tag = Tag(name=name)
|
||||
session.add(tag)
|
||||
session.flush()
|
||||
tag_objs.append(tag)
|
||||
shot.tags = tag_objs
|
||||
|
||||
|
||||
def ingest_directory(
|
||||
session: Session,
|
||||
root: Path | str,
|
||||
recursive: bool = True,
|
||||
) -> tuple[int, int]:
|
||||
"""遍历目录入库。返回 (新增数, 跳过数)。支持 UNC 网络路径。"""
|
||||
root_p = path_from_storage(str(root)) if isinstance(root, str) else root
|
||||
if not is_accessible_dir(root_p):
|
||||
return 0, 0
|
||||
|
||||
iterator: Iterable[Path]
|
||||
if recursive:
|
||||
iterator = (p for p in root_p.rglob("*") if p.is_file())
|
||||
else:
|
||||
iterator = (p for p in root_p.iterdir() if p.is_file())
|
||||
|
||||
added, skipped = 0, 0
|
||||
for path in iterator:
|
||||
if not is_supported(path):
|
||||
continue
|
||||
stored = path_to_storage(path)
|
||||
before = session.scalar(
|
||||
select(Screenshot.id).where(Screenshot.path == stored)
|
||||
)
|
||||
result = ingest_path(session, path)
|
||||
if result is None:
|
||||
skipped += 1
|
||||
continue
|
||||
if before is None:
|
||||
added += 1
|
||||
else:
|
||||
skipped += 1
|
||||
# 批量提交,避免巨型事务
|
||||
if (added + skipped) % 50 == 0:
|
||||
session.commit()
|
||||
session.commit()
|
||||
return added, skipped
|
||||
@@ -0,0 +1,207 @@
|
||||
"""Provider 连通性测试:OCR / 视觉 AI。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from app.providers import build_ocr_provider
|
||||
from app.schemas.common import ProviderConfig
|
||||
|
||||
|
||||
class ProviderTestError(Exception):
|
||||
"""测试失败,携带用户可读信息。"""
|
||||
|
||||
|
||||
async def test_provider_config(key: str, cfg: ProviderConfig) -> dict[str, Any]:
|
||||
"""测试 OCR 或 VLM Provider 连通性,返回 {ok, message, detail, latency_ms}。"""
|
||||
started = time.perf_counter()
|
||||
try:
|
||||
if cfg.type in ("", "none", "disabled"):
|
||||
raise ProviderTestError("当前类型为「不使用」,无需测试")
|
||||
|
||||
if key == KEY_OCR:
|
||||
message, detail = await _test_ocr(cfg)
|
||||
elif key == KEY_VLM:
|
||||
message, detail = await _test_vlm(cfg)
|
||||
else:
|
||||
raise ProviderTestError(f"未知配置键: {key}")
|
||||
|
||||
latency = int((time.perf_counter() - started) * 1000)
|
||||
return {"ok": True, "message": message, "detail": detail, "latency_ms": latency}
|
||||
except ProviderTestError as exc:
|
||||
latency = int((time.perf_counter() - started) * 1000)
|
||||
return {"ok": False, "message": str(exc), "detail": None, "latency_ms": latency}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
latency = int((time.perf_counter() - started) * 1000)
|
||||
return {
|
||||
"ok": False,
|
||||
"message": f"测试失败: {exc}",
|
||||
"detail": repr(exc),
|
||||
"latency_ms": latency,
|
||||
}
|
||||
|
||||
|
||||
KEY_OCR = "ocr_provider"
|
||||
KEY_VLM = "vlm_provider"
|
||||
|
||||
# 1x1 白图,用于 HTTP OCR / 视觉测试
|
||||
_TINY_PNG_B64 = (
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
|
||||
)
|
||||
|
||||
|
||||
async def _test_ocr(cfg: ProviderConfig) -> tuple[str, str | None]:
|
||||
if cfg.type == "tesseract":
|
||||
return await _test_tesseract(cfg)
|
||||
if cfg.type == "paddleocr":
|
||||
return await _test_paddle(cfg)
|
||||
if cfg.type == "http":
|
||||
return await _test_http_ocr(cfg)
|
||||
if cfg.type == "vision":
|
||||
return await _test_openai_compat(cfg, label="视觉 OCR")
|
||||
raise ProviderTestError(f"不支持的 OCR 类型: {cfg.type}")
|
||||
|
||||
|
||||
async def _test_vlm(cfg: ProviderConfig) -> tuple[str, str | None]:
|
||||
if cfg.type in ("openai_compat", "openai", "ollama", "glm", "minimax", "moonshot", "vision"):
|
||||
return await _test_openai_compat(cfg, label="视觉 AI")
|
||||
raise ProviderTestError(f"不支持的 VLM 类型: {cfg.type}")
|
||||
|
||||
|
||||
async def _test_tesseract(cfg: ProviderConfig) -> tuple[str, str | None]:
|
||||
provider = build_ocr_provider(cfg, allow_upload=True)
|
||||
if provider is None:
|
||||
raise ProviderTestError("无法构造 Tesseract Provider")
|
||||
|
||||
def _check() -> str:
|
||||
import pytesseract
|
||||
|
||||
if cfg.extra.get("cmd"):
|
||||
pytesseract.pytesseract.tesseract_cmd = cfg.extra["cmd"]
|
||||
version = pytesseract.get_tesseract_version()
|
||||
return str(version)
|
||||
|
||||
version = await asyncio.to_thread(_check)
|
||||
return f"Tesseract 可用,版本 {version}", f"lang={cfg.extra.get('lang', 'chi_sim+eng')}"
|
||||
|
||||
|
||||
async def _test_paddle(cfg: ProviderConfig) -> tuple[str, str | None]:
|
||||
def _check() -> str:
|
||||
try:
|
||||
import paddleocr # noqa: F401
|
||||
except ImportError as exc:
|
||||
raise ProviderTestError(
|
||||
"未安装 PaddleOCR,请执行: pip install paddleocr paddlepaddle"
|
||||
) from exc
|
||||
return "PaddleOCR 模块已安装"
|
||||
|
||||
detail = await asyncio.to_thread(_check)
|
||||
provider = build_ocr_provider(cfg, allow_upload=True)
|
||||
if provider is None:
|
||||
raise ProviderTestError("无法构造 PaddleOCR Provider")
|
||||
return "PaddleOCR 可用", detail
|
||||
|
||||
|
||||
async def _test_http_ocr(cfg: ProviderConfig) -> tuple[str, str | None]:
|
||||
if not cfg.base_url:
|
||||
raise ProviderTestError("请填写 OCR API URL")
|
||||
provider = build_ocr_provider(cfg, allow_upload=True)
|
||||
if provider is None:
|
||||
raise ProviderTestError("无法构造 HTTP OCR Provider")
|
||||
|
||||
# 写入临时 tiny png 再调用
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
||||
tmp.write(base64.b64decode(_TINY_PNG_B64))
|
||||
tmp_path = Path(tmp.name)
|
||||
try:
|
||||
text = await provider.recognize(tmp_path)
|
||||
finally:
|
||||
try:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
preview = (text or "").strip()[:80] or "(空响应,但接口可达)"
|
||||
return "HTTP OCR 接口可达", f"响应预览: {preview}"
|
||||
|
||||
|
||||
async def _test_openai_compat(cfg: ProviderConfig, *, label: str) -> tuple[str, str | None]:
|
||||
base_url = (cfg.base_url or "http://localhost:11434/v1").rstrip("/")
|
||||
api_key = cfg.api_key or ""
|
||||
model = cfg.model or "gpt-4o-mini"
|
||||
timeout = float(cfg.extra.get("timeout", 30))
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# 1) 尝试 /models(Ollama、OpenAI 兼容)
|
||||
models_url = f"{base_url}/models"
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
try:
|
||||
resp = await client.get(models_url, headers=headers)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
ids = _extract_model_ids(data)
|
||||
if model and ids and model not in ids:
|
||||
return (
|
||||
f"{label} 服务可达",
|
||||
f"已连接 /models,但未找到模型「{model}」。可用: {', '.join(ids[:8])}",
|
||||
)
|
||||
return f"{label} 服务可达", f"已连接 /models,目标模型: {model}"
|
||||
except httpx.HTTPError:
|
||||
pass
|
||||
|
||||
# 2) 最小 chat 探活
|
||||
chat_url = f"{base_url}/chat/completions"
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": "请只回复 OK"}],
|
||||
"max_tokens": 16,
|
||||
"temperature": 0,
|
||||
}
|
||||
try:
|
||||
resp = await client.post(chat_url, json=payload, headers=headers)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
return f"{label} 对话成功", f"模型 {model} 回复: {str(content).strip()[:60]}"
|
||||
except httpx.HTTPStatusError as exc:
|
||||
body = exc.response.text[:200]
|
||||
raise ProviderTestError(
|
||||
f"API 返回 {exc.response.status_code}: {body}"
|
||||
) from exc
|
||||
except httpx.HTTPError as exc:
|
||||
raise ProviderTestError(f"无法连接 {base_url}: {exc}") from exc
|
||||
|
||||
|
||||
def _extract_model_ids(data: Any) -> list[str]:
|
||||
"""从 /models 响应中提取 model id 列表。"""
|
||||
if not isinstance(data, dict):
|
||||
return []
|
||||
items = data.get("data") or data.get("models") or []
|
||||
ids: list[str] = []
|
||||
if isinstance(items, list):
|
||||
for item in items:
|
||||
if isinstance(item, dict):
|
||||
mid = item.get("id") or item.get("name") or item.get("model")
|
||||
if mid:
|
||||
ids.append(str(mid))
|
||||
elif isinstance(item, str):
|
||||
ids.append(item)
|
||||
return ids
|
||||
|
||||
|
||||
def merge_provider_api_key(cfg: ProviderConfig, existing: dict | None) -> ProviderConfig:
|
||||
"""测试时若 api_key 为空,合并已保存的 key。"""
|
||||
payload = cfg.model_dump()
|
||||
if (not payload.get("api_key")) and isinstance(existing, dict):
|
||||
payload["api_key"] = existing.get("api_key", "")
|
||||
return ProviderConfig(**payload)
|
||||
@@ -0,0 +1,67 @@
|
||||
"""截图列表搜索:FTS + 子串模糊(兼容中文标签/标题)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import or_, select, text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.meta import ScreenshotMeta
|
||||
from app.models.screenshot import Screenshot
|
||||
from app.models.tag import Tag
|
||||
|
||||
|
||||
def fts_query_string(raw: str) -> str:
|
||||
"""把用户输入处理成 FTS5 查询串(中英文均支持前缀匹配)。"""
|
||||
parts = [p for p in raw.replace("\n", " ").split() if p]
|
||||
if not parts:
|
||||
return raw
|
||||
cleaned: list[str] = []
|
||||
for p in parts:
|
||||
p = p.replace('"', "").strip()
|
||||
if not p:
|
||||
continue
|
||||
cleaned.append(f'"{p}"*')
|
||||
return " ".join(cleaned)
|
||||
|
||||
|
||||
def collect_search_ids(session: Session, q: str, *, limit: int = 5000) -> set[int]:
|
||||
"""联合 FTS5 与 LIKE 子串搜索,返回匹配的 screenshot id 集合。"""
|
||||
q = q.strip()
|
||||
if not q:
|
||||
return set()
|
||||
|
||||
ids: set[int] = set()
|
||||
like = f"%{q}%"
|
||||
|
||||
# 1) FTS5 全文索引
|
||||
try:
|
||||
fts_sql = text(
|
||||
"SELECT rowid FROM screenshots_fts WHERE screenshots_fts MATCH :q LIMIT :lim"
|
||||
)
|
||||
rows = session.execute(fts_sql, {"q": fts_query_string(q), "lim": limit}).fetchall()
|
||||
ids.update(int(row[0]) for row in rows)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2) 子串模糊:OCR/AI 文本(解决「三花」匹配「三花猫」)
|
||||
meta_ids = session.scalars(
|
||||
select(ScreenshotMeta.screenshot_id).where(
|
||||
or_(
|
||||
ScreenshotMeta.ocr_text.ilike(like),
|
||||
ScreenshotMeta.ai_title.ilike(like),
|
||||
ScreenshotMeta.ai_summary.ilike(like),
|
||||
ScreenshotMeta.ai_suggestion.ilike(like),
|
||||
)
|
||||
).limit(limit)
|
||||
).all()
|
||||
ids.update(int(i) for i in meta_ids)
|
||||
|
||||
# 3) 标签名子串匹配
|
||||
tag_ids = session.scalars(
|
||||
select(Screenshot.id)
|
||||
.join(Screenshot.tags)
|
||||
.where(Tag.name.ilike(like))
|
||||
.limit(limit)
|
||||
).all()
|
||||
ids.update(int(i) for i in tag_ids)
|
||||
|
||||
return ids
|
||||
@@ -0,0 +1,50 @@
|
||||
"""读取/写入键值设置。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.setting import Setting
|
||||
|
||||
|
||||
def get_setting(session: Session, key: str, default: Any = None) -> Any:
|
||||
"""读取并 JSON 解析。"""
|
||||
row = session.get(Setting, key)
|
||||
if row is None:
|
||||
return default
|
||||
try:
|
||||
return json.loads(row.value_json)
|
||||
except json.JSONDecodeError:
|
||||
return default
|
||||
|
||||
|
||||
def set_setting(session: Session, key: str, value: Any) -> None:
|
||||
"""JSON 序列化后落库(upsert)。"""
|
||||
row = session.get(Setting, key)
|
||||
payload = json.dumps(value, ensure_ascii=False)
|
||||
if row is None:
|
||||
session.add(Setting(key=key, value_json=payload))
|
||||
else:
|
||||
row.value_json = payload
|
||||
session.flush()
|
||||
|
||||
|
||||
def all_settings(session: Session) -> dict[str, Any]:
|
||||
"""返回所有设置,给前端调试 / 导出。"""
|
||||
items: dict[str, Any] = {}
|
||||
for row in session.query(Setting).all():
|
||||
try:
|
||||
items[row.key] = json.loads(row.value_json)
|
||||
except json.JSONDecodeError:
|
||||
items[row.key] = row.value_json
|
||||
return items
|
||||
|
||||
|
||||
def get_provider_config(session: Session, key: str) -> Optional[dict[str, Any]]:
|
||||
"""便捷读取 OCR/VLM provider 配置 dict。"""
|
||||
value = get_setting(session, key, None)
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
return None
|
||||
@@ -0,0 +1,48 @@
|
||||
"""生成并缓存缩略图。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.path_utils import path_to_storage
|
||||
|
||||
|
||||
SUPPORTED_EXTS = {".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp", ".tif", ".tiff"}
|
||||
|
||||
|
||||
def is_supported(path: Path) -> bool:
|
||||
"""是否为我们支持的图片格式。"""
|
||||
return path.suffix.lower() in SUPPORTED_EXTS
|
||||
|
||||
|
||||
def generate_thumbnail(image_path: Path, max_side: int | None = None) -> Path:
|
||||
"""生成 webp 缩略图,落到缓存目录。返回缓存路径。"""
|
||||
max_side = max_side or settings.thumb_size
|
||||
# 用文件路径 + mtime 哈希作为缓存键,源文件变化会自动生成新缩略图
|
||||
stat = image_path.stat()
|
||||
key = hashlib.md5(
|
||||
f"{path_to_storage(image_path)}|{stat.st_mtime_ns}|{max_side}".encode("utf-8")
|
||||
).hexdigest()
|
||||
out = settings.thumb_dir / f"{key}.webp"
|
||||
if out.exists():
|
||||
return out
|
||||
with Image.open(image_path) as img:
|
||||
img = img.convert("RGB")
|
||||
img.thumbnail((max_side, max_side), Image.LANCZOS)
|
||||
img.save(out, format="WEBP", quality=80)
|
||||
return out
|
||||
|
||||
|
||||
def file_hash(image_path: Path, chunk: int = 1024 * 1024) -> str:
|
||||
"""计算文件 sha256,用作去重键。"""
|
||||
h = hashlib.sha256()
|
||||
with open(image_path, "rb") as f:
|
||||
while True:
|
||||
data = f.read(chunk)
|
||||
if not data:
|
||||
break
|
||||
h.update(data)
|
||||
return h.hexdigest()
|
||||
@@ -0,0 +1,122 @@
|
||||
"""watchdog 监听被关注的目录。
|
||||
|
||||
中文路径与 OneDrive 同步盘下 NTFS 事件偶发不稳,因此默认使用 PollingObserver。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import select
|
||||
from watchdog.events import FileSystemEvent, FileSystemEventHandler
|
||||
from watchdog.observers.polling import PollingObserver
|
||||
|
||||
from app.core.db import session_scope
|
||||
from app.core.logger import get_logger
|
||||
from app.core.path_utils import is_accessible_dir, path_from_storage
|
||||
from app.models.watch_folder import WatchFolder
|
||||
from app.services.ingest import ingest_path
|
||||
from app.services.thumbnail import is_supported
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class _ScreenshotEventHandler(FileSystemEventHandler):
|
||||
"""新文件 -> 入库 -> 触发分析。"""
|
||||
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop, notify) -> None: # noqa: ANN001
|
||||
self._loop = loop
|
||||
self._notify = notify
|
||||
|
||||
def on_created(self, event: FileSystemEvent) -> None:
|
||||
if event.is_directory:
|
||||
return
|
||||
self._handle(Path(event.src_path))
|
||||
|
||||
def on_moved(self, event: FileSystemEvent) -> None:
|
||||
if event.is_directory:
|
||||
return
|
||||
self._handle(Path(getattr(event, "dest_path", event.src_path)))
|
||||
|
||||
def _handle(self, path: Path) -> None:
|
||||
if not is_supported(path):
|
||||
return
|
||||
# 等待写入完成(截图工具常会先创建空文件再写入)
|
||||
try:
|
||||
self._wait_file_ready(path)
|
||||
except FileNotFoundError:
|
||||
return
|
||||
with session_scope() as session:
|
||||
shot = ingest_path(session, path)
|
||||
if shot is not None:
|
||||
asyncio.run_coroutine_threadsafe(self._notify(), self._loop)
|
||||
|
||||
@staticmethod
|
||||
def _wait_file_ready(path: Path, retries: int = 10, interval: float = 0.3) -> None:
|
||||
"""轮询直至文件大小稳定。"""
|
||||
import time
|
||||
|
||||
last = -1
|
||||
for _ in range(retries):
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(path)
|
||||
size = path.stat().st_size
|
||||
if size > 0 and size == last:
|
||||
return
|
||||
last = size
|
||||
time.sleep(interval)
|
||||
|
||||
|
||||
class WatcherService:
|
||||
"""管理多个监听目录的生命周期。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._observer: PollingObserver | None = None
|
||||
self._lock = threading.Lock()
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._notify_cb = None
|
||||
|
||||
def start(self, loop: asyncio.AbstractEventLoop, notify) -> None: # noqa: ANN001
|
||||
"""根据数据库中的目录列表启动监听。"""
|
||||
with self._lock:
|
||||
self._loop = loop
|
||||
self._notify_cb = notify
|
||||
self._stop_locked()
|
||||
self._observer = PollingObserver(timeout=2.0)
|
||||
handler = _ScreenshotEventHandler(loop, notify)
|
||||
with session_scope() as session:
|
||||
folders = session.scalars(
|
||||
select(WatchFolder).where(WatchFolder.enabled.is_(True))
|
||||
).all()
|
||||
paths = [(f.path, f.recursive) for f in folders]
|
||||
for path, recursive in paths:
|
||||
p = path_from_storage(path)
|
||||
if not is_accessible_dir(p):
|
||||
logger.warning("监听目录不存在或不可访问,跳过: %s", path)
|
||||
continue
|
||||
logger.info("开始监听 %s (recursive=%s)", path, recursive)
|
||||
self._observer.schedule(handler, str(p), recursive=recursive)
|
||||
self._observer.start()
|
||||
|
||||
def reload(self) -> None:
|
||||
"""监听目录变更后重启。"""
|
||||
if self._loop is None or self._notify_cb is None:
|
||||
return
|
||||
self.start(self._loop, self._notify_cb)
|
||||
|
||||
def stop(self) -> None:
|
||||
with self._lock:
|
||||
self._stop_locked()
|
||||
|
||||
def _stop_locked(self) -> None:
|
||||
if self._observer is not None:
|
||||
try:
|
||||
self._observer.stop()
|
||||
self._observer.join(timeout=3)
|
||||
finally:
|
||||
self._observer = None
|
||||
|
||||
|
||||
watcher_service = WatcherService()
|
||||
@@ -0,0 +1,237 @@
|
||||
"""异步任务调度器:从 jobs 表取任务并并发执行。
|
||||
|
||||
事务规则:
|
||||
- 调度循环只用短事务 claim 任务、汇总状态。
|
||||
- 真正的 OCR/VLM 调用由 `analyze_screenshot_by_id` 自己管理短事务,
|
||||
绝不在 worker 这一层包裹长事务。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import case, func, or_, select
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.db import session_scope
|
||||
from app.core.logger import get_logger
|
||||
from app.models.job import Job, JobKind, JobStatus
|
||||
from app.models.screenshot import ProcessStatus, Screenshot
|
||||
from app.services.analyze import analyze_ocr_only_by_id, analyze_screenshot_by_id
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AnalyzeWorker:
|
||||
"""单实例后台 worker,负责把 jobs 表中的待处理项跑完。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._event = asyncio.Event()
|
||||
self._stop = False
|
||||
self._semaphore = asyncio.Semaphore(settings.analyze_concurrency)
|
||||
self._inflight: int = 0
|
||||
self._lock = asyncio.Lock()
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动主循环。"""
|
||||
# 启动时把上次中断的 RUNNING 任务复位
|
||||
with session_scope() as session:
|
||||
running = session.scalars(
|
||||
select(Job).where(Job.status == JobStatus.RUNNING.value)
|
||||
).all()
|
||||
for job in running:
|
||||
job.status = JobStatus.PENDING.value
|
||||
self._stop = False
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._task = asyncio.create_task(self._run(), name="analyze-worker")
|
||||
self.notify()
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._stop = True
|
||||
self._event.set()
|
||||
if self._task is not None:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def notify(self) -> None:
|
||||
"""在事件循环线程内通知 worker 有新任务可取。"""
|
||||
self._event.set()
|
||||
|
||||
def notify_threadsafe(self) -> None:
|
||||
"""跨线程唤醒 worker(FastAPI BackgroundTasks / watcher 线程)。
|
||||
|
||||
asyncio.Event.set() 本身不是线程安全的,必须通过
|
||||
loop.call_soon_threadsafe 调度回事件循环线程。
|
||||
"""
|
||||
loop = self._loop
|
||||
if loop is None or loop.is_closed():
|
||||
return
|
||||
loop.call_soon_threadsafe(self._event.set)
|
||||
|
||||
async def status(self) -> dict[str, int]:
|
||||
"""供 API 查询当前队列状况(按 status 索引计数,适合大批量)。"""
|
||||
with session_scope() as session:
|
||||
pending = session.scalar(
|
||||
select(Job.id)
|
||||
.where(Job.status == JobStatus.PENDING.value)
|
||||
.limit(1)
|
||||
)
|
||||
rows = session.execute(
|
||||
select(Job.status, func.count())
|
||||
.group_by(Job.status)
|
||||
).all()
|
||||
counts = {st.value: 0 for st in JobStatus}
|
||||
for status, cnt in rows:
|
||||
counts[status] = int(cnt)
|
||||
counts["inflight"] = self._inflight
|
||||
counts["has_more"] = 1 if pending is not None else 0
|
||||
return counts
|
||||
|
||||
def reset_stale_running(self, *, minutes: int = 5, reset_all: bool = False) -> int:
|
||||
"""把长时间 RUNNING 且无进展的任务复位为 PENDING。"""
|
||||
with session_scope() as session:
|
||||
q = select(Job).where(Job.status == JobStatus.RUNNING.value)
|
||||
if not reset_all:
|
||||
cutoff = datetime.utcnow() - timedelta(minutes=max(minutes, 1))
|
||||
q = q.where(Job.started_at.is_not(None), Job.started_at < cutoff)
|
||||
stale = session.scalars(q).all()
|
||||
for job in stale:
|
||||
job.status = JobStatus.PENDING.value
|
||||
job.started_at = None
|
||||
count = len(stale)
|
||||
if count:
|
||||
logger.info("复位 %d 条 RUNNING 任务为 PENDING", count)
|
||||
self.notify()
|
||||
return count
|
||||
|
||||
def retry_failed(self, job_ids: Optional[list[int]] = None) -> int:
|
||||
"""将 failed 任务重新排队。"""
|
||||
with session_scope() as session:
|
||||
q = select(Job).where(Job.status == JobStatus.FAILED.value)
|
||||
if job_ids:
|
||||
q = q.where(Job.id.in_(job_ids))
|
||||
failed = session.scalars(q).all()
|
||||
for job in failed:
|
||||
job.status = JobStatus.PENDING.value
|
||||
job.retries = 0
|
||||
job.last_error = None
|
||||
job.started_at = None
|
||||
job.finished_at = None
|
||||
count = len(failed)
|
||||
if count:
|
||||
logger.info("重试 %d 条 failed 任务", count)
|
||||
self.notify()
|
||||
return count
|
||||
|
||||
async def _run(self) -> None:
|
||||
"""主循环。"""
|
||||
idle_rounds = 0
|
||||
while not self._stop:
|
||||
job = self._claim_one()
|
||||
if job is None:
|
||||
idle_rounds += 1
|
||||
# 空闲时定期清理僵尸 RUNNING,避免 inflight=0 但 DB 仍显示 running
|
||||
if idle_rounds >= 3 and self._inflight == 0:
|
||||
idle_rounds = 0
|
||||
if self.reset_stale_running(minutes=5):
|
||||
continue
|
||||
self._event.clear()
|
||||
try:
|
||||
await asyncio.wait_for(self._event.wait(), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
continue
|
||||
idle_rounds = 0
|
||||
await self._semaphore.acquire()
|
||||
async with self._lock:
|
||||
self._inflight += 1
|
||||
asyncio.create_task(
|
||||
self._process(job["id"], job["screenshot_id"], job["kind"])
|
||||
)
|
||||
|
||||
def _claim_one(self) -> Optional[dict]:
|
||||
"""短事务:取一条 PENDING 任务;FULL 优先于 OCR 补跑。"""
|
||||
with session_scope() as session:
|
||||
job = session.scalar(
|
||||
select(Job)
|
||||
.where(
|
||||
Job.status == JobStatus.PENDING.value,
|
||||
or_(Job.retries < settings.max_retries, Job.retries.is_(None)),
|
||||
)
|
||||
.order_by(
|
||||
case(
|
||||
(Job.kind == JobKind.FULL.value, 0),
|
||||
(Job.kind == JobKind.VLM.value, 1),
|
||||
else_=2,
|
||||
),
|
||||
Job.id.asc(),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if job is None:
|
||||
return None
|
||||
job.status = JobStatus.RUNNING.value
|
||||
job.started_at = datetime.utcnow()
|
||||
session.flush()
|
||||
return {
|
||||
"id": job.id,
|
||||
"screenshot_id": job.screenshot_id,
|
||||
"kind": job.kind,
|
||||
}
|
||||
|
||||
async def _process(self, job_id: int, screenshot_id: int, kind: str) -> None:
|
||||
"""执行单个任务,所有 DB 写入均在短事务中。"""
|
||||
try:
|
||||
try:
|
||||
if kind == JobKind.OCR.value:
|
||||
await analyze_ocr_only_by_id(screenshot_id)
|
||||
else:
|
||||
await analyze_screenshot_by_id(screenshot_id)
|
||||
self._finish(job_id, success=True, kind=kind)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("分析失败 #%d (%s): %s", screenshot_id, kind, exc)
|
||||
self._finish(job_id, success=False, error=str(exc), kind=kind)
|
||||
finally:
|
||||
self._semaphore.release()
|
||||
async with self._lock:
|
||||
self._inflight -= 1
|
||||
self.notify()
|
||||
|
||||
def _finish(
|
||||
self,
|
||||
job_id: int,
|
||||
success: bool,
|
||||
error: Optional[str] = None,
|
||||
kind: str = JobKind.FULL.value,
|
||||
) -> None:
|
||||
"""短事务:更新 jobs 表完成状态。"""
|
||||
with session_scope() as session:
|
||||
job = session.get(Job, job_id)
|
||||
if job is None:
|
||||
return
|
||||
if success:
|
||||
job.status = JobStatus.DONE.value
|
||||
job.last_error = None
|
||||
else:
|
||||
job.retries = (job.retries or 0) + 1
|
||||
if job.retries >= settings.max_retries:
|
||||
job.status = JobStatus.FAILED.value
|
||||
# OCR 补跑失败不影响 ai_status
|
||||
if kind != JobKind.OCR.value:
|
||||
shot = session.get(Screenshot, job.screenshot_id)
|
||||
if shot is not None:
|
||||
shot.ai_status = ProcessStatus.FAILED.value
|
||||
else:
|
||||
job.status = JobStatus.PENDING.value
|
||||
job.last_error = (error or "")[:1000]
|
||||
job.finished_at = datetime.utcnow()
|
||||
|
||||
|
||||
worker = AnalyzeWorker()
|
||||
Reference in New Issue
Block a user