Initial commit: snapAna 截图智能整理工具

包含 FastAPI 后端、React 前端、队列/OCR/标签/待办等完整功能。

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
wjl
2026-05-27 15:45:50 +08:00
commit 5c028d7952
76 changed files with 10467 additions and 0 deletions
+24
View File
@@ -0,0 +1,24 @@
# Python
__pycache__/
*.py[cod]
*.egg-info/
.venv/
.env
# Node
node_modules/
dist/
.vite/
# 项目数据
backend/.data/
*.db
*.db-journal
*.db-wal
*.db-shm
# IDE / OS
.DS_Store
.idea/
.vscode/
Thumbs.db
+119
View File
@@ -0,0 +1,119 @@
# snapAna 进度文档
> 维护原则:完成或推进一个功能后在这里追加/更新条目;下次接手时先读这里再读代码。
## v0.1 (MVP) · 2026-05-22
### 后端 (FastAPI + SQLite)
- [x] 项目骨架:`backend/app/{core,models,schemas,providers,services,api}`,零迁移启动(`init_db` 自动建表 + 装配 FTS5 触发器)。
- [x] 数据模型:
- `screenshots`(含 `file_hash` 唯一、`captured_at`/`ai_status` 索引)
- `screenshot_meta` (1:1) + `screenshots_fts` (FTS5 over ocr_text/ai_title/ai_summary/ai_suggestion)
- `tags` + `screenshot_tags` 多对多
- `categories`(首次启动自动灌入 8 个默认分类)
- `todos`(AI 抽取的待办,状态:pending/doing/done/dropped
- `jobs`OCR/VLM/FULL 三种类型,含重试计数与最后错误)
- `watch_folders`(含 `is_sensitive` 字段,敏感目录禁上传)
- `settings`(键值表,存 Provider JSON
- [x] 文件监听:`watcher.py` 使用 PollingObserver,新增/重命名都会触发;写入后等待文件大小稳定再入库。
- [x] 入库:sha256 去重,相同内容仅更新路径;同步生成 webp 缩略图。
- [x] 手动批量导入:`POST /api/watch/import`,后台任务扫描目录。
- [x] Provider 抽象:`OCRProvider`/`VLMProvider`;内置 `TesseractOCR``OpenAICompatVLM`(兼容 Ollama / GLM / MiniMax / OpenAI / OpenRouter)。
- [x] 分析流水线:`analyze_screenshot` -> OCR -> VLM -> 写 meta -> 解析 category/tags/todos;敏感目录禁止 base64 上传。
- [x] 异步 worker`AnalyzeWorker` 单例,启动时复位上次 `running` 任务,asyncio.Semaphore 控制并发,失败自动按 `MAX_RETRIES` 重排。
- [x] REST 接口:
- `GET /api/screenshots`(分页 + 过滤 + FTS5 关键词)
- `GET /api/screenshots/{id}``PATCH``DELETE``POST .../reanalyze``/file``/thumb`
- `GET /api/screenshots/random``/stats`
- `GET/PATCH/DELETE /api/todos``GET /api/todos/summary`
- `GET/POST/PATCH/DELETE /api/watch/folders``POST /api/watch/import``GET /api/watch/queue`
- `GET /api/settings`api_key 脱敏)、`GET/PUT /api/settings/providers/{key}`
- `GET/POST/PATCH/DELETE /api/settings/categories``GET /api/settings/tags`
### 前端 (React + Vite + Tailwind + react-window)
- [x] 路由:首页 / 库 / 随机 / 待办 / 设置
- [x] 首页:四宫格状态卡 + 每日回顾随机 6 张 + 分类分布
- [x] 库浏览:左侧筛选侧栏(关键词、分类、时间区间、排序、状态、收藏、热门标签) + react-window 虚拟滚动卡片网格 + 分页
- [x] 详情抽屉:原图 + AI 标题/摘要/建议 + OCR 文本(可复制)+ 分类切换 + 标签编辑 + 待办列表 + 元信息 + 收藏/重分析/移除
- [x] 随机展示页:单张大图 + 元信息侧栏 + 「再来一张」
- [x] 待办页:四个状态 Tab + 卡片列表 + 完成/搁置/重置 + 跳回原图
- [x] 设置页:
- 监听目录增删改 + 「重扫」
- OCR / VLM Provider 配置(OCR 支持 TesseractVLM 走 OpenAI 兼容)
- 分类管理(颜色 + 提示词)
### 工程化
- [x] `start-dev.ps1` 一键启动(自动建 venv + npm install + 并发启动后端与前端)
- [x] `.gitignore`、根 `README.md``backend/README.md``.env.example`
- [x] CORS、SQLite WAL / busy_timeout 等优化
## v0.1.6 · 待办/搜索/标签/排序
- [x] **待办**:分页 + 标题/备注关键词搜索
- [x] **库搜索**:FTS 前缀 + LIKE 子串(「三花」可匹配「三花猫」)+ 标签名模糊
- [x] **标签页** `/tags`:全部标签浏览、搜索、排序、点击跳转库筛选
- [x] **库排序**:导入时间、标题、文件大小等 8 种
- [x] **EXIF 地点**:入库读 GPS/拍摄时间,自动加 `地点:` 标签;重分析保留
## v0.1.5 · OCR 补跑队列
- [x] **OCR 专用任务** `JobKind.OCR`:仅重跑 OCR/视觉识文,不改动 AI 结果
- [x] **批量入队** `POST /api/watch/jobs/enqueue-ocr-failed`AI 成功 + OCR 失败)
- [x] **单张补跑** `POST /api/screenshots/{id}/reocr`
- [x] WorkerFULL 任务优先于 OCROCR 失败不污染 `ai_status`
- [x] 队列页「OCR 待补」统计 + 补跑按钮;详情页「补跑 OCR」
## v0.1.4 · 队列详情页
- [x] **队列 API**`GET /api/watch/jobs` 分页列出任务(含 `last_error`、缩略图、路径)
- [x] **队列操作**`POST /api/watch/jobs/retry-failed` 重试失败任务;`POST /api/watch/jobs/reset-stale` 复位僵尸 RUNNING
- [x] **Worker 优化**`status()` 改用 `GROUP BY` 计数;空闲时自动复位超时 RUNNING
- [x] **前端队列页**:侧栏「队列」入口;失败/运行中/排队/完成 Tab + 分页;展示完整错误信息;首页队列卡片可点击跳转
## v0.1.3 · UNC 网络路径 + Provider 测试
- [x] **UNC / 网络路径**`path_utils.py` 规范化 `\\NAS\share\...`;入库、监听、原图读取不再用 `as_posix()` 破坏 UNC
- [x] **监听目录**`POST /api/watch/validate-path` 测试路径可达性;设置页「测试路径」按钮
- [x] **Provider 测试**`POST /api/settings/providers/{key}/test`OCRTesseract/Paddle/HTTP/视觉)与视觉 AI 均支持连通性探活
## v0.1.2 · 多 OCR 引擎 + 识别模式
- [x] **OCR 引擎扩展**Tesseract、PaddleOCR、HTTP API、视觉模型识文(OpenAI 兼容)
- [x] **文字识别方式**:设置页可选「传统 OCR / 视觉 AI / 混合」
- `ocr`:仅 OCR 引擎识文
- `vision`:视觉大模型识文(用 VLM 配置)
- `hybrid`:OCR 优先,失败时自动视觉识文,再交给 AI 分析
- [x] 新增 Provider`ocr_vision.py``ocr_http.py``ocr_paddle.py``openai_vision_client.py`
- [x] API`GET/PUT /api/settings/recognition-mode`
## v0.1.1 · 代码审核修复
- [x] **P1**:拆分 worker/analyze 事务。AI 调用全程在事务外执行,OCR/VLM 之间用三段短事务标记/写回,彻底消除「分析期间 SQLite 写锁」问题。
- [x] **P2**:详情页 PATCH `{ category_id: null }` 现在能真正清空分类(用 Pydantic `model_fields_set` 区分未传与显式 null),同时校验 category 存在性。
- [x] **P2**`screenshots.category_id` 升级为 `ForeignKey(..., ondelete="SET NULL")``init_db()` 内置轻量迁移,旧库会清理悬空 `category_id`
- [x] **P2**:跨线程唤醒 worker 改用 `loop.call_soon_threadsafe`(新接口 `worker.notify_threadsafe()`)。FastAPI 同步路由与 BackgroundTasks 都改走 threadsafe 入口。
- [x] **P3**`GET /api/settings/providers/{key}` 返回新的 `ProviderConfigOut`,含 `api_key_mask` 字段;前端 `Settings.tsx` 去掉强转。
- [x] **P3**`init_db()` 现在直接调用 `ensure_default_categories()`,首次访问设置/筛选页就能看到全部默认分类。
## 已知限制 & 后续可做
- 没接语义搜索 / CLIP 向量(在 plan 的「可扩展点」里预留思路)
- 没做 dedup(pHash),相同内容不同分辨率会算两条
- VLM 调用没做 RPS 限流,仅靠 `ANALYZE_CONCURRENCY`
- Tesseract 在 Windows 上需用户自行安装;可补一个一键检测脚本
- 前端筛选 Tag 只有第一个生效(多 Tag 交集后端未支持)
- 缩略图未做 LRU 清理,长时间运行需要手动清 `.data/thumbs/`
- SQLite 不支持 ALTER TABLE 加外键。已建过的旧库虽然不会再有悬空 `category_id`,但底层约束仍缺失;如果在意可以删 `.data/snapana.db` 让新表生效
## 操作回顾(首次部署)
1. `git clone` 或者直接进入仓库根目录
2. `.\start-dev.ps1` (首次约 1-2 分钟装依赖)
3. 浏览器打开 <http://127.0.0.1:5173>
4. 「设置」→ 添加监听目录(例如 `D:/Pictures/Screenshots`),勾选「敏感」可禁上传
5. 「设置」→ 配置 OCR / VLM Provider,保存
6. 等待首页右上角「队列」归零(或在「库」里看 `分析中 / 完成` 标记)
7. 享受筛选、随机、待办
+94
View File
@@ -0,0 +1,94 @@
# snapAna · 截图智能整理
让 AI 帮你认识每一张截图。本地运行的 Web 应用:自动监听截图文件夹,提取文字、识别内容、给出标题/摘要/标签/待办,按时间和分类整理成可筛选、可随机展示的卡片库。
## 特性
- 监听一个或多个文件夹,新截图自动入库(含 OneDrive/同步盘的轮询兜底)
- 哈希去重,文件重命名/移动只更新路径
- 可插拔 ProviderTesseract / PaddleOCR / HTTP OCR / 视觉模型识文 + OpenAI 兼容视觉 AIOllama / GLM / MiniMax / OpenAI…)
- 文字识别方式可选:传统 OCR、视觉 AI、混合(OCR 失败自动视觉识文)
- 单张图同时拿到结构化结果:标题 + 摘要 + 分类 + 标签 + 待办 + 建议
- SQLite + FTS5 全文搜索(OCR 文本 / AI 摘要 / AI 标题)
- 分类色块、标签云、收藏、日期范围、状态筛选;卡片网格虚拟滚动
- 随机展示页 + 首页「每日回顾」
- 待办清单:AI 自动抽取「待看 / 待读 / 待办」,可逐条标记完成
- 敏感目录黑名单:勾选后该目录内的截图不会上传云端 VLM
## 目录结构
```
snapAna/
├── backend/ # FastAPI + SQLite + watchdog
│ ├── app/ # 应用代码
│ ├── run.py # 开发入口
│ └── requirements.txt
├── frontend/ # Vite + React + Tailwind
│ └── src/
├── start-dev.ps1 # 一键启动脚本(Windows
└── PROGRESS.md # 进度文档
```
## 快速开始
### 一键启动(Windows PowerShell
```powershell
# 在仓库根目录
.\start-dev.ps1 # 首次会自动建 venv + 安装依赖
# 或显式重新装依赖
.\start-dev.ps1 -InstallDeps
```
启动后访问:
- 前端:<http://127.0.0.1:5173>
- 后端 API<http://127.0.0.1:8765/docs>
### 手动启动
```powershell
# 后端
cd backend
python -m venv .venv
.\.venv\Scripts\Activate.ps1
pip install -r requirements.txt
python run.py
```
```powershell
# 前端(另开一个终端)
cd frontend
npm install
npm run dev
```
## 配置 AI Provider
进入「设置」页:
- **文字识别方式**:传统 OCR / 视觉 AI 识文 / 混合(推荐)
- **OCR 引擎**`tesseract`(本地)、`paddleocr`(需 `pip install paddleocr`)、`http`(自定义 API)、`vision`(视觉模型纯识文)
- **视觉 AI 模型**:选 `openai_compat`,填入:
| 模型 | Base URL | Model 示例 |
| --- | --- | --- |
| 本地 Ollama | `http://localhost:11434/v1` | `qwen2.5vl:7b` |
| OpenAI | `https://api.openai.com/v1` | `gpt-4o-mini` |
| 智谱 GLM | `https://open.bigmodel.cn/api/paas/v4` | `glm-4v-flash` |
| MiniMax | `https://api.minimaxi.com/v1` | `MiniMax-VL-01` |
| OpenRouter | `https://openrouter.ai/api/v1` | `qwen/qwen2.5-vl-72b-instruct` |
保存后回到「设置」→「监听目录」,添加一个截图文件夹,系统会自动扫描并入库;右侧 worker 会按配置并发分析(默认并发数 2,可通过 `.env` 调整)。
## 数据与隐私
- 所有数据存在 `backend/.data/``snapana.db`SQLite+ `thumbs/`(缩略图缓存)
- 默认绑定 `127.0.0.1`,仅本机访问
- 标记为「敏感目录」的截图不会上传到云端 VLM;如果两个 Provider 都是本地,则永远离线
- 上传 VLM 之前会自动压缩到长边 1280 像素以节省成本/时延
## 开发提示
- 前端通过 Vite 反向代理 `/api/*``127.0.0.1:8765`
- 队列 / 监听器在 FastAPI lifespan 内启动,热重载会自动复用
- SQLAlchemy 模型与 FTS5 触发器在首次启动时由 `init_db()` 创建,无需额外迁移命令
+15
View File
@@ -0,0 +1,15 @@
# snapAna 后端环境变量示例。复制为 .env 后按需修改。
DEBUG=false
HOST=127.0.0.1
PORT=8765
# 数据目录(默认 backend/.data
# DATA_DIR=D:/snapAna-data
# 并发与重试
ANALYZE_CONCURRENCY=4
MAX_RETRIES=3
# 缩略图 / VLM 上传压缩
THUMB_SIZE=320
VLM_MAX_SIDE=1280
+44
View File
@@ -0,0 +1,44 @@
# snapAna Backend
基于 FastAPI 的截图分析与分类后端。
## 安装
```bash
cd backend
python -m venv .venv
.venv\Scripts\activate # Windows
pip install -r requirements.txt
```
可选依赖:
- 本地 OCR:安装 [Tesseract OCR](https://github.com/UB-Mannheim/tesseract/wiki) 并放入 PATH,下载 `chi_sim` 中文语言包。
- 本地 VLM:安装 [Ollama](https://ollama.com),拉取 `qwen2.5vl:7b` 等多模态模型。
## 启动
```bash
copy .env.example .env # 按需修改
python run.py
```
默认监听 `http://127.0.0.1:8765`。OpenAPI 在 `/docs`
## 数据目录
- SQLite 主库:`backend/.data/snapana.db`
- 缩略图缓存:`backend/.data/thumbs/`
可通过 `.env``DATA_DIR` 自定义。
## Provider 配置
在前端 `设置` 页或通过 `/api/settings/providers/{key}` 接口配置:
- OCR`tesseract`(本地)或 `none`(仅靠 VLM 看图)
- VLM`openai_compat``base_url` 形如:
- 本地 Ollama`http://localhost:11434/v1`model 例如 `qwen2.5vl:7b`
- 智谱 GLM`https://open.bigmodel.cn/api/paas/v4`model 例如 `glm-4v-flash`
- MiniMax`https://api.minimaxi.com/v1`
- OpenAI`https://api.openai.com/v1`model 例如 `gpt-4o-mini`
+3
View File
@@ -0,0 +1,3 @@
"""snapAna 截图分析后端应用包。"""
__version__ = "0.1.0"
View File
+17
View File
@@ -0,0 +1,17 @@
"""API 通用依赖。"""
from __future__ import annotations
from typing import Iterator
from sqlalchemy.orm import Session
from app.core.db import SessionLocal
def db_session() -> Iterator[Session]:
"""每请求一个会话。"""
session = SessionLocal()
try:
yield session
finally:
session.close()
+367
View File
@@ -0,0 +1,367 @@
"""截图列表 / 详情 / 随机 / 重新分析 / 文件流。"""
from __future__ import annotations
from datetime import datetime
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import FileResponse
from sqlalchemy import and_, func, or_, select, text
from sqlalchemy.orm import Session, selectinload
from app.api.deps import db_session
from app.core.path_utils import is_accessible_file, path_from_storage
from app.models.category import Category
from app.models.job import Job, JobKind, JobStatus
from app.models.meta import ScreenshotMeta
from app.models.screenshot import ProcessStatus, Screenshot
from app.models.tag import Tag
from app.services.search_utils import collect_search_ids
from app.schemas.screenshot import (
CategoryOut,
ScreenshotBrief,
ScreenshotDetail,
ScreenshotListResp,
ScreenshotUpdate,
TagOut,
TodoBrief,
)
from app.services.worker import worker
router = APIRouter(prefix="/api/screenshots", tags=["screenshots"])
def _to_brief(shot: Screenshot, cat_map: dict[int, Category]) -> ScreenshotBrief:
"""ORM -> ScreenshotBrief。"""
return ScreenshotBrief(
id=shot.id,
path=shot.path,
width=shot.width,
height=shot.height,
captured_at=shot.captured_at,
thumb_url=f"/api/screenshots/{shot.id}/thumb" if shot.thumb_path else None,
ai_title=(shot.meta.ai_title if shot.meta else None),
ai_status=shot.ai_status,
ocr_status=shot.ocr_status,
is_favorite=bool(shot.is_favorite),
category=(
CategoryOut.model_validate(cat_map[shot.category_id])
if shot.category_id and shot.category_id in cat_map
else None
),
tags=[TagOut.model_validate(t) for t in (shot.tags or [])],
)
def _category_map(session: Session) -> dict[int, Category]:
return {c.id: c for c in session.scalars(select(Category)).all()}
@router.get("", response_model=ScreenshotListResp)
def list_screenshots(
session: Session = Depends(db_session),
q: Optional[str] = Query(None, description="OCR+AI 全文搜索关键词"),
category_id: Optional[int] = Query(None),
tag: Optional[str] = Query(None),
date_from: Optional[datetime] = Query(None),
date_to: Optional[datetime] = Query(None),
favorite: Optional[bool] = Query(None),
status: Optional[str] = Query(None, description="ai_status 过滤"),
show_hidden: bool = Query(False),
sort: str = Query("captured_desc"),
page: int = Query(1, ge=1),
size: int = Query(40, ge=1, le=200),
) -> ScreenshotListResp:
"""主列表查询:支持时间/分类/标签/收藏/状态/搜索词。"""
stmt = select(Screenshot).options(
selectinload(Screenshot.meta),
selectinload(Screenshot.tags),
)
filters = []
if not show_hidden:
filters.append(Screenshot.is_hidden == 0)
if category_id is not None:
filters.append(Screenshot.category_id == category_id)
if date_from is not None:
filters.append(Screenshot.captured_at >= date_from)
if date_to is not None:
filters.append(Screenshot.captured_at <= date_to)
if favorite is True:
filters.append(Screenshot.is_favorite == 1)
if status:
filters.append(Screenshot.ai_status == status)
if tag:
stmt = stmt.join(Screenshot.tags).where(Tag.name.ilike(f"%{tag}%"))
if q:
ids = collect_search_ids(session, q)
if not ids:
return ScreenshotListResp(items=[], total=0, page=page, size=size)
filters.append(Screenshot.id.in_(ids))
if filters:
stmt = stmt.where(and_(*filters))
# 排序
stmt = _apply_sort(stmt, sort)
total = session.scalar(select(func.count()).select_from(stmt.subquery())) or 0
rows = session.scalars(stmt.offset((page - 1) * size).limit(size)).unique().all()
cat_map = _category_map(session)
items = [_to_brief(r, cat_map) for r in rows]
return ScreenshotListResp(items=items, total=int(total), page=page, size=size)
def _apply_sort(stmt, sort: str):
"""列表排序:时间 / 导入 / 标题 / 文件大小。"""
if sort == "captured_asc":
return stmt.order_by(Screenshot.captured_at.asc())
if sort == "imported_desc":
return stmt.order_by(Screenshot.imported_at.desc())
if sort == "imported_asc":
return stmt.order_by(Screenshot.imported_at.asc())
if sort == "title_asc":
return stmt.outerjoin(ScreenshotMeta).order_by(
ScreenshotMeta.ai_title.asc().nulls_last()
)
if sort == "title_desc":
return stmt.outerjoin(ScreenshotMeta).order_by(
ScreenshotMeta.ai_title.desc().nulls_last()
)
if sort == "size_desc":
return stmt.order_by(Screenshot.size.desc())
if sort == "size_asc":
return stmt.order_by(Screenshot.size.asc())
return stmt.order_by(Screenshot.captured_at.desc())
@router.get("/random", response_model=list[ScreenshotBrief])
def random_screenshots(
session: Session = Depends(db_session),
n: int = Query(1, ge=1, le=20),
category_id: Optional[int] = Query(None),
) -> list[ScreenshotBrief]:
"""随机展示。"""
stmt = select(Screenshot).options(
selectinload(Screenshot.meta),
selectinload(Screenshot.tags),
).where(Screenshot.is_hidden == 0)
if category_id is not None:
stmt = stmt.where(Screenshot.category_id == category_id)
stmt = stmt.order_by(func.random()).limit(n)
rows = session.scalars(stmt).unique().all()
cat_map = _category_map(session)
return [_to_brief(r, cat_map) for r in rows]
@router.get("/stats")
def stats(session: Session = Depends(db_session)) -> dict:
"""汇总统计:总数、状态分布、按分类、按月份。"""
total = session.scalar(select(func.count(Screenshot.id))) or 0
by_status = {
st.value: session.scalar(
select(func.count(Screenshot.id)).where(Screenshot.ai_status == st.value)
)
or 0
for st in ProcessStatus
}
by_category_rows = session.execute(
select(Category.id, Category.name, Category.color, func.count(Screenshot.id))
.join(Screenshot, Screenshot.category_id == Category.id, isouter=True)
.group_by(Category.id)
.order_by(func.count(Screenshot.id).desc())
).all()
by_category = [
{"id": r[0], "name": r[1], "color": r[2], "count": int(r[3] or 0)}
for r in by_category_rows
]
by_month_rows = session.execute(
text(
"SELECT strftime('%Y-%m', captured_at) AS m, COUNT(1) AS c "
"FROM screenshots WHERE is_hidden=0 GROUP BY m ORDER BY m DESC LIMIT 36"
)
).all()
by_month = [{"month": r[0], "count": int(r[1])} for r in by_month_rows]
return {
"total": int(total),
"by_status": by_status,
"by_category": by_category,
"by_month": by_month,
"queue": _queue_summary(session),
}
def _queue_summary(session: Session) -> dict:
"""汇总 jobs 队列状态。"""
out: dict[str, int] = {}
for st in JobStatus:
out[st.value] = (
session.scalar(select(func.count(Job.id)).where(Job.status == st.value)) or 0
)
return out
@router.get("/{screenshot_id}", response_model=ScreenshotDetail)
def get_screenshot(
screenshot_id: int,
session: Session = Depends(db_session),
) -> ScreenshotDetail:
"""单张详情。"""
shot = session.get(Screenshot, screenshot_id)
if shot is None:
raise HTTPException(404, "Screenshot not found")
cat_map = _category_map(session)
brief = _to_brief(shot, cat_map)
meta = shot.meta
todos = [TodoBrief.model_validate(t) for t in shot.todos]
return ScreenshotDetail(
**brief.model_dump(),
file_url=f"/api/screenshots/{shot.id}/file",
size=shot.size,
ocr_text=(meta.ocr_text if meta else None),
ai_summary=(meta.ai_summary if meta else None),
ai_suggestion=(meta.ai_suggestion if meta else None),
todos=todos,
)
@router.patch("/{screenshot_id}", response_model=ScreenshotDetail)
def update_screenshot(
screenshot_id: int,
payload: ScreenshotUpdate,
session: Session = Depends(db_session),
) -> ScreenshotDetail:
"""前端编辑:分类、收藏、隐藏、标签。"""
shot = session.get(Screenshot, screenshot_id)
if shot is None:
raise HTTPException(404, "Screenshot not found")
# 用 model_fields_set 区分「未传字段」与「显式传入 null」
# 这样前端 PATCH {"category_id": null} 可以真正清空分类
fields = payload.model_fields_set
if "category_id" in fields:
if payload.category_id is not None:
cat = session.get(Category, payload.category_id)
if cat is None:
raise HTTPException(400, "category not found")
shot.category_id = payload.category_id
if "is_favorite" in fields and payload.is_favorite is not None:
shot.is_favorite = 1 if payload.is_favorite else 0
if "is_hidden" in fields and payload.is_hidden is not None:
shot.is_hidden = 1 if payload.is_hidden else 0
if "tags" in fields and payload.tags is not None:
tag_objs = []
for name in payload.tags:
name = (name or "").strip()[:64]
if not name:
continue
tag = session.scalar(select(Tag).where(Tag.name == name))
if tag is None:
tag = Tag(name=name)
session.add(tag)
session.flush()
tag_objs.append(tag)
shot.tags = tag_objs
session.commit()
session.refresh(shot)
return get_screenshot(screenshot_id, session)
@router.post("/{screenshot_id}/reanalyze")
def reanalyze(
screenshot_id: int,
session: Session = Depends(db_session),
) -> dict:
"""加入队列重新分析。"""
shot = session.get(Screenshot, screenshot_id)
if shot is None:
raise HTTPException(404, "Screenshot not found")
shot.ai_status = ProcessStatus.PENDING.value
shot.ocr_status = ProcessStatus.PENDING.value
job = Job(screenshot_id=shot.id, kind=JobKind.FULL.value, status=JobStatus.PENDING.value)
session.add(job)
session.commit()
# 同步路由跑在线程池,必须 threadsafe 唤醒事件循环
worker.notify_threadsafe()
return {"ok": True, "job_id": job.id}
@router.post("/{screenshot_id}/reocr")
def reocr(
screenshot_id: int,
session: Session = Depends(db_session),
) -> dict:
"""仅补跑 OCR,不重新调用 AI 分析。"""
shot = session.get(Screenshot, screenshot_id)
if shot is None:
raise HTTPException(404, "Screenshot not found")
active = session.scalar(
select(Job.id).where(
Job.screenshot_id == shot.id,
Job.kind == JobKind.OCR.value,
Job.status.in_((JobStatus.PENDING.value, JobStatus.RUNNING.value)),
)
)
if active is not None:
return {"ok": True, "job_id": active, "message": "已有 OCR 任务在队列中"}
shot.ocr_status = ProcessStatus.PENDING.value
job = Job(
screenshot_id=shot.id,
kind=JobKind.OCR.value,
status=JobStatus.PENDING.value,
)
session.add(job)
session.commit()
worker.notify_threadsafe()
return {"ok": True, "job_id": job.id}
@router.delete("/{screenshot_id}")
def delete_screenshot(
screenshot_id: int,
session: Session = Depends(db_session),
) -> dict:
"""删除记录(不删除原始文件)。"""
shot = session.get(Screenshot, screenshot_id)
if shot is None:
raise HTTPException(404, "Screenshot not found")
session.delete(shot)
session.commit()
return {"ok": True}
@router.get("/{screenshot_id}/file")
def get_file(
screenshot_id: int,
session: Session = Depends(db_session),
) -> FileResponse:
"""原图文件流。"""
shot = session.get(Screenshot, screenshot_id)
if shot is None:
raise HTTPException(404, "Screenshot not found")
p = path_from_storage(shot.path)
if not is_accessible_file(p):
raise HTTPException(404, "file missing")
return FileResponse(str(p))
@router.get("/{screenshot_id}/thumb")
def get_thumb(
screenshot_id: int,
session: Session = Depends(db_session),
) -> FileResponse:
"""缩略图流。"""
shot = session.get(Screenshot, screenshot_id)
if shot is None:
raise HTTPException(404, "Screenshot not found")
if shot.thumb_path:
p = Path(shot.thumb_path)
if p.exists():
return FileResponse(str(p), media_type="image/webp")
# 兜底:返回原图
p = path_from_storage(shot.path)
if is_accessible_file(p):
return FileResponse(str(p))
raise HTTPException(404, "thumb missing")
+220
View File
@@ -0,0 +1,220 @@
"""设置接口:Provider 配置、分类、Tag。"""
from __future__ import annotations
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from app.api.deps import db_session
from app.models.category import Category
from app.models.setting import (
DEFAULT_RECOGNITION_MODE,
KEY_OCR_PROVIDER,
KEY_RECOGNITION_MODE,
KEY_VLM_PROVIDER,
)
from app.models.screenshot import Screenshot
from app.models.tag import Tag
from app.providers import RECOGNITION_MODES
from app.schemas.common import (
CategoryIn,
ProviderConfig,
ProviderConfigOut,
ProviderTestResult,
RecognitionModeIn,
)
from app.services.provider_test import merge_provider_api_key, test_provider_config
from app.services.settings_store import all_settings, get_setting, set_setting
router = APIRouter(prefix="/api/settings", tags=["settings"])
@router.get("")
def get_all(session: Session = Depends(db_session)) -> dict:
"""返回所有非敏感设置。api_key 字段做脱敏。"""
raw = all_settings(session)
for key in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
cfg = raw.get(key)
if isinstance(cfg, dict) and cfg.get("api_key"):
cfg["api_key_mask"] = _mask(cfg["api_key"])
cfg["api_key"] = ""
return raw
def _mask(value: str) -> str:
if not value:
return ""
if len(value) <= 6:
return "*" * len(value)
return value[:3] + "*" * (len(value) - 6) + value[-3:]
@router.get("/providers/{key}", response_model=ProviderConfigOut | None)
def get_provider(
key: str,
session: Session = Depends(db_session),
) -> ProviderConfigOut | None:
"""读取 Provider 配置:api_key 明文不外传,只给一个掩码用于 UI 提示。"""
if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
raise HTTPException(400, "key must be ocr_provider or vlm_provider")
raw = get_setting(session, key, None)
if not raw:
return None
mask = _mask(raw.get("api_key", "") or "")
return ProviderConfigOut(
type=raw.get("type", ""),
base_url=raw.get("base_url"),
api_key="",
api_key_mask=mask or None,
model=raw.get("model"),
extra=raw.get("extra", {}) or {},
)
@router.put("/providers/{key}")
def put_provider(
key: str,
cfg: ProviderConfig,
session: Session = Depends(db_session),
) -> dict:
if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
raise HTTPException(400, "key must be ocr_provider or vlm_provider")
# 如果客户端没有传新的 api_key(空字符串),保留旧值
existing = get_setting(session, key, None)
payload = cfg.model_dump()
if (not payload.get("api_key")) and isinstance(existing, dict):
payload["api_key"] = existing.get("api_key", "")
set_setting(session, key, payload)
session.commit()
return {"ok": True}
@router.post("/providers/{key}/test", response_model=ProviderTestResult)
async def test_provider(
key: str,
cfg: ProviderConfig,
session: Session = Depends(db_session),
) -> ProviderTestResult:
"""测试 OCR / 视觉 AI Provider 连通性(使用当前表单配置,api_key 可留空沿用已保存值)。"""
if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
raise HTTPException(400, "key must be ocr_provider or vlm_provider")
existing = get_setting(session, key, None)
merged = merge_provider_api_key(cfg, existing if isinstance(existing, dict) else None)
result = await test_provider_config(key, merged)
return ProviderTestResult(**result)
@router.get("/recognition-mode")
def get_recognition_mode(session: Session = Depends(db_session)) -> dict:
"""读取文字识别策略:ocr / vision / hybrid。"""
mode = get_setting(session, KEY_RECOGNITION_MODE, DEFAULT_RECOGNITION_MODE)
if mode not in RECOGNITION_MODES:
mode = DEFAULT_RECOGNITION_MODE
return {"mode": mode, "options": list(RECOGNITION_MODES)}
@router.put("/recognition-mode")
def put_recognition_mode(
payload: RecognitionModeIn,
session: Session = Depends(db_session),
) -> dict:
"""保存文字识别策略。"""
if payload.mode not in RECOGNITION_MODES:
raise HTTPException(400, f"mode must be one of {RECOGNITION_MODES}")
set_setting(session, KEY_RECOGNITION_MODE, payload.mode)
session.commit()
return {"ok": True, "mode": payload.mode}
@router.get("/categories")
def list_categories(session: Session = Depends(db_session)) -> list[dict]:
rows = session.scalars(select(Category).order_by(Category.id)).all()
return [
{"id": c.id, "name": c.name, "color": c.color, "prompt_hint": c.prompt_hint}
for c in rows
]
@router.post("/categories")
def create_category(
payload: CategoryIn,
session: Session = Depends(db_session),
) -> dict:
exists = session.scalar(select(Category).where(Category.name == payload.name))
if exists is not None:
raise HTTPException(400, "category exists")
cat = Category(name=payload.name, color=payload.color, prompt_hint=payload.prompt_hint)
session.add(cat)
session.commit()
session.refresh(cat)
return {"id": cat.id}
@router.patch("/categories/{cat_id}")
def update_category(
cat_id: int,
payload: CategoryIn,
session: Session = Depends(db_session),
) -> dict:
cat = session.get(Category, cat_id)
if cat is None:
raise HTTPException(404, "category not found")
cat.name = payload.name
cat.color = payload.color
cat.prompt_hint = payload.prompt_hint
session.commit()
return {"ok": True}
@router.delete("/categories/{cat_id}")
def delete_category(
cat_id: int,
session: Session = Depends(db_session),
) -> dict:
cat = session.get(Category, cat_id)
if cat is None:
raise HTTPException(404, "category not found")
session.delete(cat)
session.commit()
return {"ok": True}
@router.get("/tags")
def list_tags(
session: Session = Depends(db_session),
q: Optional[str] = Query(None, description="标签名关键词"),
page: int = Query(1, ge=1),
size: int = Query(200, ge=1, le=500),
sort: str = Query("count_desc", description="count_desc|count_asc|name_asc|name_desc"),
) -> dict:
"""标签列表(含使用次数),支持搜索与分页。"""
base = select(Tag.id)
if q:
base = base.where(Tag.name.ilike(f"%{q.strip()}%"))
total = session.scalar(select(func.count()).select_from(base.subquery())) or 0
stmt = (
select(Tag.id, Tag.name, Tag.color, func.count(Screenshot.id))
.join(Tag.screenshots, isouter=True)
.group_by(Tag.id)
)
if q:
stmt = stmt.where(Tag.name.ilike(f"%{q.strip()}%"))
if sort == "count_asc":
stmt = stmt.order_by(func.count(Screenshot.id).asc())
elif sort == "name_asc":
stmt = stmt.order_by(Tag.name.asc())
elif sort == "name_desc":
stmt = stmt.order_by(Tag.name.desc())
else:
stmt = stmt.order_by(func.count(Screenshot.id).desc())
rows = session.execute(stmt.offset((page - 1) * size).limit(size)).all()
items = [
{"id": r[0], "name": r[1], "color": r[2], "count": int(r[3] or 0)} for r in rows
]
return {"items": items, "total": int(total), "page": page, "size": size}
+106
View File
@@ -0,0 +1,106 @@
"""待办清单接口。"""
from __future__ import annotations
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import Session
from app.api.deps import db_session
from app.models.todo import Todo, TodoStatus
from app.schemas.common import TodoUpdate
from app.schemas.screenshot import TodoBrief, TodoListResp
router = APIRouter(prefix="/api/todos", tags=["todos"])
def _todo_filters(
status: Optional[str],
kind: Optional[str],
q: Optional[str],
) -> list:
"""构建待办筛选条件。"""
filters = []
if status:
filters.append(Todo.status == status)
if kind:
filters.append(Todo.kind == kind)
if q:
like = f"%{q.strip()}%"
filters.append(or_(Todo.title.ilike(like), Todo.note.ilike(like)))
return filters
@router.get("", response_model=TodoListResp)
def list_todos(
session: Session = Depends(db_session),
status: Optional[str] = Query(None),
kind: Optional[str] = Query(None),
q: Optional[str] = Query(None, description="标题/备注关键词"),
page: int = Query(1, ge=1),
size: int = Query(50, ge=1, le=200),
) -> TodoListResp:
"""按状态/类型/关键词分页查询。"""
filters = _todo_filters(status, kind, q)
base = select(Todo)
if filters:
base = base.where(and_(*filters))
total = session.scalar(select(func.count()).select_from(base.subquery())) or 0
rows = session.scalars(
base.order_by(Todo.created_at.desc()).offset((page - 1) * size).limit(size)
).all()
return TodoListResp(
items=[TodoBrief.model_validate(r) for r in rows],
total=int(total),
page=page,
size=size,
)
@router.get("/summary")
def summary(session: Session = Depends(db_session)) -> dict:
"""各状态待办数量。"""
return {
st.value: session.scalar(select(func.count(Todo.id)).where(Todo.status == st.value)) or 0
for st in TodoStatus
}
@router.patch("/{todo_id}", response_model=TodoBrief)
def update_todo(
todo_id: int,
payload: TodoUpdate,
session: Session = Depends(db_session),
) -> TodoBrief:
"""更新状态/标题/备注。"""
todo = session.get(Todo, todo_id)
if todo is None:
raise HTTPException(404, "Todo not found")
if payload.status is not None:
todo.status = payload.status
if payload.status == TodoStatus.DONE.value:
todo.completed_at = datetime.utcnow()
if payload.title is not None:
todo.title = payload.title
if payload.note is not None:
todo.note = payload.note
session.commit()
session.refresh(todo)
return TodoBrief.model_validate(todo)
@router.delete("/{todo_id}")
def delete_todo(
todo_id: int,
session: Session = Depends(db_session),
) -> dict:
todo = session.get(Todo, todo_id)
if todo is None:
raise HTTPException(404, "Todo not found")
session.delete(todo)
session.commit()
return {"ok": True}
+256
View File
@@ -0,0 +1,256 @@
"""监听目录的增删改、手动导入、分析队列。"""
from __future__ import annotations
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
from sqlalchemy import func, select
from sqlalchemy.orm import Session, selectinload
from app.api.deps import db_session
from app.core.db import session_scope
from app.core.path_utils import (
count_files_sample,
is_accessible_dir,
normalize_user_path,
)
from app.models.job import Job, JobStatus
from app.models.screenshot import ProcessStatus, Screenshot
from app.models.watch_folder import WatchFolder
from app.schemas.common import WatchFolderIn, WatchFolderOut
from app.schemas.job import JobListResp, JobOut, JobRetryIn
from app.services.analyze import enqueue_ocr_jobs
from app.services.ingest import ingest_directory
from app.services.watcher import watcher_service
from app.services.worker import worker
router = APIRouter(prefix="/api/watch", tags=["watch"])
def _validate_folder_path(raw: str) -> str:
"""校验并规范化监听目录路径(含 UNC 网络路径)。"""
normalized = normalize_user_path(raw)
if not normalized:
raise HTTPException(400, "路径不能为空")
if not is_accessible_dir(normalized):
raise HTTPException(
400,
f"目录不存在或无法访问: {normalized}"
"请确认 NAS 已挂载、有读权限,UNC 路径形如 \\\\服务器\\共享\\文件夹",
)
return normalized
@router.get("/folders", response_model=list[WatchFolderOut])
def list_folders(session: Session = Depends(db_session)) -> list[WatchFolderOut]:
rows = session.scalars(select(WatchFolder).order_by(WatchFolder.id)).all()
return [WatchFolderOut.model_validate(r) for r in rows]
@router.post("/folders", response_model=WatchFolderOut)
def add_folder(
payload: WatchFolderIn,
background: BackgroundTasks,
session: Session = Depends(db_session),
) -> WatchFolderOut:
"""新增监听目录,自动触发一次扫描入库。"""
normalized = _validate_folder_path(payload.path)
exists = session.scalar(select(WatchFolder).where(WatchFolder.path == normalized))
if exists is not None:
raise HTTPException(400, "目录已存在")
folder = WatchFolder(
path=normalized,
enabled=payload.enabled,
recursive=payload.recursive,
is_sensitive=payload.is_sensitive,
)
session.add(folder)
session.commit()
session.refresh(folder)
watcher_service.reload()
background.add_task(_scan_folder, normalized, payload.recursive)
return WatchFolderOut.model_validate(folder)
@router.patch("/folders/{folder_id}", response_model=WatchFolderOut)
def update_folder(
folder_id: int,
payload: WatchFolderIn,
session: Session = Depends(db_session),
) -> WatchFolderOut:
folder = session.get(WatchFolder, folder_id)
if folder is None:
raise HTTPException(404, "folder not found")
normalized = _validate_folder_path(payload.path)
folder.path = normalized
folder.enabled = payload.enabled
folder.recursive = payload.recursive
folder.is_sensitive = payload.is_sensitive
session.commit()
session.refresh(folder)
watcher_service.reload()
return WatchFolderOut.model_validate(folder)
@router.delete("/folders/{folder_id}")
def delete_folder(
folder_id: int,
session: Session = Depends(db_session),
) -> dict:
folder = session.get(WatchFolder, folder_id)
if folder is None:
raise HTTPException(404, "folder not found")
session.delete(folder)
session.commit()
watcher_service.reload()
return {"ok": True}
@router.post("/import")
def import_now(
payload: WatchFolderIn,
background: BackgroundTasks,
) -> dict:
"""手动触发一次目录扫描(不一定要登记为监听)。"""
normalized = _validate_folder_path(payload.path)
background.add_task(_scan_folder, normalized, payload.recursive)
return {"ok": True, "message": "已在后台扫描"}
@router.post("/validate-path")
def validate_path(payload: WatchFolderIn) -> dict:
"""测试目录是否可访问(含 UNC 网络路径),返回抽样文件数。"""
normalized = _validate_folder_path(payload.path)
total, samples = count_files_sample(normalized, limit=3)
return {
"ok": True,
"path": normalized,
"sample_image_count": total,
"samples": samples,
"message": f"目录可访问,抽样发现约 {total}+ 张图片",
}
def _scan_folder(path: str, recursive: bool) -> None:
"""后台任务:扫描目录入库,再通知 worker。
BackgroundTasks 的同步函数运行在线程池中,必须用 threadsafe 入口
唤醒事件循环,否则 asyncio.Event.set() 会有竞态。
"""
with session_scope() as session:
ingest_directory(session, path, recursive=recursive)
worker.notify_threadsafe()
@router.get("/queue")
async def queue_status() -> dict:
"""读取 worker 队列状态。"""
counts = await worker.status()
with session_scope() as session:
counts["ocr_retryable"] = (
session.scalar(
select(func.count(Screenshot.id)).where(
Screenshot.ocr_status == ProcessStatus.FAILED.value,
Screenshot.ai_status == ProcessStatus.DONE.value,
)
)
or 0
)
counts["ocr_pending"] = (
session.scalar(
select(func.count(Job.id)).where(
Job.kind == "ocr",
Job.status == JobStatus.PENDING.value,
)
)
or 0
)
return counts
def _job_to_out(job: Job, shot: Screenshot | None) -> JobOut:
"""ORM -> JobOut,附带截图摘要字段。"""
return JobOut(
id=job.id,
screenshot_id=job.screenshot_id,
kind=job.kind,
status=job.status,
retries=job.retries or 0,
last_error=job.last_error,
created_at=job.created_at,
started_at=job.started_at,
finished_at=job.finished_at,
thumb_url=(
f"/api/screenshots/{shot.id}/thumb" if shot and shot.thumb_path else None
),
path=shot.path if shot else None,
ai_title=(shot.meta.ai_title if shot and shot.meta else None),
ai_status=shot.ai_status if shot else None,
ocr_status=shot.ocr_status if shot else None,
)
@router.get("/jobs", response_model=JobListResp)
def list_jobs(
status: Optional[str] = Query(None, description="pending|running|done|failed"),
kind: Optional[str] = Query(None, description="full|ocr|vlm"),
page: int = Query(1, ge=1),
size: int = Query(50, ge=1, le=200),
session: Session = Depends(db_session),
) -> JobListResp:
"""分页列出分析任务,默认按 id 倒序(最新的在前)。"""
if status and status not in {s.value for s in JobStatus}:
raise HTTPException(400, f"无效 status: {status}")
base = select(Job)
if status:
base = base.where(Job.status == status)
if kind:
base = base.where(Job.kind == kind)
total = session.scalar(select(func.count()).select_from(base.subquery())) or 0
jobs = session.scalars(
base.order_by(Job.id.desc()).offset((page - 1) * size).limit(size)
).all()
shot_ids = [j.screenshot_id for j in jobs]
shots: dict[int, Screenshot] = {}
if shot_ids:
rows = session.scalars(
select(Screenshot)
.where(Screenshot.id.in_(shot_ids))
.options(selectinload(Screenshot.meta))
).all()
shots = {s.id: s for s in rows}
items = [_job_to_out(j, shots.get(j.screenshot_id)) for j in jobs]
return JobListResp(items=items, total=total, page=page, size=size)
@router.post("/jobs/retry-failed")
def retry_failed_jobs(payload: JobRetryIn | None = None) -> dict:
"""将全部或指定 failed 任务重新排队。"""
job_ids = payload.job_ids if payload else None
count = worker.retry_failed(job_ids)
return {"ok": True, "count": count}
@router.post("/jobs/reset-stale")
def reset_stale_jobs(
minutes: int = Query(5, ge=1, le=1440),
reset_all: bool = Query(False, description="为 true 时复位全部 RUNNING"),
) -> dict:
"""复位僵尸 RUNNING 任务(worker 崩溃或未正常 finish 时)。"""
count = worker.reset_stale_running(minutes=minutes, reset_all=reset_all)
return {"ok": True, "count": count}
@router.post("/jobs/enqueue-ocr-failed")
def enqueue_ocr_failed(limit: int = Query(500, ge=1, le=5000)) -> dict:
"""为 AI 已成功但 OCR 失败的截图批量创建 OCR 补跑任务。"""
count = enqueue_ocr_jobs(limit=limit)
if count:
worker.notify()
return {"ok": True, "count": count}
View File
+66
View File
@@ -0,0 +1,66 @@
"""全局配置:路径、数据库、并发参数等。"""
from __future__ import annotations
import os
from pathlib import Path
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
# 默认数据目录:放在 backend/.data 下,便于零配置启动
_BACKEND_ROOT = Path(__file__).resolve().parents[2]
_DEFAULT_DATA_DIR = _BACKEND_ROOT / ".data"
class Settings(BaseSettings):
"""读取 .env 与环境变量的全局配置。"""
model_config = SettingsConfigDict(
env_file=str(_BACKEND_ROOT / ".env"),
env_file_encoding="utf-8",
extra="ignore",
)
# 应用基础
app_name: str = "snapAna"
debug: bool = False
host: str = "127.0.0.1"
port: int = 8765
# 数据目录
data_dir: Path = Field(default=_DEFAULT_DATA_DIR)
# 任务并发
analyze_concurrency: int = 4
max_retries: int = 3
# 缩略图
thumb_size: int = 320
vlm_max_side: int = 1280 # 上传 VLM 前压缩的长边像素
# CORS
cors_origins: list[str] = ["http://localhost:5173", "http://127.0.0.1:5173"]
@property
def db_path(self) -> Path:
"""SQLite 数据库文件路径。"""
return self.data_dir / "snapana.db"
@property
def db_url(self) -> str:
"""SQLAlchemy 连接串。"""
return f"sqlite:///{self.db_path.as_posix()}"
@property
def thumb_dir(self) -> Path:
"""缩略图缓存目录。"""
return self.data_dir / "thumbs"
def ensure_dirs(self) -> None:
"""确保所有运行期目录存在。"""
self.data_dir.mkdir(parents=True, exist_ok=True)
self.thumb_dir.mkdir(parents=True, exist_ok=True)
settings = Settings()
settings.ensure_dirs()
+153
View File
@@ -0,0 +1,153 @@
"""数据库引擎、会话与初始化。
使用 SQLAlchemy 2.0 + SQLite。FTS5 虚拟表通过原生 SQL 创建,并配套触发器
让 OCR/AI 字段更新时自动同步到全文索引。
"""
from __future__ import annotations
from contextlib import contextmanager
from typing import Iterator
from sqlalchemy import create_engine, event, text
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
from app.core.config import settings
class Base(DeclarativeBase):
"""全局声明性 Base。"""
engine = create_engine(
settings.db_url,
echo=False,
future=True,
connect_args={"check_same_thread": False},
)
@event.listens_for(engine, "connect")
def _sqlite_pragmas(dbapi_connection, _connection_record):
"""启用外键、WAL、忙等待等 SQLite 优化项。"""
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA busy_timeout=5000")
cursor.close()
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True)
def get_session() -> Iterator[Session]:
"""FastAPI 依赖注入:每个请求一个会话。"""
with SessionLocal() as session:
yield session
@contextmanager
def session_scope() -> Iterator[Session]:
"""常规上下文管理:自动 commit/rollback。"""
session = SessionLocal()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
# FTS5 虚拟表与触发器 SQL(独立维护,便于以后调整字段)
_FTS_SCHEMA_SQL = [
"""
CREATE VIRTUAL TABLE IF NOT EXISTS screenshots_fts
USING fts5(
ocr_text,
ai_title,
ai_summary,
ai_suggestion,
content='screenshot_meta',
content_rowid='screenshot_id',
tokenize='unicode61'
);
""",
"""
CREATE TRIGGER IF NOT EXISTS screenshot_meta_ai
AFTER INSERT ON screenshot_meta BEGIN
INSERT INTO screenshots_fts(rowid, ocr_text, ai_title, ai_summary, ai_suggestion)
VALUES (new.screenshot_id,
coalesce(new.ocr_text, ''),
coalesce(new.ai_title, ''),
coalesce(new.ai_summary, ''),
coalesce(new.ai_suggestion, ''));
END;
""",
"""
CREATE TRIGGER IF NOT EXISTS screenshot_meta_ad
AFTER DELETE ON screenshot_meta BEGIN
INSERT INTO screenshots_fts(screenshots_fts, rowid, ocr_text, ai_title, ai_summary, ai_suggestion)
VALUES('delete', old.screenshot_id,
coalesce(old.ocr_text, ''),
coalesce(old.ai_title, ''),
coalesce(old.ai_summary, ''),
coalesce(old.ai_suggestion, ''));
END;
""",
"""
CREATE TRIGGER IF NOT EXISTS screenshot_meta_au
AFTER UPDATE ON screenshot_meta BEGIN
INSERT INTO screenshots_fts(screenshots_fts, rowid, ocr_text, ai_title, ai_summary, ai_suggestion)
VALUES('delete', old.screenshot_id,
coalesce(old.ocr_text, ''),
coalesce(old.ai_title, ''),
coalesce(old.ai_summary, ''),
coalesce(old.ai_suggestion, ''));
INSERT INTO screenshots_fts(rowid, ocr_text, ai_title, ai_summary, ai_suggestion)
VALUES (new.screenshot_id,
coalesce(new.ocr_text, ''),
coalesce(new.ai_title, ''),
coalesce(new.ai_summary, ''),
coalesce(new.ai_suggestion, ''));
END;
""",
]
def init_db() -> None:
"""启动时建表并装配 FTS5、灌入默认分类。"""
from app.models import register_all # noqa: F401
register_all()
Base.metadata.create_all(engine)
with engine.begin() as conn:
for stmt in _FTS_SCHEMA_SQL:
conn.execute(text(stmt))
_migrate_legacy_schema(conn)
# 启动期 seed 默认分类(即使首次启动也能在「设置」/筛选页看到分类)
from app.services.analyze import ensure_default_categories
ensure_default_categories()
def _migrate_legacy_schema(conn) -> None:
"""轻量迁移:旧版本的 screenshots.category_id 没有外键。
SQLite 不支持 ALTER TABLE 加外键,但删除分类时 ON DELETE SET NULL 失效
会导致悬空引用。检测到旧表时,主动用一次性 SQL 清理掉无效引用并打日志,
建议用户用「分类管理」页重建索引。
"""
pragma_rows = conn.execute(
text("PRAGMA foreign_key_list(screenshots)")
).fetchall()
has_cat_fk = any(row[2] == "categories" for row in pragma_rows)
if not has_cat_fk:
# 清理悬空 category_id,避免列表统计出错
conn.execute(
text(
"UPDATE screenshots SET category_id = NULL "
"WHERE category_id IS NOT NULL "
"AND category_id NOT IN (SELECT id FROM categories)"
)
)
+25
View File
@@ -0,0 +1,25 @@
"""统一日志配置。"""
from __future__ import annotations
import logging
import sys
def setup_logging(debug: bool = False) -> None:
"""初始化根 logger 的格式与级别。"""
level = logging.DEBUG if debug else logging.INFO
fmt = "%(asctime)s | %(levelname)-7s | %(name)s | %(message)s"
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter(fmt))
root = logging.getLogger()
root.handlers.clear()
root.addHandler(handler)
root.setLevel(level)
# 降低第三方库噪音
for noisy in ("watchdog", "httpx", "PIL"):
logging.getLogger(noisy).setLevel(logging.WARNING)
def get_logger(name: str) -> logging.Logger:
"""统一入口获取 logger。"""
return logging.getLogger(name)
+102
View File
@@ -0,0 +1,102 @@
"""跨平台路径工具:重点兼容 Windows UNC 网络路径(\\\\NAS\\share\\...)。"""
from __future__ import annotations
import os
import sys
from pathlib import Path, PureWindowsPath
def normalize_user_path(raw: str) -> str:
"""规范化用户输入的路径,保留 UNC 反斜杠格式。
示例:
- \\\\JIULUGNAS\\personal_folder\\Photos -> 原样保留
- //JIULUGNAS/personal_folder/Photos -> 转为 UNC
- D:/Pictures/Screenshots -> D:\\Pictures\\Screenshots
"""
raw = (raw or "").strip().strip('"').strip("'")
if not raw:
return raw
if sys.platform == "win32":
# //server/share -> \\server\share
if raw.startswith("//") and not raw.startswith("///"):
raw = "\\\\" + raw.lstrip("/").replace("/", "\\")
elif raw.startswith("\\\\"):
pass
else:
raw = raw.replace("/", "\\")
return str(PureWindowsPath(raw))
return str(Path(raw).expanduser())
def path_from_storage(stored: str) -> Path:
"""从数据库读出的路径转为 Path(修复历史 as_posix 导致的 //NAS/...)。"""
if not stored:
return Path(stored)
if sys.platform == "win32":
# 历史数据://JIULUGNAS/foo/bar -> \\JIULUGNAS\foo\bar
if stored.startswith("//") and not stored.startswith("///"):
stored = "\\\\" + stored.lstrip("/").replace("/", "\\")
return Path(stored)
def path_to_storage(path: Path | str) -> str:
"""写入数据库 / 比较用的路径字符串;Windows 下保留反斜杠。"""
if isinstance(path, Path):
if sys.platform == "win32":
return str(path)
return path.as_posix()
return normalize_user_path(str(path)) if sys.platform == "win32" else str(path)
def is_accessible_dir(path: str | Path) -> bool:
"""目录是否可访问(UNC / 本地均适用)。"""
try:
return os.path.isdir(str(path))
except OSError:
return False
def is_accessible_file(path: str | Path) -> bool:
"""文件是否可访问。"""
try:
return os.path.isfile(str(path))
except OSError:
return False
def path_is_under(parent: str | Path, child: str | Path) -> bool:
"""判断 child 是否在 parent 目录下(用于敏感目录检测)。"""
try:
parent_norm = os.path.normcase(os.path.normpath(str(parent)))
child_norm = os.path.normcase(os.path.normpath(str(child)))
if not parent_norm.endswith(os.sep):
parent_norm += os.sep
return child_norm.startswith(parent_norm) or child_norm == parent_norm.rstrip(os.sep)
except OSError:
return False
def count_files_sample(root: str | Path, limit: int = 5) -> tuple[int, list[str]]:
"""快速抽样统计目录下图片数量(网络路径可能较慢,limit 控制遍历深度)。"""
from app.services.thumbnail import is_supported
root_p = path_from_storage(str(root)) if isinstance(root, str) else root
total = 0
samples: list[str] = []
try:
for dirpath, _, filenames in os.walk(str(root_p)):
for name in filenames:
p = Path(dirpath) / name
if not is_supported(p):
continue
total += 1
if len(samples) < limit:
samples.append(path_to_storage(p))
if total >= 1000:
break
except OSError:
pass
return total, samples
+60
View File
@@ -0,0 +1,60 @@
"""FastAPI 应用入口。"""
from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api import screenshots, settings_api, todos, watch
from app.core.config import settings
from app.core.db import init_db
from app.core.logger import get_logger, setup_logging
from app.services.watcher import watcher_service
from app.services.worker import worker
setup_logging(settings.debug)
logger = get_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI): # noqa: ARG001
"""启动时初始化 DB、启动监听器与分析 worker。"""
init_db()
loop = asyncio.get_running_loop()
async def notify() -> None:
worker.notify()
watcher_service.start(loop, notify)
await worker.start()
logger.info("snapAna 启动完成 @ http://%s:%d", settings.host, settings.port)
try:
yield
finally:
watcher_service.stop()
await worker.stop()
app = FastAPI(title="snapAna", version="0.1.0", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(screenshots.router)
app.include_router(todos.router)
app.include_router(settings_api.router)
app.include_router(watch.router)
@app.get("/api/health")
def health() -> dict:
"""健康检查。"""
return {"status": "ok", "version": "0.1.0"}
+13
View File
@@ -0,0 +1,13 @@
"""SQLAlchemy 模型集中注册入口。"""
def register_all() -> None:
"""显式导入以触发模型注册到 Base.metadata。"""
from . import screenshot # noqa: F401
from . import meta # noqa: F401
from . import tag # noqa: F401
from . import category # noqa: F401
from . import todo # noqa: F401
from . import job # noqa: F401
from . import watch_folder # noqa: F401
from . import setting # noqa: F401
+32
View File
@@ -0,0 +1,32 @@
"""截图分类。预置常见类目,AI 命中即可写回。"""
from __future__ import annotations
from sqlalchemy import Integer, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from app.core.db import Base
class Category(Base):
"""截图分类。"""
__tablename__ = "categories"
__table_args__ = (UniqueConstraint("name", name="uq_categories_name"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(String(64), nullable=False)
color: Mapped[str | None] = mapped_column(String(16), nullable=True)
prompt_hint: Mapped[str | None] = mapped_column(Text, nullable=True)
# 首次启动时灌入的默认分类
DEFAULT_CATEGORIES: list[dict[str, str | None]] = [
{"name": "知识技术", "color": "#3b82f6", "prompt_hint": "技术文章、代码、教程、文档截图"},
{"name": "梗图幽默", "color": "#f59e0b", "prompt_hint": "搞笑图、表情包、梗图"},
{"name": "小说文字", "color": "#8b5cf6", "prompt_hint": "长段文字、小说阅读、电子书"},
{"name": "聊天记录", "color": "#10b981", "prompt_hint": "微信/QQ/Slack 等聊天截图"},
{"name": "UI 设计", "color": "#ec4899", "prompt_hint": "界面设计、网页/App 灵感参考"},
{"name": "生活记录", "color": "#22c55e", "prompt_hint": "日常照片、生活记录、票据"},
{"name": "购物商品", "color": "#ef4444", "prompt_hint": "商品截图、价格、订单"},
{"name": "其他", "color": "#6b7280", "prompt_hint": "无法明确归类"},
]
+54
View File
@@ -0,0 +1,54 @@
"""分析任务队列:持久化到 SQLite,断电可恢复。"""
from __future__ import annotations
from datetime import datetime
from enum import Enum
from sqlalchemy import DateTime, ForeignKey, Index, Integer, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column
from app.core.db import Base
class JobKind(str, Enum):
"""任务种类。"""
OCR = "ocr"
VLM = "vlm"
FULL = "full" # OCR + VLM 一条龙
class JobStatus(str, Enum):
"""任务运行状态。"""
PENDING = "pending"
RUNNING = "running"
DONE = "done"
FAILED = "failed"
class Job(Base):
"""单条分析任务记录。"""
__tablename__ = "jobs"
__table_args__ = (
Index("ix_jobs_status", "status"),
Index("ix_jobs_kind_status", "kind", "status"),
)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
screenshot_id: Mapped[int] = mapped_column(
Integer,
ForeignKey("screenshots.id", ondelete="CASCADE"),
nullable=False,
)
kind: Mapped[str] = mapped_column(String(16), default=JobKind.FULL.value)
status: Mapped[str] = mapped_column(String(16), default=JobStatus.PENDING.value)
retries: Mapped[int] = mapped_column(Integer, default=0)
last_error: Mapped[str | None] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.now(), nullable=False
)
started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
finished_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
+27
View File
@@ -0,0 +1,27 @@
"""截图的 OCR / AI 元信息。与 screenshot 1:1。"""
from __future__ import annotations
from sqlalchemy import ForeignKey, Integer, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.core.db import Base
class ScreenshotMeta(Base):
"""OCR 文本 + AI 结构化结果。"""
__tablename__ = "screenshot_meta"
screenshot_id: Mapped[int] = mapped_column(
Integer,
ForeignKey("screenshots.id", ondelete="CASCADE"),
primary_key=True,
)
ocr_text: Mapped[str | None] = mapped_column(Text, nullable=True)
ai_title: Mapped[str | None] = mapped_column(Text, nullable=True)
ai_summary: Mapped[str | None] = mapped_column(Text, nullable=True)
ai_suggestion: Mapped[str | None] = mapped_column(Text, nullable=True)
ai_raw_json: Mapped[str | None] = mapped_column(Text, nullable=True) # 完整原始 JSON
screenshot = relationship("Screenshot", back_populates="meta")
+86
View File
@@ -0,0 +1,86 @@
"""截图主表与处理状态。"""
from __future__ import annotations
from datetime import datetime
from enum import Enum
from sqlalchemy import (
BigInteger,
DateTime,
ForeignKey,
Index,
Integer,
String,
UniqueConstraint,
func,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.core.db import Base
class ProcessStatus(str, Enum):
"""处理流水线的状态枚举。"""
PENDING = "pending"
RUNNING = "running"
DONE = "done"
FAILED = "failed"
SKIPPED = "skipped"
class Screenshot(Base):
"""截图文件主记录。"""
__tablename__ = "screenshots"
__table_args__ = (
UniqueConstraint("file_hash", name="uq_screenshots_file_hash"),
Index("ix_screenshots_captured_at", "captured_at"),
Index("ix_screenshots_ai_status", "ai_status"),
Index("ix_screenshots_category_id", "category_id"),
)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
path: Mapped[str] = mapped_column(String(1024), nullable=False)
file_hash: Mapped[str] = mapped_column(String(64), nullable=False)
width: Mapped[int] = mapped_column(Integer, default=0)
height: Mapped[int] = mapped_column(Integer, default=0)
size: Mapped[int] = mapped_column(BigInteger, default=0)
captured_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
imported_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.now(), nullable=False
)
thumb_path: Mapped[str | None] = mapped_column(String(1024), nullable=True)
ocr_status: Mapped[str] = mapped_column(String(16), default=ProcessStatus.PENDING.value)
ai_status: Mapped[str] = mapped_column(String(16), default=ProcessStatus.PENDING.value)
# AI 写回的分类:外键 + SET NULL,删除分类时自动把引用置空
category_id: Mapped[int | None] = mapped_column(
Integer,
ForeignKey("categories.id", ondelete="SET NULL"),
nullable=True,
)
is_favorite: Mapped[int] = mapped_column(Integer, default=0) # 0/1,便于 SQLite 索引
is_hidden: Mapped[int] = mapped_column(Integer, default=0)
meta = relationship(
"ScreenshotMeta",
back_populates="screenshot",
uselist=False,
cascade="all, delete-orphan",
)
tags = relationship(
"Tag",
secondary="screenshot_tags",
back_populates="screenshots",
lazy="selectin",
)
todos = relationship(
"Todo",
back_populates="screenshot",
cascade="all, delete-orphan",
)
+26
View File
@@ -0,0 +1,26 @@
"""键值设置:Provider 配置等以 JSON 形式存储。"""
from __future__ import annotations
from sqlalchemy import String, Text
from sqlalchemy.orm import Mapped, mapped_column
from app.core.db import Base
class Setting(Base):
"""通用键值设置。"""
__tablename__ = "settings"
key: Mapped[str] = mapped_column(String(64), primary_key=True)
value_json: Mapped[str] = mapped_column(Text, nullable=False, default="null")
# 设置键名常量
KEY_OCR_PROVIDER = "ocr_provider"
KEY_VLM_PROVIDER = "vlm_provider"
KEY_RECOGNITION_MODE = "recognition_mode" # ocr | vision | hybrid
KEY_CATEGORY_HINT = "category_hint"
# 默认识别模式:混合(OCR 文本 + 视觉 AI 联合分析)
DEFAULT_RECOGNITION_MODE = "hybrid"
+42
View File
@@ -0,0 +1,42 @@
"""标签与多对多关联。"""
from __future__ import annotations
from sqlalchemy import Column, ForeignKey, Integer, String, Table, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.core.db import Base
screenshot_tags = Table(
"screenshot_tags",
Base.metadata,
Column(
"screenshot_id",
Integer,
ForeignKey("screenshots.id", ondelete="CASCADE"),
primary_key=True,
),
Column(
"tag_id",
Integer,
ForeignKey("tags.id", ondelete="CASCADE"),
primary_key=True,
),
)
class Tag(Base):
"""用户/AI 共享的自由标签。"""
__tablename__ = "tags"
__table_args__ = (UniqueConstraint("name", name="uq_tags_name"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(String(64), nullable=False)
color: Mapped[str | None] = mapped_column(String(16), nullable=True)
screenshots = relationship(
"Screenshot",
secondary=screenshot_tags,
back_populates="tags",
)
+47
View File
@@ -0,0 +1,47 @@
"""AI 抽取的待办(待看/待读/待办)。"""
from __future__ import annotations
from datetime import datetime
from enum import Enum
from sqlalchemy import DateTime, ForeignKey, Index, Integer, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.core.db import Base
class TodoStatus(str, Enum):
"""待办状态。"""
PENDING = "pending"
DOING = "doing"
DONE = "done"
DROPPED = "dropped"
class Todo(Base):
"""AI 从截图中抽取的待办项。"""
__tablename__ = "todos"
__table_args__ = (
Index("ix_todos_status", "status"),
Index("ix_todos_screenshot_id", "screenshot_id"),
)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
screenshot_id: Mapped[int] = mapped_column(
Integer,
ForeignKey("screenshots.id", ondelete="CASCADE"),
nullable=False,
)
title: Mapped[str] = mapped_column(String(512), nullable=False)
note: Mapped[str | None] = mapped_column(Text, nullable=True)
kind: Mapped[str | None] = mapped_column(String(32), nullable=True) # 待看/待读/待办等
status: Mapped[str] = mapped_column(String(16), default=TodoStatus.PENDING.value)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.now(), nullable=False
)
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
screenshot = relationship("Screenshot", back_populates="todos")
+25
View File
@@ -0,0 +1,25 @@
"""被监听的截图目录列表。"""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import Boolean, DateTime, Integer, String, UniqueConstraint, func
from sqlalchemy.orm import Mapped, mapped_column
from app.core.db import Base
class WatchFolder(Base):
"""监听的截图目录。"""
__tablename__ = "watch_folders"
__table_args__ = (UniqueConstraint("path", name="uq_watch_folders_path"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
path: Mapped[str] = mapped_column(String(1024), nullable=False)
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
recursive: Mapped[bool] = mapped_column(Boolean, default=True)
is_sensitive: Mapped[bool] = mapped_column(Boolean, default=False) # 是否禁止上传云端
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.now(), nullable=False
)
+81
View File
@@ -0,0 +1,81 @@
"""Provider 工厂,按设置中的 type 字段实例化。"""
from __future__ import annotations
from typing import Optional
from app.schemas.common import ProviderConfig
from .base import OCRProvider, VLMProvider
from .ocr_http import HttpOCR
from .ocr_paddle import PaddleOCRProvider
from .ocr_tesseract import TesseractOCR
from .ocr_vision import VisionOCR
from .vlm_openai import OpenAICompatVLM
# OCR Provider 类型常量
OCR_TYPES = ("tesseract", "paddleocr", "http", "vision", "none")
VLM_TYPES = ("openai_compat", "none")
RECOGNITION_MODES = ("ocr", "vision", "hybrid")
def build_ocr_provider(
cfg: ProviderConfig | None,
*,
allow_upload: bool = True,
) -> Optional[OCRProvider]:
"""根据配置构造传统 OCR / 视觉 OCR Provider。"""
if cfg is None or cfg.type in ("", "none", "disabled"):
return None
if cfg.type == "tesseract":
return TesseractOCR(
lang=cfg.extra.get("lang", "chi_sim+eng"),
cmd=cfg.extra.get("cmd"),
)
if cfg.type == "paddleocr":
return PaddleOCRProvider(lang=cfg.extra.get("lang", "ch"))
if cfg.type == "http":
if not cfg.base_url:
raise ValueError("HTTP OCR 需要配置 base_url")
return HttpOCR(
base_url=cfg.base_url,
api_key=cfg.api_key or "",
text_path=str(cfg.extra.get("text_path", "text")),
headers=cfg.extra.get("headers") if isinstance(cfg.extra.get("headers"), dict) else None,
timeout=float(cfg.extra.get("timeout", 30)),
)
if cfg.type == "vision":
return build_vision_ocr(cfg, allow_upload=allow_upload)
raise ValueError(f"未知 OCR Provider 类型: {cfg.type}")
def build_vision_ocr(
cfg: ProviderConfig | None,
*,
allow_upload: bool = True,
) -> Optional[VisionOCR]:
"""从 ProviderConfig 构造视觉 OCR(可与 VLM 共用同一套接口配置)。"""
if cfg is None or cfg.type in ("", "none", "disabled"):
return None
base_url = cfg.base_url or "http://localhost:11434/v1"
model = cfg.model or "qwen2.5vl:7b"
return VisionOCR(
base_url=base_url,
api_key=cfg.api_key or "",
model=model,
timeout=float(cfg.extra.get("timeout", 60)),
allow_upload=allow_upload,
)
def build_vlm_provider(cfg: ProviderConfig | None) -> Optional[VLMProvider]:
"""根据配置构造 VLM Provider(AI 分类/摘要/标签)。"""
if cfg is None or cfg.type in ("", "none", "disabled"):
return None
if cfg.type in ("openai_compat", "openai", "ollama", "glm", "minimax", "moonshot", "vision"):
return OpenAICompatVLM(
base_url=cfg.base_url or "http://localhost:11434/v1",
api_key=cfg.api_key or "",
model=cfg.model or "gpt-4o-mini",
timeout=float(cfg.extra.get("timeout", 60)),
)
raise ValueError(f"未知 VLM Provider 类型: {cfg.type}")
+46
View File
@@ -0,0 +1,46 @@
"""OCR / VLM Provider 抽象接口。"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
@dataclass
class VLMResult:
"""VLM 结构化分析结果。"""
title: str = ""
summary: str = ""
category: str | None = None
tags: list[str] = field(default_factory=list)
todos: list[dict[str, str]] = field(default_factory=list) # [{title, kind, note}]
suggestion: str = ""
raw: dict[str, Any] = field(default_factory=dict)
class OCRProvider(ABC):
"""OCR 接口:输入图片路径,返回文本。"""
name: str = "ocr"
@abstractmethod
async def recognize(self, image_path: Path) -> str:
...
class VLMProvider(ABC):
"""多模态接口:根据图片 + OCR 文本生成结构化分析。"""
name: str = "vlm"
@abstractmethod
async def analyze(
self,
image_path: Path,
ocr_text: str,
categories: list[str],
allow_upload: bool,
) -> VLMResult:
...
+63
View File
@@ -0,0 +1,63 @@
"""通用 HTTP OCR:向自定义 REST 接口 POST 图片并解析文本。"""
from __future__ import annotations
import base64
import json
from pathlib import Path
from typing import Any
import httpx
from .base import OCRProvider
class HttpOCR(OCRProvider):
"""POST JSON {"image_base64": "..."} 到指定 URL,从响应 JSON 取文本。
extra 配置项:
- text_path: 点分路径,如 "data.text""result",默认 "text"
- headers: 额外请求头 dict
"""
name = "http"
def __init__(
self,
base_url: str,
api_key: str = "",
text_path: str = "text",
headers: dict[str, str] | None = None,
timeout: float = 30.0,
) -> None:
self.base_url = base_url.rstrip("/")
self.api_key = api_key
self.text_path = text_path
self.headers = headers or {}
self.timeout = timeout
async def recognize(self, image_path: Path) -> str:
with open(image_path, "rb") as f:
encoded = base64.b64encode(f.read()).decode("ascii")
headers = {"Content-Type": "application/json", **self.headers}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
payload = {"image_base64": encoded, "image": encoded}
async with httpx.AsyncClient(timeout=self.timeout) as client:
resp = await client.post(self.base_url, json=payload, headers=headers)
resp.raise_for_status()
data = resp.json()
return str(_dig(data, self.text_path) or "").strip()
def _dig(obj: Any, path: str) -> Any:
"""按点分路径从嵌套 dict 取值。"""
cur = obj
for part in path.split("."):
if not isinstance(cur, dict):
return None
cur = cur.get(part)
return cur
+43
View File
@@ -0,0 +1,43 @@
"""PaddleOCR 本地 OCR(可选依赖)。"""
from __future__ import annotations
import asyncio
from pathlib import Path
from .base import OCRProvider
class PaddleOCRProvider(OCRProvider):
"""通过 PaddleOCR 本地识文。需 pip install paddleocr paddlepaddle。"""
name = "paddleocr"
def __init__(self, lang: str = "ch") -> None:
self.lang = lang
self._engine = None
async def recognize(self, image_path: Path) -> str:
return await asyncio.to_thread(self._sync_recognize, image_path)
def _sync_recognize(self, image_path: Path) -> str:
try:
from paddleocr import PaddleOCR # type: ignore
except ImportError as exc:
raise RuntimeError(
"未安装 PaddleOCR,请执行: pip install paddleocr paddlepaddle"
) from exc
if self._engine is None:
self._engine = PaddleOCR(use_angle_cls=True, lang=self.lang, show_log=False)
result = self._engine.ocr(str(image_path), cls=True)
lines: list[str] = []
if result and result[0]:
for line in result[0]:
if line and len(line) >= 2:
text_part = line[1]
if isinstance(text_part, (list, tuple)) and text_part:
lines.append(str(text_part[0]))
elif isinstance(text_part, str):
lines.append(text_part)
return "\n".join(lines).strip()
+39
View File
@@ -0,0 +1,39 @@
"""Tesseract 本地 OCR 实现。"""
from __future__ import annotations
import asyncio
from pathlib import Path
from typing import Optional
from .base import OCRProvider
class TesseractOCR(OCRProvider):
"""通过 pytesseract 调用本地 tesseract。
需提前安装 tesseract-ocr 及中文语言包。
"""
name = "tesseract"
def __init__(self, lang: str = "chi_sim+eng", cmd: Optional[str] = None) -> None:
self.lang = lang
self.cmd = cmd
async def recognize(self, image_path: Path) -> str:
"""异步包装:避免阻塞事件循环。"""
return await asyncio.to_thread(self._sync_recognize, image_path)
def _sync_recognize(self, image_path: Path) -> str:
try:
import pytesseract
from PIL import Image
except ImportError as exc: # pragma: no cover
raise RuntimeError("未安装 pytesseract / Pillow") from exc
if self.cmd:
pytesseract.pytesseract.tesseract_cmd = self.cmd
with Image.open(image_path) as img:
text = pytesseract.image_to_string(img, lang=self.lang)
return text.strip()
+52
View File
@@ -0,0 +1,52 @@
"""视觉大模型 OCR:用多模态 API 从截图中提取文字。"""
from __future__ import annotations
from pathlib import Path
from .base import OCRProvider
from .openai_vision_client import chat_completions, safe_parse_json
_VISION_OCR_SYSTEM = """你是 OCR 助手。用户会给你一张截图,请尽可能完整地提取其中的文字。
只输出 JSON,格式:{"text": "提取到的全部文字,保留换行"}
如果没有可识别文字,text 填空字符串。"""
class VisionOCR(OCRProvider):
"""OpenAI 兼容视觉模型识文(GLM-4V / GPT-4o / Qwen-VL / Ollama 等)。"""
name = "vision"
def __init__(
self,
base_url: str,
api_key: str,
model: str,
timeout: float = 60.0,
allow_upload: bool = True,
) -> None:
self.base_url = base_url
self.api_key = api_key
self.model = model
self.timeout = timeout
self.allow_upload = allow_upload
async def recognize(self, image_path: Path) -> str:
"""调用视觉模型提取文字。"""
if not self.allow_upload:
raise RuntimeError("敏感目录禁止上传图片,无法使用视觉 OCR")
content = await chat_completions(
base_url=self.base_url,
api_key=self.api_key,
model=self.model,
system_prompt=_VISION_OCR_SYSTEM,
user_text="请提取这张截图中的所有文字。",
image_path=image_path,
allow_upload=True,
timeout=self.timeout,
json_mode=True,
)
parsed = safe_parse_json(content)
text = parsed.get("text") or parsed.get("ocr_text") or content
return str(text).strip()
@@ -0,0 +1,107 @@
"""OpenAI 兼容视觉 API 的公共封装:图片编码 + chat/completions 调用。"""
from __future__ import annotations
import base64
import json
from io import BytesIO
from pathlib import Path
from typing import Any
import httpx
from PIL import Image
from app.core.config import settings
from app.core.logger import get_logger
logger = get_logger(__name__)
def image_to_data_url(image_path: Path, max_side: int | None = None) -> str:
"""将图片压缩并编码为 data URL。"""
max_side = max_side or settings.vlm_max_side
with Image.open(image_path) as img:
img = img.convert("RGB")
w, h = img.size
scale = max(w, h) / max_side
if scale > 1:
img = img.resize((int(w / scale), int(h / scale)), Image.LANCZOS)
buf = BytesIO()
img.save(buf, format="JPEG", quality=82)
encoded = base64.b64encode(buf.getvalue()).decode("ascii")
return f"data:image/jpeg;base64,{encoded}"
def safe_parse_json(content: str) -> dict[str, Any]:
"""解析模型 JSON 输出,兼容 markdown 包裹。"""
text = content.strip()
if text.startswith("```"):
text = text.strip("`")
if text.lower().startswith("json"):
text = text[4:].strip()
try:
return json.loads(text)
except json.JSONDecodeError:
start = text.find("{")
end = text.rfind("}")
if start >= 0 and end > start:
try:
return json.loads(text[start : end + 1])
except json.JSONDecodeError:
pass
return {"text": content}
async def chat_completions(
*,
base_url: str,
api_key: str,
model: str,
system_prompt: str,
user_text: str,
image_path: Path | None = None,
allow_upload: bool = True,
timeout: float = 60.0,
json_mode: bool = True,
) -> str:
"""调用 /v1/chat/completions,返回 message.content 字符串。"""
user_content: list[dict[str, Any]] = [{"type": "text", "text": user_text}]
if image_path is not None and allow_upload:
data_url = image_to_data_url(image_path)
user_content.append({"type": "image_url", "image_url": {"url": data_url}})
payload: dict[str, Any] = {
"model": model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_content},
],
"temperature": 0.2,
}
if json_mode:
payload["response_format"] = {"type": "json_object"}
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
url = f"{base_url.rstrip('/')}/chat/completions"
async with httpx.AsyncClient(timeout=timeout) as client:
try:
resp = await client.post(url, json=payload, headers=headers)
except httpx.HTTPError as exc:
logger.warning("视觉 API 请求失败,尝试移除 response_format%s", exc)
payload.pop("response_format", None)
resp = await client.post(url, json=payload, headers=headers)
if resp.status_code == 400 and "response_format" in resp.text:
payload.pop("response_format", None)
resp = await client.post(url, json=payload, headers=headers)
resp.raise_for_status()
data = resp.json()
try:
return data["choices"][0]["message"]["content"]
except (KeyError, IndexError) as exc:
raise RuntimeError(f"视觉 API 返回结构异常: {data}") from exc
+107
View File
@@ -0,0 +1,107 @@
"""OpenAI 兼容 VLM 实现:覆盖 Ollama / GLM / MiniMax / Moonshot / OpenRouter / OpenAI。"""
from __future__ import annotations
from pathlib import Path
from typing import Any
from app.core.logger import get_logger
from .base import VLMProvider, VLMResult
from .openai_vision_client import chat_completions, safe_parse_json
logger = get_logger(__name__)
_SYSTEM_PROMPT = """你是一个截图整理助手。用户会给你一张截图(可能附带 OCR 文本)。
请用简洁的中文,按以下 JSON 结构返回分析结果,**只输出 JSON,不要解释**:
{
"title": "一句话标题,不超过 24 个字",
"summary": "2-3 句话总结这张截图的内容、要点或笑点",
"category": "从给定分类列表中选一个最贴切的名字;如果都不符合就填'其他'",
"tags": ["3-6 个能帮助检索的细分标签"],
"todos": [
{"title": "如果截图里出现'待看/待读/待办/想试试/记一下'的内容,抽成一条 todo", "kind": "待读|待看|待办|学习", "note": "可空"}
],
"suggestion": "可选:给用户的进一步行动建议或同类资源提示,可空"
}
要求:
- 标题要可读,不要复述"这是一张..."
- summary 不要超过 80 字。
- todos 没有可识别项时给空数组。"""
class OpenAICompatVLM(VLMProvider):
"""统一调用 /v1/chat/completions,图片以 base64 data URL 传入。"""
name = "openai_compat"
def __init__(
self,
base_url: str,
api_key: str,
model: str,
timeout: float = 60.0,
) -> None:
self.base_url = base_url.rstrip("/")
self.api_key = api_key
self.model = model
self.timeout = timeout
async def analyze(
self,
image_path: Path,
ocr_text: str,
categories: list[str],
allow_upload: bool,
) -> VLMResult:
"""调用模型并解析结构化 JSON。"""
prompt = (
f"可选分类:{', '.join(categories)}\n\n"
f"OCR 文本(可能不完整或为空):\n{ocr_text or '(无)'}"
)
content = await chat_completions(
base_url=self.base_url,
api_key=self.api_key,
model=self.model,
system_prompt=_SYSTEM_PROMPT,
user_text=prompt,
image_path=image_path if allow_upload else None,
allow_upload=allow_upload,
timeout=self.timeout,
json_mode=True,
)
parsed = safe_parse_json(content)
return _to_vlm_result(parsed)
def _to_vlm_result(data: dict[str, Any]) -> VLMResult:
"""JSON -> dataclass,容错地兜住字段。"""
todos_raw = data.get("todos") or []
todos: list[dict[str, str]] = []
if isinstance(todos_raw, list):
for item in todos_raw:
if isinstance(item, dict) and item.get("title"):
todos.append(
{
"title": str(item.get("title", ""))[:512],
"kind": str(item.get("kind", "")) or "待办",
"note": str(item.get("note", "") or ""),
}
)
elif isinstance(item, str):
todos.append({"title": item, "kind": "待办", "note": ""})
tags_raw = data.get("tags") or []
if not isinstance(tags_raw, list):
tags_raw = []
return VLMResult(
title=str(data.get("title", "") or "")[:128],
summary=str(data.get("summary", "") or ""),
category=str(data.get("category") or "") or None,
tags=[str(t) for t in tags_raw if t][:8],
todos=todos,
suggestion=str(data.get("suggestion", "") or ""),
raw=data,
)
View File
+76
View File
@@ -0,0 +1,76 @@
"""通用 Schema:状态、统计、设置。"""
from __future__ import annotations
from typing import Any, Optional
from pydantic import BaseModel
class StatsResp(BaseModel):
total: int
pending_ocr: int
pending_ai: int
failed: int
by_category: list[dict[str, Any]]
by_date: list[dict[str, Any]]
class WatchFolderIn(BaseModel):
path: str
enabled: bool = True
recursive: bool = True
is_sensitive: bool = False
class WatchFolderOut(WatchFolderIn):
id: int
class Config:
from_attributes = True
class CategoryIn(BaseModel):
name: str
color: Optional[str] = None
prompt_hint: Optional[str] = None
class ProviderConfig(BaseModel):
"""OCR/VLM Provider 配置。
type: openai_compat / tesseract / anthropic / none
base_url、api_key、model 等都是可选的,按 provider 类型决定。
"""
type: str
base_url: Optional[str] = None
api_key: Optional[str] = None
model: Optional[str] = None
extra: dict[str, Any] = {}
class ProviderConfigOut(ProviderConfig):
"""读取用:api_key 永远为空,只通过 api_key_mask 暴露提示。"""
api_key_mask: Optional[str] = None
class RecognitionModeIn(BaseModel):
"""文字识别策略:传统 OCR / 视觉 AI / 混合。"""
mode: str # ocr | vision | hybrid
class ProviderTestResult(BaseModel):
"""Provider 连通性测试结果。"""
ok: bool
message: str
detail: Optional[str] = None
latency_ms: Optional[int] = None
class TodoUpdate(BaseModel):
status: Optional[str] = None
title: Optional[str] = None
note: Optional[str] = None
+79
View File
@@ -0,0 +1,79 @@
"""分析任务队列的请求/响应模型。"""
from __future__ import annotations
from datetime import datetime
from typing import Optional
from pydantic import BaseModel
class JobOut(BaseModel):
"""单条任务详情,含关联截图摘要。"""
id: int
screenshot_id: int
kind: str
status: str
retries: int
last_error: Optional[str] = None
created_at: datetime
started_at: Optional[datetime] = None
finished_at: Optional[datetime] = None
thumb_url: Optional[str] = None
path: Optional[str] = None
ai_title: Optional[str] = None
ai_status: Optional[str] = None
ocr_status: Optional[str] = None
class JobListResp(BaseModel):
items: list[JobOut]
total: int
page: int
size: int
class JobRetryIn(BaseModel):
"""可选:仅重试指定 job id;不传则重试全部 failed。"""
job_ids: Optional[list[int]] = None
+90
View File
@@ -0,0 +1,90 @@
"""截图相关的请求/响应模型。"""
from __future__ import annotations
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
class TagOut(BaseModel):
id: int
name: str
color: Optional[str] = None
class Config:
from_attributes = True
class CategoryOut(BaseModel):
id: int
name: str
color: Optional[str] = None
class Config:
from_attributes = True
class ScreenshotBrief(BaseModel):
"""卡片列表用:尽量精简。"""
id: int
path: str
width: int
height: int
captured_at: datetime
thumb_url: Optional[str] = None
ai_title: Optional[str] = None
ai_status: str
ocr_status: str
is_favorite: bool = False
category: Optional[CategoryOut] = None
tags: list[TagOut] = []
class TodoBrief(BaseModel):
id: int
title: str
note: Optional[str] = None
kind: Optional[str] = None
status: str
created_at: datetime
completed_at: Optional[datetime] = None
screenshot_id: int
class Config:
from_attributes = True
class TodoListResp(BaseModel):
items: list[TodoBrief]
total: int
page: int
size: int
class ScreenshotDetail(ScreenshotBrief):
"""详情页用:含 OCR 与 AI 文本。"""
file_url: str
size: int
ocr_text: Optional[str] = None
ai_summary: Optional[str] = None
ai_suggestion: Optional[str] = None
todos: list[TodoBrief] = []
class ScreenshotListResp(BaseModel):
items: list[ScreenshotBrief]
total: int
page: int
size: int
class ScreenshotUpdate(BaseModel):
"""前端更新可写字段。"""
category_id: Optional[int] = None
is_favorite: Optional[bool] = None
is_hidden: Optional[bool] = None
tags: Optional[list[str]] = Field(default=None, description="标签名列表,自动新建")
View File
+484
View File
@@ -0,0 +1,484 @@
"""单张截图的分析逻辑:OCR -> VLM -> 写回数据库。
设计要点:
- 不在长时间网络调用期间持有 SQLite 写事务,避免 `database is locked`。
- 把流程拆为「短事务(取配置/标记状态)」 -> 「无事务(OCR/VLM 网络调用)」
-> 「短事务(写回结果)」。
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from sqlalchemy import select
from app.core.db import session_scope
from app.core.logger import get_logger
from app.core.path_utils import is_accessible_file, path_from_storage, path_is_under
from app.models.category import Category, DEFAULT_CATEGORIES
from app.models.meta import ScreenshotMeta
from app.models.screenshot import ProcessStatus, Screenshot
from app.models.setting import (
DEFAULT_RECOGNITION_MODE,
KEY_OCR_PROVIDER,
KEY_RECOGNITION_MODE,
KEY_VLM_PROVIDER,
)
from app.models.tag import Tag
from app.models.todo import Todo, TodoStatus
from app.models.watch_folder import WatchFolder
from app.providers import (
RECOGNITION_MODES,
build_ocr_provider,
build_vision_ocr,
build_vlm_provider,
)
from app.providers.base import VLMResult
from app.schemas.common import ProviderConfig
from app.services.exif_utils import is_exif_location_tag
from app.services.settings_store import get_provider_config, get_setting
logger = get_logger(__name__)
@dataclass
class _PreparedContext:
"""从短事务中导出的、不依赖 ORM 会话的纯数据。"""
path: Path
ocr_cfg: Optional[ProviderConfig]
vlm_cfg: Optional[ProviderConfig]
recognition_mode: str
category_names: list[str]
allow_upload: bool
exists: bool
async def analyze_screenshot_by_id(screenshot_id: int) -> None:
"""对外入口:按 id 分析单张截图。
被 worker 调度。函数内部自己管理多个短事务。
"""
ctx = _prepare(screenshot_id)
if ctx is None:
return # 截图已被删除
if not ctx.exists:
_persist_missing(screenshot_id)
return
ocr_provider = _safe_build(
lambda c: build_ocr_provider(c, allow_upload=ctx.allow_upload),
ctx.ocr_cfg if _use_traditional_ocr(ctx) else None,
"OCR",
)
vlm_provider = _safe_build(build_vlm_provider, ctx.vlm_cfg, "VLM")
# ---- 文字识别阶段(在事务外执行)----
ocr_text, ocr_status = await _extract_text(screenshot_id, ctx, ocr_provider)
# ---- VLM 阶段(事务外)----
vlm_result: Optional[VLMResult] = None
ai_status = ProcessStatus.SKIPPED.value
vlm_error: Optional[Exception] = None
if vlm_provider is not None:
_mark_status(screenshot_id, ai=ProcessStatus.RUNNING.value)
try:
vlm_result = await vlm_provider.analyze(
image_path=ctx.path,
ocr_text=ocr_text,
categories=ctx.category_names,
allow_upload=ctx.allow_upload,
)
ai_status = ProcessStatus.DONE.value
except Exception as exc: # noqa: BLE001
logger.warning("VLM 失败 #%d: %s", screenshot_id, exc)
ai_status = ProcessStatus.FAILED.value
vlm_error = exc
# ---- 写回阶段(短事务)----
_persist_result(
screenshot_id=screenshot_id,
ocr_text=ocr_text,
ocr_status=ocr_status,
ai_status=ai_status,
vlm_result=vlm_result,
)
if vlm_error is not None:
raise vlm_error # 让 worker 决定重试
async def analyze_ocr_only_by_id(screenshot_id: int) -> None:
"""仅补跑 OCR/视觉识文,不改动 AI 分析结果。
用于 ai_status=done 但 ocr_status=failed 的截图。
OCR 仍失败时抛异常,由 worker 按 max_retries 重试。
"""
ctx = _prepare(screenshot_id)
if ctx is None:
return
if not ctx.exists:
_persist_missing(screenshot_id)
raise RuntimeError("截图文件丢失")
ocr_provider = _safe_build(
lambda c: build_ocr_provider(c, allow_upload=ctx.allow_upload),
ctx.ocr_cfg if _use_traditional_ocr(ctx) else None,
"OCR",
)
ocr_text, ocr_status = await _extract_text(screenshot_id, ctx, ocr_provider)
_persist_ocr_only(screenshot_id, ocr_text, ocr_status)
if ocr_status == ProcessStatus.FAILED.value:
raise RuntimeError("OCR 识别失败")
# ---------------- 短事务工具 ---------------- #
def _prepare(screenshot_id: int) -> Optional[_PreparedContext]:
"""短事务:读取 Provider 配置、分类列表、敏感目录判定。"""
with session_scope() as session:
shot = session.get(Screenshot, screenshot_id)
if shot is None:
return None
image_path = path_from_storage(shot.path)
exists = is_accessible_file(image_path)
ocr_cfg = _load_provider_config(session, KEY_OCR_PROVIDER)
vlm_cfg = _load_provider_config(session, KEY_VLM_PROVIDER)
mode = get_setting(session, KEY_RECOGNITION_MODE, DEFAULT_RECOGNITION_MODE)
if mode not in RECOGNITION_MODES:
mode = DEFAULT_RECOGNITION_MODE
categories = _ensure_default_categories(session)
category_names = [c.name for c in categories]
allow_upload = not _is_sensitive(session, image_path)
return _PreparedContext(
path=image_path,
ocr_cfg=ocr_cfg,
vlm_cfg=vlm_cfg,
recognition_mode=mode,
category_names=category_names,
allow_upload=allow_upload,
exists=exists,
)
def _use_traditional_ocr(ctx: _PreparedContext) -> bool:
"""混合/传统模式下是否启用 OCR 区配置(排除 vision 类型,vision 单独处理)。"""
if ctx.recognition_mode not in ("ocr", "hybrid"):
return False
if ctx.ocr_cfg is None or ctx.ocr_cfg.type in ("", "none", "disabled", "vision"):
return False
return True
async def _extract_text(
screenshot_id: int,
ctx: _PreparedContext,
ocr_provider,
) -> tuple[str, str]:
"""按识别模式提取文字:传统 OCR / 视觉 AI / 混合。"""
ocr_text = ""
ocr_status = ProcessStatus.SKIPPED.value
mode = ctx.recognition_mode
# 1) 传统 OCRTesseract / Paddle / HTTP
if ocr_provider is not None:
_mark_status(screenshot_id, ocr=ProcessStatus.RUNNING.value)
try:
ocr_text = await ocr_provider.recognize(ctx.path)
ocr_status = ProcessStatus.DONE.value
except Exception as exc: # noqa: BLE001
logger.warning("OCR 失败 #%d: %s", screenshot_id, exc)
ocr_status = ProcessStatus.FAILED.value
# 2) 视觉 AI 识文
need_vision = mode == "vision" or (
mode == "hybrid" and not ocr_text.strip()
)
if mode == "ocr" and ctx.ocr_cfg and ctx.ocr_cfg.type == "vision":
# 用户在 OCR 区选了「视觉模型识文」
need_vision = True
if need_vision:
vision_cfg = _pick_vision_config(ctx)
vision = _safe_build(
lambda c: build_vision_ocr(c, allow_upload=ctx.allow_upload),
vision_cfg,
"VisionOCR",
)
if vision is not None:
_mark_status(screenshot_id, ocr=ProcessStatus.RUNNING.value)
try:
ocr_text = await vision.recognize(ctx.path)
ocr_status = ProcessStatus.DONE.value
except Exception as exc: # noqa: BLE001
logger.warning("视觉识文失败 #%d: %s", screenshot_id, exc)
if ocr_status != ProcessStatus.DONE.value:
ocr_status = ProcessStatus.FAILED.value
return ocr_text, ocr_status
def _pick_vision_config(ctx: _PreparedContext) -> Optional[ProviderConfig]:
"""决定视觉识文用哪套配置:优先 OCR 区的 vision,否则 VLM 区。"""
if ctx.ocr_cfg and ctx.ocr_cfg.type == "vision":
return ctx.ocr_cfg
if ctx.recognition_mode == "vision" or ctx.recognition_mode == "hybrid":
return ctx.vlm_cfg
return ctx.vlm_cfg
def _mark_status(
screenshot_id: int,
ocr: Optional[str] = None,
ai: Optional[str] = None,
) -> None:
"""短事务:把截图标记为 running,方便前端看到进度。"""
if ocr is None and ai is None:
return
with session_scope() as session:
shot = session.get(Screenshot, screenshot_id)
if shot is None:
return
if ocr is not None:
shot.ocr_status = ocr
if ai is not None:
shot.ai_status = ai
def _persist_missing(screenshot_id: int) -> None:
"""短事务:标记文件已丢失。"""
with session_scope() as session:
shot = session.get(Screenshot, screenshot_id)
if shot is None:
return
shot.ocr_status = ProcessStatus.FAILED.value
shot.ai_status = ProcessStatus.FAILED.value
meta = _get_or_create_meta(session, screenshot_id)
meta.ai_summary = "(文件丢失)"
def _persist_result(
screenshot_id: int,
ocr_text: str,
ocr_status: str,
ai_status: str,
vlm_result: Optional[VLMResult],
) -> None:
"""短事务:把 OCR/VLM 结果写回 DB,包括 meta/tags/category/todos。"""
with session_scope() as session:
shot = session.get(Screenshot, screenshot_id)
if shot is None:
return
shot.ocr_status = ocr_status
shot.ai_status = ai_status
meta = _get_or_create_meta(session, screenshot_id)
meta.ocr_text = ocr_text or None
if vlm_result is not None:
meta.ai_title = vlm_result.title or None
meta.ai_summary = vlm_result.summary or None
meta.ai_suggestion = vlm_result.suggestion or None
meta.ai_raw_json = json.dumps(vlm_result.raw, ensure_ascii=False)
categories = list(session.scalars(select(Category)).all())
category = _resolve_category(session, vlm_result.category, categories)
if category is not None:
shot.category_id = category.id
_sync_tags(session, shot, vlm_result.tags)
_sync_todos(session, shot, vlm_result.todos)
def _persist_ocr_only(screenshot_id: int, ocr_text: str, ocr_status: str) -> None:
"""短事务:仅写回 OCR 文本与状态,保留已有 AI 字段。"""
with session_scope() as session:
shot = session.get(Screenshot, screenshot_id)
if shot is None:
return
shot.ocr_status = ocr_status
meta = _get_or_create_meta(session, screenshot_id)
meta.ocr_text = ocr_text or None
def enqueue_ocr_jobs(*, limit: int = 500) -> int:
"""为「AI 已成功、OCR 失败」的截图批量创建 OCR 补跑任务。
跳过已有 pending/running 的 ocr 任务,避免重复入队。
"""
from app.models.job import Job, JobKind, JobStatus
active_status = (JobStatus.PENDING.value, JobStatus.RUNNING.value)
created = 0
with session_scope() as session:
# 已有活跃 OCR 任务的 screenshot_id
busy_ids = set(
session.scalars(
select(Job.screenshot_id).where(
Job.kind == JobKind.OCR.value,
Job.status.in_(active_status),
)
).all()
)
shots = session.scalars(
select(Screenshot)
.where(
Screenshot.ocr_status == ProcessStatus.FAILED.value,
Screenshot.ai_status == ProcessStatus.DONE.value,
)
.order_by(Screenshot.id.asc())
.limit(limit)
).all()
for shot in shots:
if shot.id in busy_ids:
continue
session.add(
Job(
screenshot_id=shot.id,
kind=JobKind.OCR.value,
status=JobStatus.PENDING.value,
)
)
busy_ids.add(shot.id)
created += 1
return created
# ---------------- 内部辅助 ---------------- #
def _load_provider_config(session, key: str) -> Optional[ProviderConfig]:
raw = get_provider_config(session, key)
if not raw:
return None
try:
return ProviderConfig(**raw)
except Exception as exc: # noqa: BLE001
logger.warning("Provider 配置 %s 解析失败: %s", key, exc)
return None
def _safe_build(builder, cfg: Optional[ProviderConfig], label: str):
if cfg is None:
return None
try:
return builder(cfg)
except Exception as exc: # noqa: BLE001
logger.warning("%s Provider 构造失败: %s", label, exc)
return None
def _is_sensitive(session, image_path: Path) -> bool:
"""判断文件是否落在某个标记为敏感的监听目录内。"""
sensitive_dirs = session.scalars(
select(WatchFolder.path).where(WatchFolder.is_sensitive.is_(True))
).all()
child = str(image_path)
for d in sensitive_dirs:
if path_is_under(d, child):
return True
return False
def _get_or_create_meta(session, screenshot_id: int) -> ScreenshotMeta:
meta = session.get(ScreenshotMeta, screenshot_id)
if meta is None:
meta = ScreenshotMeta(screenshot_id=screenshot_id)
session.add(meta)
session.flush()
return meta
def ensure_default_categories() -> None:
"""对外暴露:启动时 seed 默认分类。"""
with session_scope() as session:
_ensure_default_categories(session)
def _ensure_default_categories(session) -> list[Category]:
"""首次运行时灌入默认分类,返回最新列表。"""
existing = session.scalars(select(Category)).all()
if existing:
return list(existing)
for item in DEFAULT_CATEGORIES:
session.add(Category(**item))
session.flush()
return list(session.scalars(select(Category)).all())
def _resolve_category(
session,
name: str | None,
categories: list[Category],
) -> Optional[Category]:
if not name:
return None
normalized = name.strip()
for c in categories:
if c.name == normalized or c.name in normalized or normalized in c.name:
return c
new_cat = Category(name=normalized[:64], color=None, prompt_hint=None)
session.add(new_cat)
session.flush()
categories.append(new_cat)
return new_cat
def _sync_tags(session, screenshot: Screenshot, tag_names: list[str]) -> None:
"""根据 AI 给的标签名同步多对多关系;保留 EXIF 地点标签不被覆盖。"""
exif_tags = [t for t in (screenshot.tags or []) if is_exif_location_tag(t.name)]
exif_names = {t.name for t in exif_tags}
seen: set[str] = set(exif_names)
tag_objs: list[Tag] = list(exif_tags)
for raw_name in tag_names:
name = (raw_name or "").strip()[:64]
if not name or name in seen:
continue
seen.add(name)
tag = session.scalar(select(Tag).where(Tag.name == name))
if tag is None:
tag = Tag(name=name)
session.add(tag)
session.flush()
tag_objs.append(tag)
screenshot.tags = tag_objs
def _sync_todos(
session,
screenshot: Screenshot,
todos: list[dict[str, str]],
) -> None:
"""以 AI 输出覆盖该截图未完成的 todos;保留用户已完成/搁置项。"""
existing = list(screenshot.todos)
for t in existing:
if t.status in (TodoStatus.DONE.value, TodoStatus.DROPPED.value):
continue
session.delete(t)
session.flush()
for item in todos:
title = (item.get("title") or "").strip()
if not title:
continue
session.add(
Todo(
screenshot_id=screenshot.id,
title=title[:512],
note=(item.get("note") or "")[:2000] or None,
kind=(item.get("kind") or "待办")[:32],
status=TodoStatus.PENDING.value,
)
)
session.flush()
+107
View File
@@ -0,0 +1,107 @@
"""从图片 EXIF 提取拍摄时间与 GPS 地点标签。"""
from __future__ import annotations
from datetime import datetime
from fractions import Fraction
from pathlib import Path
from typing import Optional
from PIL import ExifTags, Image
from app.core.logger import get_logger
logger = get_logger(__name__)
# EXIF 地点类标签前缀,重分析时保留不被 AI 覆盖
EXIF_TAG_PREFIX = "地点:"
def _ratio_to_float(value) -> float:
"""EXIF 有理数 → float。"""
if isinstance(value, tuple) and len(value) == 2:
num, den = value
return float(num) / float(den) if den else 0.0
if isinstance(value, Fraction):
return float(value)
return float(value)
def _dms_to_decimal(dms: tuple, ref: str) -> Optional[float]:
"""度分秒 → 十进制度。"""
try:
deg, minutes, seconds = dms
decimal = _ratio_to_float(deg) + _ratio_to_float(minutes) / 60 + _ratio_to_float(seconds) / 3600
if ref in ("S", "W"):
decimal = -decimal
return round(decimal, 6)
except (TypeError, ValueError, ZeroDivisionError):
return None
def extract_image_metadata(path: Path) -> tuple[Optional[datetime], list[str]]:
"""读取 EXIF,返回 (拍摄时间, 地点标签列表)。"""
captured: Optional[datetime] = None
location_tags: list[str] = []
try:
with Image.open(path) as img:
exif = img.getexif()
if not exif:
return None, []
# 拍摄时间:优先 DateTimeOriginal
for key in (36867, 36868, 306): # DateTimeOriginal / DateTimeDigitized / DateTime
raw = exif.get(key)
if raw:
captured = _parse_exif_datetime(str(raw))
if captured:
break
# GPS → 地点标签
gps_ifd = exif.get_ifd(ExifTags.IFD.GPSInfo) if hasattr(exif, "get_ifd") else None
if gps_ifd:
lat = _dms_to_decimal(
gps_ifd.get(2),
gps_ifd.get(1, "N"),
)
lon = _dms_to_decimal(
gps_ifd.get(4),
gps_ifd.get(3, "E"),
)
if lat is not None and lon is not None:
location_tags.append(f"{EXIF_TAG_PREFIX}{lat},{lon}")
# 部分设备写入可读地名(XP Keywords / ImageDescription 等)
for key, val in exif.items():
tag_name = ExifTags.TAGS.get(key, "")
if tag_name in ("ImageDescription", "XPComment") and val:
text = str(val).strip()[:64]
if text and _looks_like_place(text):
location_tags.append(f"{EXIF_TAG_PREFIX}{text}")
except Exception as exc: # noqa: BLE001
logger.debug("读取 EXIF 失败 %s: %s", path.name, exc)
return captured, location_tags
def _parse_exif_datetime(raw: str) -> Optional[datetime]:
"""解析 EXIF 时间字符串。"""
for fmt in ("%Y:%m:%d %H:%M:%S", "%Y-%m-%d %H:%M:%S"):
try:
return datetime.strptime(raw.strip(), fmt)
except ValueError:
continue
return None
def _looks_like_place(text: str) -> bool:
"""粗判字符串是否像地名(含中文或常见地址关键词)。"""
keywords = ("", "", "", "", "", "", "", "", "", "GPS")
return any(k in text for k in keywords) or any("\u4e00" <= c <= "\u9fff" for c in text)
def is_exif_location_tag(name: str) -> bool:
"""是否为 EXIF 自动写入的地点标签。"""
return name.startswith(EXIF_TAG_PREFIX)
+147
View File
@@ -0,0 +1,147 @@
"""将磁盘上的截图文件入库 + 排队分析。"""
from __future__ import annotations
from datetime import datetime
from pathlib import Path
from typing import Iterable, Optional
from PIL import Image
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.path_utils import (
is_accessible_dir,
is_accessible_file,
path_from_storage,
path_to_storage,
)
from app.core.logger import get_logger
from app.models.job import Job, JobKind, JobStatus
from app.models.screenshot import ProcessStatus, Screenshot
from app.models.tag import Tag
from app.services.exif_utils import extract_image_metadata
from app.services.thumbnail import file_hash, generate_thumbnail, is_supported
logger = get_logger(__name__)
def ingest_path(session: Session, path: Path) -> Optional[Screenshot]:
"""单文件入库。返回 Screenshot 或 None(不支持/重复时)。"""
if not is_accessible_file(path) or not path.is_file():
return None
if not is_supported(path):
return None
stored_path = path_to_storage(path)
try:
digest = file_hash(path)
except OSError as exc:
logger.warning("无法读取文件 %s: %s", path, exc)
return None
existing = session.scalar(select(Screenshot).where(Screenshot.file_hash == digest))
if existing:
# 同一内容重命名/移动:更新路径
if existing.path != stored_path:
existing.path = stored_path
session.flush()
return existing
try:
with Image.open(path) as img:
width, height = img.size
except Exception as exc: # noqa: BLE001
logger.warning("无法读取图片尺寸 %s: %s", path, exc)
width, height = 0, 0
stat = path.stat()
captured_at = datetime.fromtimestamp(stat.st_mtime)
exif_time, location_tags = extract_image_metadata(path)
if exif_time is not None:
captured_at = exif_time
try:
thumb = generate_thumbnail(path)
thumb_path = thumb.as_posix()
except Exception as exc: # noqa: BLE001
logger.warning("生成缩略图失败 %s: %s", path, exc)
thumb_path = None
shot = Screenshot(
path=stored_path,
file_hash=digest,
width=width,
height=height,
size=stat.st_size,
captured_at=captured_at,
thumb_path=thumb_path,
ocr_status=ProcessStatus.PENDING.value,
ai_status=ProcessStatus.PENDING.value,
)
session.add(shot)
session.flush()
if location_tags:
_attach_location_tags(session, shot, location_tags)
job = Job(screenshot_id=shot.id, kind=JobKind.FULL.value, status=JobStatus.PENDING.value)
session.add(job)
logger.info("入库 #%d %s", shot.id, path.name)
return shot
def _attach_location_tags(session: Session, shot: Screenshot, tag_names: list[str]) -> None:
"""入库时写入 EXIF 地点标签。"""
tag_objs: list[Tag] = []
for raw in tag_names:
name = (raw or "").strip()[:64]
if not name:
continue
tag = session.scalar(select(Tag).where(Tag.name == name))
if tag is None:
tag = Tag(name=name)
session.add(tag)
session.flush()
tag_objs.append(tag)
shot.tags = tag_objs
def ingest_directory(
session: Session,
root: Path | str,
recursive: bool = True,
) -> tuple[int, int]:
"""遍历目录入库。返回 (新增数, 跳过数)。支持 UNC 网络路径。"""
root_p = path_from_storage(str(root)) if isinstance(root, str) else root
if not is_accessible_dir(root_p):
return 0, 0
iterator: Iterable[Path]
if recursive:
iterator = (p for p in root_p.rglob("*") if p.is_file())
else:
iterator = (p for p in root_p.iterdir() if p.is_file())
added, skipped = 0, 0
for path in iterator:
if not is_supported(path):
continue
stored = path_to_storage(path)
before = session.scalar(
select(Screenshot.id).where(Screenshot.path == stored)
)
result = ingest_path(session, path)
if result is None:
skipped += 1
continue
if before is None:
added += 1
else:
skipped += 1
# 批量提交,避免巨型事务
if (added + skipped) % 50 == 0:
session.commit()
session.commit()
return added, skipped
+207
View File
@@ -0,0 +1,207 @@
"""Provider 连通性测试:OCR / 视觉 AI。"""
from __future__ import annotations
import asyncio
import base64
import time
from typing import Any
import httpx
from app.providers import build_ocr_provider
from app.schemas.common import ProviderConfig
class ProviderTestError(Exception):
"""测试失败,携带用户可读信息。"""
async def test_provider_config(key: str, cfg: ProviderConfig) -> dict[str, Any]:
"""测试 OCR 或 VLM Provider 连通性,返回 {ok, message, detail, latency_ms}。"""
started = time.perf_counter()
try:
if cfg.type in ("", "none", "disabled"):
raise ProviderTestError("当前类型为「不使用」,无需测试")
if key == KEY_OCR:
message, detail = await _test_ocr(cfg)
elif key == KEY_VLM:
message, detail = await _test_vlm(cfg)
else:
raise ProviderTestError(f"未知配置键: {key}")
latency = int((time.perf_counter() - started) * 1000)
return {"ok": True, "message": message, "detail": detail, "latency_ms": latency}
except ProviderTestError as exc:
latency = int((time.perf_counter() - started) * 1000)
return {"ok": False, "message": str(exc), "detail": None, "latency_ms": latency}
except Exception as exc: # noqa: BLE001
latency = int((time.perf_counter() - started) * 1000)
return {
"ok": False,
"message": f"测试失败: {exc}",
"detail": repr(exc),
"latency_ms": latency,
}
KEY_OCR = "ocr_provider"
KEY_VLM = "vlm_provider"
# 1x1 白图,用于 HTTP OCR / 视觉测试
_TINY_PNG_B64 = (
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
)
async def _test_ocr(cfg: ProviderConfig) -> tuple[str, str | None]:
if cfg.type == "tesseract":
return await _test_tesseract(cfg)
if cfg.type == "paddleocr":
return await _test_paddle(cfg)
if cfg.type == "http":
return await _test_http_ocr(cfg)
if cfg.type == "vision":
return await _test_openai_compat(cfg, label="视觉 OCR")
raise ProviderTestError(f"不支持的 OCR 类型: {cfg.type}")
async def _test_vlm(cfg: ProviderConfig) -> tuple[str, str | None]:
if cfg.type in ("openai_compat", "openai", "ollama", "glm", "minimax", "moonshot", "vision"):
return await _test_openai_compat(cfg, label="视觉 AI")
raise ProviderTestError(f"不支持的 VLM 类型: {cfg.type}")
async def _test_tesseract(cfg: ProviderConfig) -> tuple[str, str | None]:
provider = build_ocr_provider(cfg, allow_upload=True)
if provider is None:
raise ProviderTestError("无法构造 Tesseract Provider")
def _check() -> str:
import pytesseract
if cfg.extra.get("cmd"):
pytesseract.pytesseract.tesseract_cmd = cfg.extra["cmd"]
version = pytesseract.get_tesseract_version()
return str(version)
version = await asyncio.to_thread(_check)
return f"Tesseract 可用,版本 {version}", f"lang={cfg.extra.get('lang', 'chi_sim+eng')}"
async def _test_paddle(cfg: ProviderConfig) -> tuple[str, str | None]:
def _check() -> str:
try:
import paddleocr # noqa: F401
except ImportError as exc:
raise ProviderTestError(
"未安装 PaddleOCR,请执行: pip install paddleocr paddlepaddle"
) from exc
return "PaddleOCR 模块已安装"
detail = await asyncio.to_thread(_check)
provider = build_ocr_provider(cfg, allow_upload=True)
if provider is None:
raise ProviderTestError("无法构造 PaddleOCR Provider")
return "PaddleOCR 可用", detail
async def _test_http_ocr(cfg: ProviderConfig) -> tuple[str, str | None]:
if not cfg.base_url:
raise ProviderTestError("请填写 OCR API URL")
provider = build_ocr_provider(cfg, allow_upload=True)
if provider is None:
raise ProviderTestError("无法构造 HTTP OCR Provider")
# 写入临时 tiny png 再调用
from pathlib import Path
import tempfile
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
tmp.write(base64.b64decode(_TINY_PNG_B64))
tmp_path = Path(tmp.name)
try:
text = await provider.recognize(tmp_path)
finally:
try:
tmp_path.unlink(missing_ok=True)
except OSError:
pass
preview = (text or "").strip()[:80] or "(空响应,但接口可达)"
return "HTTP OCR 接口可达", f"响应预览: {preview}"
async def _test_openai_compat(cfg: ProviderConfig, *, label: str) -> tuple[str, str | None]:
base_url = (cfg.base_url or "http://localhost:11434/v1").rstrip("/")
api_key = cfg.api_key or ""
model = cfg.model or "gpt-4o-mini"
timeout = float(cfg.extra.get("timeout", 30))
headers: dict[str, str] = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
# 1) 尝试 /modelsOllama、OpenAI 兼容)
models_url = f"{base_url}/models"
async with httpx.AsyncClient(timeout=timeout) as client:
try:
resp = await client.get(models_url, headers=headers)
if resp.status_code == 200:
data = resp.json()
ids = _extract_model_ids(data)
if model and ids and model not in ids:
return (
f"{label} 服务可达",
f"已连接 /models,但未找到模型「{model}」。可用: {', '.join(ids[:8])}",
)
return f"{label} 服务可达", f"已连接 /models,目标模型: {model}"
except httpx.HTTPError:
pass
# 2) 最小 chat 探活
chat_url = f"{base_url}/chat/completions"
payload = {
"model": model,
"messages": [{"role": "user", "content": "请只回复 OK"}],
"max_tokens": 16,
"temperature": 0,
}
try:
resp = await client.post(chat_url, json=payload, headers=headers)
resp.raise_for_status()
data = resp.json()
content = data["choices"][0]["message"]["content"]
return f"{label} 对话成功", f"模型 {model} 回复: {str(content).strip()[:60]}"
except httpx.HTTPStatusError as exc:
body = exc.response.text[:200]
raise ProviderTestError(
f"API 返回 {exc.response.status_code}: {body}"
) from exc
except httpx.HTTPError as exc:
raise ProviderTestError(f"无法连接 {base_url}: {exc}") from exc
def _extract_model_ids(data: Any) -> list[str]:
"""从 /models 响应中提取 model id 列表。"""
if not isinstance(data, dict):
return []
items = data.get("data") or data.get("models") or []
ids: list[str] = []
if isinstance(items, list):
for item in items:
if isinstance(item, dict):
mid = item.get("id") or item.get("name") or item.get("model")
if mid:
ids.append(str(mid))
elif isinstance(item, str):
ids.append(item)
return ids
def merge_provider_api_key(cfg: ProviderConfig, existing: dict | None) -> ProviderConfig:
"""测试时若 api_key 为空,合并已保存的 key。"""
payload = cfg.model_dump()
if (not payload.get("api_key")) and isinstance(existing, dict):
payload["api_key"] = existing.get("api_key", "")
return ProviderConfig(**payload)
+67
View File
@@ -0,0 +1,67 @@
"""截图列表搜索:FTS + 子串模糊(兼容中文标签/标题)。"""
from __future__ import annotations
from sqlalchemy import or_, select, text
from sqlalchemy.orm import Session
from app.models.meta import ScreenshotMeta
from app.models.screenshot import Screenshot
from app.models.tag import Tag
def fts_query_string(raw: str) -> str:
"""把用户输入处理成 FTS5 查询串(中英文均支持前缀匹配)。"""
parts = [p for p in raw.replace("\n", " ").split() if p]
if not parts:
return raw
cleaned: list[str] = []
for p in parts:
p = p.replace('"', "").strip()
if not p:
continue
cleaned.append(f'"{p}"*')
return " ".join(cleaned)
def collect_search_ids(session: Session, q: str, *, limit: int = 5000) -> set[int]:
"""联合 FTS5 与 LIKE 子串搜索,返回匹配的 screenshot id 集合。"""
q = q.strip()
if not q:
return set()
ids: set[int] = set()
like = f"%{q}%"
# 1) FTS5 全文索引
try:
fts_sql = text(
"SELECT rowid FROM screenshots_fts WHERE screenshots_fts MATCH :q LIMIT :lim"
)
rows = session.execute(fts_sql, {"q": fts_query_string(q), "lim": limit}).fetchall()
ids.update(int(row[0]) for row in rows)
except Exception:
pass
# 2) 子串模糊:OCR/AI 文本(解决「三花」匹配「三花猫」)
meta_ids = session.scalars(
select(ScreenshotMeta.screenshot_id).where(
or_(
ScreenshotMeta.ocr_text.ilike(like),
ScreenshotMeta.ai_title.ilike(like),
ScreenshotMeta.ai_summary.ilike(like),
ScreenshotMeta.ai_suggestion.ilike(like),
)
).limit(limit)
).all()
ids.update(int(i) for i in meta_ids)
# 3) 标签名子串匹配
tag_ids = session.scalars(
select(Screenshot.id)
.join(Screenshot.tags)
.where(Tag.name.ilike(like))
.limit(limit)
).all()
ids.update(int(i) for i in tag_ids)
return ids
+50
View File
@@ -0,0 +1,50 @@
"""读取/写入键值设置。"""
from __future__ import annotations
import json
from typing import Any, Optional
from sqlalchemy.orm import Session
from app.models.setting import Setting
def get_setting(session: Session, key: str, default: Any = None) -> Any:
"""读取并 JSON 解析。"""
row = session.get(Setting, key)
if row is None:
return default
try:
return json.loads(row.value_json)
except json.JSONDecodeError:
return default
def set_setting(session: Session, key: str, value: Any) -> None:
"""JSON 序列化后落库(upsert)。"""
row = session.get(Setting, key)
payload = json.dumps(value, ensure_ascii=False)
if row is None:
session.add(Setting(key=key, value_json=payload))
else:
row.value_json = payload
session.flush()
def all_settings(session: Session) -> dict[str, Any]:
"""返回所有设置,给前端调试 / 导出。"""
items: dict[str, Any] = {}
for row in session.query(Setting).all():
try:
items[row.key] = json.loads(row.value_json)
except json.JSONDecodeError:
items[row.key] = row.value_json
return items
def get_provider_config(session: Session, key: str) -> Optional[dict[str, Any]]:
"""便捷读取 OCR/VLM provider 配置 dict。"""
value = get_setting(session, key, None)
if isinstance(value, dict):
return value
return None
+48
View File
@@ -0,0 +1,48 @@
"""生成并缓存缩略图。"""
from __future__ import annotations
import hashlib
from pathlib import Path
from PIL import Image
from app.core.config import settings
from app.core.path_utils import path_to_storage
SUPPORTED_EXTS = {".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp", ".tif", ".tiff"}
def is_supported(path: Path) -> bool:
"""是否为我们支持的图片格式。"""
return path.suffix.lower() in SUPPORTED_EXTS
def generate_thumbnail(image_path: Path, max_side: int | None = None) -> Path:
"""生成 webp 缩略图,落到缓存目录。返回缓存路径。"""
max_side = max_side or settings.thumb_size
# 用文件路径 + mtime 哈希作为缓存键,源文件变化会自动生成新缩略图
stat = image_path.stat()
key = hashlib.md5(
f"{path_to_storage(image_path)}|{stat.st_mtime_ns}|{max_side}".encode("utf-8")
).hexdigest()
out = settings.thumb_dir / f"{key}.webp"
if out.exists():
return out
with Image.open(image_path) as img:
img = img.convert("RGB")
img.thumbnail((max_side, max_side), Image.LANCZOS)
img.save(out, format="WEBP", quality=80)
return out
def file_hash(image_path: Path, chunk: int = 1024 * 1024) -> str:
"""计算文件 sha256,用作去重键。"""
h = hashlib.sha256()
with open(image_path, "rb") as f:
while True:
data = f.read(chunk)
if not data:
break
h.update(data)
return h.hexdigest()
+122
View File
@@ -0,0 +1,122 @@
"""watchdog 监听被关注的目录。
中文路径与 OneDrive 同步盘下 NTFS 事件偶发不稳,因此默认使用 PollingObserver。
"""
from __future__ import annotations
import asyncio
import threading
from pathlib import Path
from sqlalchemy import select
from watchdog.events import FileSystemEvent, FileSystemEventHandler
from watchdog.observers.polling import PollingObserver
from app.core.db import session_scope
from app.core.logger import get_logger
from app.core.path_utils import is_accessible_dir, path_from_storage
from app.models.watch_folder import WatchFolder
from app.services.ingest import ingest_path
from app.services.thumbnail import is_supported
logger = get_logger(__name__)
class _ScreenshotEventHandler(FileSystemEventHandler):
"""新文件 -> 入库 -> 触发分析。"""
def __init__(self, loop: asyncio.AbstractEventLoop, notify) -> None: # noqa: ANN001
self._loop = loop
self._notify = notify
def on_created(self, event: FileSystemEvent) -> None:
if event.is_directory:
return
self._handle(Path(event.src_path))
def on_moved(self, event: FileSystemEvent) -> None:
if event.is_directory:
return
self._handle(Path(getattr(event, "dest_path", event.src_path)))
def _handle(self, path: Path) -> None:
if not is_supported(path):
return
# 等待写入完成(截图工具常会先创建空文件再写入)
try:
self._wait_file_ready(path)
except FileNotFoundError:
return
with session_scope() as session:
shot = ingest_path(session, path)
if shot is not None:
asyncio.run_coroutine_threadsafe(self._notify(), self._loop)
@staticmethod
def _wait_file_ready(path: Path, retries: int = 10, interval: float = 0.3) -> None:
"""轮询直至文件大小稳定。"""
import time
last = -1
for _ in range(retries):
if not path.exists():
raise FileNotFoundError(path)
size = path.stat().st_size
if size > 0 and size == last:
return
last = size
time.sleep(interval)
class WatcherService:
"""管理多个监听目录的生命周期。"""
def __init__(self) -> None:
self._observer: PollingObserver | None = None
self._lock = threading.Lock()
self._loop: asyncio.AbstractEventLoop | None = None
self._notify_cb = None
def start(self, loop: asyncio.AbstractEventLoop, notify) -> None: # noqa: ANN001
"""根据数据库中的目录列表启动监听。"""
with self._lock:
self._loop = loop
self._notify_cb = notify
self._stop_locked()
self._observer = PollingObserver(timeout=2.0)
handler = _ScreenshotEventHandler(loop, notify)
with session_scope() as session:
folders = session.scalars(
select(WatchFolder).where(WatchFolder.enabled.is_(True))
).all()
paths = [(f.path, f.recursive) for f in folders]
for path, recursive in paths:
p = path_from_storage(path)
if not is_accessible_dir(p):
logger.warning("监听目录不存在或不可访问,跳过: %s", path)
continue
logger.info("开始监听 %s (recursive=%s)", path, recursive)
self._observer.schedule(handler, str(p), recursive=recursive)
self._observer.start()
def reload(self) -> None:
"""监听目录变更后重启。"""
if self._loop is None or self._notify_cb is None:
return
self.start(self._loop, self._notify_cb)
def stop(self) -> None:
with self._lock:
self._stop_locked()
def _stop_locked(self) -> None:
if self._observer is not None:
try:
self._observer.stop()
self._observer.join(timeout=3)
finally:
self._observer = None
watcher_service = WatcherService()
+237
View File
@@ -0,0 +1,237 @@
"""异步任务调度器:从 jobs 表取任务并并发执行。
事务规则:
- 调度循环只用短事务 claim 任务、汇总状态。
- 真正的 OCR/VLM 调用由 `analyze_screenshot_by_id` 自己管理短事务,
绝不在 worker 这一层包裹长事务。
"""
from __future__ import annotations
import asyncio
from datetime import datetime, timedelta
from typing import Optional
from sqlalchemy import case, func, or_, select
from app.core.config import settings
from app.core.db import session_scope
from app.core.logger import get_logger
from app.models.job import Job, JobKind, JobStatus
from app.models.screenshot import ProcessStatus, Screenshot
from app.services.analyze import analyze_ocr_only_by_id, analyze_screenshot_by_id
logger = get_logger(__name__)
class AnalyzeWorker:
"""单实例后台 worker,负责把 jobs 表中的待处理项跑完。"""
def __init__(self) -> None:
self._task: Optional[asyncio.Task] = None
self._event = asyncio.Event()
self._stop = False
self._semaphore = asyncio.Semaphore(settings.analyze_concurrency)
self._inflight: int = 0
self._lock = asyncio.Lock()
self._loop: Optional[asyncio.AbstractEventLoop] = None
async def start(self) -> None:
"""启动主循环。"""
# 启动时把上次中断的 RUNNING 任务复位
with session_scope() as session:
running = session.scalars(
select(Job).where(Job.status == JobStatus.RUNNING.value)
).all()
for job in running:
job.status = JobStatus.PENDING.value
self._stop = False
self._loop = asyncio.get_running_loop()
self._task = asyncio.create_task(self._run(), name="analyze-worker")
self.notify()
async def stop(self) -> None:
self._stop = True
self._event.set()
if self._task is not None:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
def notify(self) -> None:
"""在事件循环线程内通知 worker 有新任务可取。"""
self._event.set()
def notify_threadsafe(self) -> None:
"""跨线程唤醒 workerFastAPI BackgroundTasks / watcher 线程)。
asyncio.Event.set() 本身不是线程安全的,必须通过
loop.call_soon_threadsafe 调度回事件循环线程。
"""
loop = self._loop
if loop is None or loop.is_closed():
return
loop.call_soon_threadsafe(self._event.set)
async def status(self) -> dict[str, int]:
"""供 API 查询当前队列状况(按 status 索引计数,适合大批量)。"""
with session_scope() as session:
pending = session.scalar(
select(Job.id)
.where(Job.status == JobStatus.PENDING.value)
.limit(1)
)
rows = session.execute(
select(Job.status, func.count())
.group_by(Job.status)
).all()
counts = {st.value: 0 for st in JobStatus}
for status, cnt in rows:
counts[status] = int(cnt)
counts["inflight"] = self._inflight
counts["has_more"] = 1 if pending is not None else 0
return counts
def reset_stale_running(self, *, minutes: int = 5, reset_all: bool = False) -> int:
"""把长时间 RUNNING 且无进展的任务复位为 PENDING。"""
with session_scope() as session:
q = select(Job).where(Job.status == JobStatus.RUNNING.value)
if not reset_all:
cutoff = datetime.utcnow() - timedelta(minutes=max(minutes, 1))
q = q.where(Job.started_at.is_not(None), Job.started_at < cutoff)
stale = session.scalars(q).all()
for job in stale:
job.status = JobStatus.PENDING.value
job.started_at = None
count = len(stale)
if count:
logger.info("复位 %d 条 RUNNING 任务为 PENDING", count)
self.notify()
return count
def retry_failed(self, job_ids: Optional[list[int]] = None) -> int:
"""将 failed 任务重新排队。"""
with session_scope() as session:
q = select(Job).where(Job.status == JobStatus.FAILED.value)
if job_ids:
q = q.where(Job.id.in_(job_ids))
failed = session.scalars(q).all()
for job in failed:
job.status = JobStatus.PENDING.value
job.retries = 0
job.last_error = None
job.started_at = None
job.finished_at = None
count = len(failed)
if count:
logger.info("重试 %d 条 failed 任务", count)
self.notify()
return count
async def _run(self) -> None:
"""主循环。"""
idle_rounds = 0
while not self._stop:
job = self._claim_one()
if job is None:
idle_rounds += 1
# 空闲时定期清理僵尸 RUNNING,避免 inflight=0 但 DB 仍显示 running
if idle_rounds >= 3 and self._inflight == 0:
idle_rounds = 0
if self.reset_stale_running(minutes=5):
continue
self._event.clear()
try:
await asyncio.wait_for(self._event.wait(), timeout=10)
except asyncio.TimeoutError:
pass
continue
idle_rounds = 0
await self._semaphore.acquire()
async with self._lock:
self._inflight += 1
asyncio.create_task(
self._process(job["id"], job["screenshot_id"], job["kind"])
)
def _claim_one(self) -> Optional[dict]:
"""短事务:取一条 PENDING 任务;FULL 优先于 OCR 补跑。"""
with session_scope() as session:
job = session.scalar(
select(Job)
.where(
Job.status == JobStatus.PENDING.value,
or_(Job.retries < settings.max_retries, Job.retries.is_(None)),
)
.order_by(
case(
(Job.kind == JobKind.FULL.value, 0),
(Job.kind == JobKind.VLM.value, 1),
else_=2,
),
Job.id.asc(),
)
.limit(1)
)
if job is None:
return None
job.status = JobStatus.RUNNING.value
job.started_at = datetime.utcnow()
session.flush()
return {
"id": job.id,
"screenshot_id": job.screenshot_id,
"kind": job.kind,
}
async def _process(self, job_id: int, screenshot_id: int, kind: str) -> None:
"""执行单个任务,所有 DB 写入均在短事务中。"""
try:
try:
if kind == JobKind.OCR.value:
await analyze_ocr_only_by_id(screenshot_id)
else:
await analyze_screenshot_by_id(screenshot_id)
self._finish(job_id, success=True, kind=kind)
except Exception as exc: # noqa: BLE001
logger.exception("分析失败 #%d (%s): %s", screenshot_id, kind, exc)
self._finish(job_id, success=False, error=str(exc), kind=kind)
finally:
self._semaphore.release()
async with self._lock:
self._inflight -= 1
self.notify()
def _finish(
self,
job_id: int,
success: bool,
error: Optional[str] = None,
kind: str = JobKind.FULL.value,
) -> None:
"""短事务:更新 jobs 表完成状态。"""
with session_scope() as session:
job = session.get(Job, job_id)
if job is None:
return
if success:
job.status = JobStatus.DONE.value
job.last_error = None
else:
job.retries = (job.retries or 0) + 1
if job.retries >= settings.max_retries:
job.status = JobStatus.FAILED.value
# OCR 补跑失败不影响 ai_status
if kind != JobKind.OCR.value:
shot = session.get(Screenshot, job.screenshot_id)
if shot is not None:
shot.ai_status = ProcessStatus.FAILED.value
else:
job.status = JobStatus.PENDING.value
job.last_error = (error or "")[:1000]
job.finished_at = datetime.utcnow()
worker = AnalyzeWorker()
+12
View File
@@ -0,0 +1,12 @@
fastapi==0.115.0
uvicorn[standard]==0.30.6
sqlalchemy==2.0.34
pydantic==2.9.2
pydantic-settings==2.5.2
python-multipart==0.0.10
watchdog==5.0.2
Pillow==10.4.0
pytesseract==0.3.13
httpx==0.27.2
aiofiles==24.1.0
python-dotenv==1.0.1
+20
View File
@@ -0,0 +1,20 @@
"""开发入口:python run.py 启动后端。"""
from __future__ import annotations
import uvicorn
from app.core.config import settings
def main() -> None:
uvicorn.run(
"app.main:app",
host=settings.host,
port=settings.port,
reload=settings.debug,
log_level="info",
)
if __name__ == "__main__":
main()
+12
View File
@@ -0,0 +1,12 @@
<!doctype html>
<html lang="zh-CN" class="dark">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>snapAna · 截图分析</title>
</head>
<body class="bg-slate-950 text-slate-100">
<div id="root"></div>
<script type="module" src="/src/main.tsx"></script>
</body>
</html>
+2813
View File
File diff suppressed because it is too large Load Diff
+31
View File
@@ -0,0 +1,31 @@
{
"name": "snapana-frontend",
"private": true,
"version": "0.1.0",
"type": "module",
"scripts": {
"dev": "vite",
"build": "tsc -b && vite build",
"preview": "vite preview"
},
"dependencies": {
"@tanstack/react-query": "^5.59.0",
"clsx": "^2.1.1",
"lucide-react": "^0.451.0",
"react": "^18.3.1",
"react-dom": "^18.3.1",
"react-router-dom": "^6.26.2",
"react-window": "^1.8.10"
},
"devDependencies": {
"@types/react": "^18.3.10",
"@types/react-dom": "^18.3.0",
"@types/react-window": "^1.8.8",
"@vitejs/plugin-react": "^4.3.2",
"autoprefixer": "^10.4.20",
"postcss": "^8.4.47",
"tailwindcss": "^3.4.13",
"typescript": "^5.6.2",
"vite": "^5.4.8"
}
}
+6
View File
@@ -0,0 +1,6 @@
export default {
plugins: {
tailwindcss: {},
autoprefixer: {},
},
};
+68
View File
@@ -0,0 +1,68 @@
import { NavLink, Route, Routes, Navigate } from "react-router-dom";
import { Home, Images, Shuffle, ListChecks, Settings as Cog, ListOrdered, Hash } from "lucide-react";
import HomePage from "@/pages/Home";
import LibraryPage from "@/pages/Library";
import ShufflePage from "@/pages/Shuffle";
import TodosPage from "@/pages/Todos";
import SettingsPage from "@/pages/Settings";
import QueuePage from "@/pages/Queue";
import TagsPage from "@/pages/Tags";
const navItems = [
{ to: "/", label: "首页", icon: Home },
{ to: "/library", label: "库", icon: Images },
{ to: "/tags", label: "标签", icon: Hash },
{ to: "/queue", label: "队列", icon: ListOrdered },
{ to: "/shuffle", label: "随机", icon: Shuffle },
{ to: "/todos", label: "待办", icon: ListChecks },
{ to: "/settings", label: "设置", icon: Cog },
];
export default function App() {
return (
<div className="flex h-full">
<aside className="flex w-56 shrink-0 flex-col border-r border-slate-800 bg-slate-950/80">
<div className="px-5 py-5">
<div className="text-lg font-semibold tracking-wide text-white">snapAna</div>
<div className="mt-1 text-xs text-slate-500"></div>
</div>
<nav className="flex flex-col gap-1 px-3">
{navItems.map((item) => (
<NavLink
key={item.to}
to={item.to}
end={item.to === "/"}
className={({ isActive }) =>
`flex items-center gap-2 rounded-md px-3 py-2 text-sm transition ${
isActive
? "bg-brand-600/20 text-white"
: "text-slate-400 hover:bg-slate-800/60 hover:text-white"
}`
}
>
<item.icon size={16} />
{item.label}
</NavLink>
))}
</nav>
<div className="mt-auto px-5 py-4 text-xs text-slate-500">
v0.1 ·
</div>
</aside>
<main className="flex-1 overflow-hidden">
<Routes>
<Route path="/" element={<HomePage />} />
<Route path="/library" element={<LibraryPage />} />
<Route path="/tags" element={<TagsPage />} />
<Route path="/queue" element={<QueuePage />} />
<Route path="/shuffle" element={<ShufflePage />} />
<Route path="/todos" element={<TodosPage />} />
<Route path="/settings" element={<SettingsPage />} />
<Route path="*" element={<Navigate to="/" replace />} />
</Routes>
</main>
</div>
);
}
+176
View File
@@ -0,0 +1,176 @@
import type {
Category,
JobListResp,
ListQuery,
ListResp,
ProviderConfig,
ProviderConfigOut,
ProviderTestResult,
ScreenshotBrief,
ScreenshotDetail,
StatsResp,
Tag,
TagListResp,
TodoItem,
TodoListQuery,
TodoListResp,
WatchFolder,
} from "@/types";
const BASE = "";
async function request<T>(
path: string,
init?: RequestInit & { params?: Record<string, unknown> }
): Promise<T> {
const { params, ...rest } = init ?? {};
const url = new URL(BASE + path, window.location.origin);
if (params) {
Object.entries(params).forEach(([key, value]) => {
if (value === undefined || value === null || value === "") return;
url.searchParams.append(key, String(value));
});
}
const resp = await fetch(url.pathname + url.search, {
headers: {
"Content-Type": "application/json",
...(rest.headers ?? {}),
},
...rest,
});
if (!resp.ok) {
let detail = resp.statusText;
try {
const data = await resp.json();
detail = data.detail ?? detail;
} catch {
/* ignore */
}
throw new Error(detail);
}
if (resp.status === 204) return undefined as T;
return (await resp.json()) as T;
}
export const api = {
listScreenshots: (query: ListQuery) =>
request<ListResp>("/api/screenshots", { params: query as Record<string, unknown> }),
getScreenshot: (id: number) => request<ScreenshotDetail>(`/api/screenshots/${id}`),
randomScreenshots: (params: { n?: number; category_id?: number }) =>
request<ScreenshotBrief[]>("/api/screenshots/random", {
params: params as Record<string, unknown>,
}),
stats: () => request<StatsResp>("/api/screenshots/stats"),
reanalyze: (id: number) =>
request<{ ok: boolean }>(`/api/screenshots/${id}/reanalyze`, { method: "POST" }),
reocr: (id: number) =>
request<{ ok: boolean; job_id?: number; message?: string }>(
`/api/screenshots/${id}/reocr`,
{ method: "POST" }
),
updateScreenshot: (id: number, payload: Partial<{ category_id: number | null; is_favorite: boolean; is_hidden: boolean; tags: string[] }>) =>
request<ScreenshotDetail>(`/api/screenshots/${id}`, {
method: "PATCH",
body: JSON.stringify(payload),
}),
deleteScreenshot: (id: number) =>
request<{ ok: boolean }>(`/api/screenshots/${id}`, { method: "DELETE" }),
listTodos: (params: TodoListQuery) =>
request<TodoListResp>("/api/todos", { params: params as Record<string, unknown> }),
todoSummary: () => request<Record<string, number>>("/api/todos/summary"),
updateTodo: (id: number, payload: Partial<{ status: string; title: string; note: string }>) =>
request<TodoItem>(`/api/todos/${id}`, {
method: "PATCH",
body: JSON.stringify(payload),
}),
deleteTodo: (id: number) =>
request<{ ok: boolean }>(`/api/todos/${id}`, { method: "DELETE" }),
listWatchFolders: () => request<WatchFolder[]>("/api/watch/folders"),
addWatchFolder: (payload: Omit<WatchFolder, "id">) =>
request<WatchFolder>("/api/watch/folders", {
method: "POST",
body: JSON.stringify(payload),
}),
updateWatchFolder: (id: number, payload: Omit<WatchFolder, "id">) =>
request<WatchFolder>(`/api/watch/folders/${id}`, {
method: "PATCH",
body: JSON.stringify(payload),
}),
deleteWatchFolder: (id: number) =>
request<{ ok: boolean }>(`/api/watch/folders/${id}`, { method: "DELETE" }),
importNow: (payload: Omit<WatchFolder, "id">) =>
request<{ ok: boolean }>("/api/watch/import", {
method: "POST",
body: JSON.stringify(payload),
}),
validateWatchPath: (path: string) =>
request<{ ok: boolean; path: string; sample_image_count: number; samples: string[]; message: string }>(
"/api/watch/validate-path",
{
method: "POST",
body: JSON.stringify({ path, enabled: true, recursive: true, is_sensitive: false }),
}
),
queueStatus: () => request<Record<string, number>>("/api/watch/queue"),
listJobs: (params: { status?: string; kind?: string; page?: number; size?: number }) =>
request<JobListResp>("/api/watch/jobs", {
params: params as Record<string, unknown>,
}),
retryFailedJobs: (jobIds?: number[]) =>
request<{ ok: boolean; count: number }>("/api/watch/jobs/retry-failed", {
method: "POST",
body: JSON.stringify(jobIds?.length ? { job_ids: jobIds } : {}),
}),
resetStaleJobs: (resetAll?: boolean) =>
request<{ ok: boolean; count: number }>("/api/watch/jobs/reset-stale", {
method: "POST",
params: { reset_all: resetAll ? true : undefined },
}),
enqueueOcrFailed: (limit?: number) =>
request<{ ok: boolean; count: number }>("/api/watch/jobs/enqueue-ocr-failed", {
method: "POST",
params: { limit: limit ?? 500 },
}),
listCategories: () => request<Category[]>("/api/settings/categories"),
createCategory: (payload: Omit<Category, "id">) =>
request<{ id: number }>("/api/settings/categories", {
method: "POST",
body: JSON.stringify(payload),
}),
updateCategory: (id: number, payload: Omit<Category, "id">) =>
request<{ ok: boolean }>(`/api/settings/categories/${id}`, {
method: "PATCH",
body: JSON.stringify(payload),
}),
deleteCategory: (id: number) =>
request<{ ok: boolean }>(`/api/settings/categories/${id}`, { method: "DELETE" }),
listTags: (params?: { q?: string; page?: number; size?: number; sort?: string }) =>
request<TagListResp>("/api/settings/tags", {
params: (params ?? {}) as Record<string, unknown>,
}),
getProvider: (key: "ocr_provider" | "vlm_provider") =>
request<ProviderConfigOut | null>(`/api/settings/providers/${key}`),
setProvider: (key: "ocr_provider" | "vlm_provider", payload: ProviderConfig) =>
request<{ ok: boolean }>(`/api/settings/providers/${key}`, {
method: "PUT",
body: JSON.stringify(payload),
}),
testProvider: (key: "ocr_provider" | "vlm_provider", payload: ProviderConfig) =>
request<ProviderTestResult>(`/api/settings/providers/${key}/test`, {
method: "POST",
body: JSON.stringify(payload),
}),
getRecognitionMode: () =>
request<{ mode: string; options: string[] }>("/api/settings/recognition-mode"),
setRecognitionMode: (mode: string) =>
request<{ ok: boolean; mode: string }>("/api/settings/recognition-mode", {
method: "PUT",
body: JSON.stringify({ mode }),
}),
};
+70
View File
@@ -0,0 +1,70 @@
import { Star } from "lucide-react";
import type { ScreenshotBrief } from "@/types";
import { StatusBadge } from "./StatusBadge";
interface Props {
shot: ScreenshotBrief;
onOpen: (id: number) => void;
onToggleFav?: (shot: ScreenshotBrief) => void;
}
export function Card({ shot, onOpen, onToggleFav }: Props) {
const date = new Date(shot.captured_at).toLocaleString("zh-CN", {
hour12: false,
});
return (
<button
onClick={() => onOpen(shot.id)}
className="card-hover group relative flex w-full flex-col overflow-hidden rounded-xl border border-slate-800 bg-slate-900/40 text-left"
>
<div className="relative aspect-[4/3] w-full overflow-hidden bg-slate-800">
{shot.thumb_url ? (
<img
src={shot.thumb_url}
alt=""
loading="lazy"
className="h-full w-full object-cover transition group-hover:scale-[1.02]"
/>
) : (
<div className="flex h-full w-full items-center justify-center text-xs text-slate-500">
</div>
)}
{shot.category && (
<span
className="absolute left-2 top-2 inline-flex items-center rounded-full bg-black/55 px-2 py-0.5 text-[10px] text-white backdrop-blur"
style={{
borderLeft: `3px solid ${shot.category.color ?? "#6366f1"}`,
}}
>
{shot.category.name}
</span>
)}
<button
className={`absolute right-2 top-2 rounded-full p-1.5 transition ${
shot.is_favorite
? "bg-amber-400/90 text-slate-900"
: "bg-black/40 text-slate-300 opacity-0 group-hover:opacity-100"
}`}
onClick={(e) => {
e.stopPropagation();
onToggleFav?.(shot);
}}
aria-label="收藏"
>
<Star size={14} fill={shot.is_favorite ? "currentColor" : "none"} />
</button>
</div>
<div className="flex flex-1 flex-col gap-1 px-3 py-2">
<div className="line-clamp-2 text-sm font-medium text-slate-100">
{shot.ai_title || "(未生成标题)"}
</div>
<div className="mt-auto flex items-center justify-between text-[11px] text-slate-500">
<span>{date}</span>
<StatusBadge status={shot.ai_status} />
</div>
</div>
</button>
);
}
+81
View File
@@ -0,0 +1,81 @@
import { useEffect, useMemo, useRef, useState } from "react";
import { FixedSizeGrid as Grid } from "react-window";
import type { ScreenshotBrief } from "@/types";
import { Card } from "./Card";
interface Props {
items: ScreenshotBrief[];
onOpen: (id: number) => void;
onToggleFav?: (shot: ScreenshotBrief) => void;
}
const CARD_W = 260;
const CARD_H = 280;
const GAP = 16;
export function CardGrid({ items, onOpen, onToggleFav }: Props) {
const containerRef = useRef<HTMLDivElement>(null);
const [size, setSize] = useState({ w: 0, h: 0 });
useEffect(() => {
const el = containerRef.current;
if (!el) return;
const observer = new ResizeObserver((entries) => {
for (const entry of entries) {
setSize({ w: entry.contentRect.width, h: entry.contentRect.height });
}
});
observer.observe(el);
return () => observer.disconnect();
}, []);
const colCount = useMemo(() => {
const c = Math.max(1, Math.floor((size.w + GAP) / (CARD_W + GAP)));
return c;
}, [size.w]);
const rowCount = Math.ceil(items.length / colCount);
const colWidth = (size.w - GAP) / Math.max(colCount, 1);
return (
<div ref={containerRef} className="relative h-full w-full">
{items.length === 0 ? (
<div className="flex h-full items-center justify-center text-sm text-slate-500">
</div>
) : (
<Grid
columnCount={colCount}
rowCount={rowCount}
columnWidth={colWidth}
rowHeight={CARD_H + GAP}
height={size.h}
width={size.w}
itemKey={({ rowIndex, columnIndex }) => {
const idx = rowIndex * colCount + columnIndex;
const item = items[idx];
return item ? `s-${item.id}` : `e-${rowIndex}-${columnIndex}`;
}}
>
{({ rowIndex, columnIndex, style }) => {
const idx = rowIndex * colCount + columnIndex;
const item = items[idx];
if (!item) return null;
return (
<div
style={{
...style,
paddingRight: GAP,
paddingBottom: GAP,
}}
>
<Card shot={item} onOpen={onOpen} onToggleFav={onToggleFav} />
</div>
);
}}
</Grid>
)}
</div>
);
}
+311
View File
@@ -0,0 +1,311 @@
import { useEffect, useState } from "react";
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
import { Copy, ExternalLink, RefreshCw, Star, Trash2, X } from "lucide-react";
import { api } from "@/api/client";
import type { ScreenshotDetail } from "@/types";
import { StatusBadge } from "./StatusBadge";
interface Props {
id: number | null;
onClose: () => void;
}
export function DetailPanel({ id, onClose }: Props) {
const open = id !== null;
const qc = useQueryClient();
const detail = useQuery({
queryKey: ["screenshot", id],
queryFn: () => api.getScreenshot(id!),
enabled: open,
});
const update = useMutation({
mutationFn: (payload: Parameters<typeof api.updateScreenshot>[1]) =>
api.updateScreenshot(id!, payload),
onSuccess: () => {
qc.invalidateQueries({ queryKey: ["screenshot", id] });
qc.invalidateQueries({ queryKey: ["screenshots"] });
},
});
const reanalyze = useMutation({
mutationFn: () => api.reanalyze(id!),
onSuccess: () => qc.invalidateQueries({ queryKey: ["screenshot", id] }),
});
const reocr = useMutation({
mutationFn: () => api.reocr(id!),
onSuccess: () => {
qc.invalidateQueries({ queryKey: ["screenshot", id] });
qc.invalidateQueries({ queryKey: ["queue"] });
},
});
const remove = useMutation({
mutationFn: () => api.deleteScreenshot(id!),
onSuccess: () => {
qc.invalidateQueries({ queryKey: ["screenshots"] });
onClose();
},
});
const cats = useQuery({ queryKey: ["categories"], queryFn: api.listCategories });
const [tagInput, setTagInput] = useState("");
useEffect(() => {
setTagInput("");
}, [id]);
if (!open) return null;
const s = detail.data as ScreenshotDetail | undefined;
return (
<div className="fixed inset-0 z-50 flex justify-end">
<div className="absolute inset-0 bg-black/40" onClick={onClose} />
<div className="glass relative flex h-full w-full max-w-4xl flex-col overflow-hidden border-l border-slate-800">
<header className="flex items-center justify-between border-b border-slate-800 px-4 py-3">
<div className="flex items-center gap-2 text-sm text-slate-300">
<span className="rounded bg-slate-800 px-2 py-0.5 text-xs">#{id}</span>
{s && <StatusBadge status={s.ai_status} />}
{s?.category && (
<span
className="rounded-full px-2 py-0.5 text-xs"
style={{
backgroundColor: `${s.category.color ?? "#6366f1"}22`,
color: s.category.color ?? "#a5b4fc",
}}
>
{s.category.name}
</span>
)}
</div>
<button className="btn-ghost rounded p-1.5" onClick={onClose} aria-label="关闭">
<X size={18} />
</button>
</header>
{s ? (
<div className="flex flex-1 overflow-hidden">
<div className="flex flex-1 items-center justify-center bg-black/40 p-4">
<img
src={s.file_url}
alt={s.ai_title ?? "screenshot"}
className="max-h-full max-w-full rounded-md object-contain shadow-2xl"
/>
</div>
<div className="w-[420px] shrink-0 overflow-y-auto border-l border-slate-800 px-4 py-4">
<div className="flex flex-wrap items-center gap-2">
<button
className={`btn ${s.is_favorite ? "btn-primary" : ""}`}
onClick={() => update.mutate({ is_favorite: !s.is_favorite })}
>
<Star size={14} /> {s.is_favorite ? "已收藏" : "收藏"}
</button>
<button
className="btn"
onClick={() => reanalyze.mutate()}
disabled={reanalyze.isPending}
>
<RefreshCw size={14} />
</button>
{s.ocr_status === "failed" && s.ai_status === "done" && (
<button
className="btn border-brand-600/50 text-brand-300"
onClick={() => reocr.mutate()}
disabled={reocr.isPending}
>
<RefreshCw size={14} /> OCR
</button>
)}
<a
className="btn"
href={s.file_url}
target="_blank"
rel="noreferrer"
>
<ExternalLink size={14} />
</a>
<button
className="btn text-rose-300 hover:border-rose-500"
onClick={() => {
if (confirm("从库中移除(不删除原文件)?")) remove.mutate();
}}
>
<Trash2 size={14} />
</button>
</div>
<Section title="标题">
<p className="text-base font-medium text-white">
{s.ai_title || "(未生成标题)"}
</p>
</Section>
<Section title="AI 摘要">
<p className="whitespace-pre-line text-sm text-slate-300">
{s.ai_summary || "—"}
</p>
</Section>
{s.ai_suggestion && (
<Section title="AI 建议">
<p className="whitespace-pre-line text-sm text-amber-200/90">
{s.ai_suggestion}
</p>
</Section>
)}
<Section
title="OCR 文本"
right={
<button
className="btn-ghost rounded p-1"
onClick={() => navigator.clipboard.writeText(s.ocr_text ?? "")}
aria-label="复制 OCR 文本"
>
<Copy size={14} />
</button>
}
>
<pre className="max-h-64 overflow-auto whitespace-pre-wrap rounded-md border border-slate-800 bg-slate-900/60 p-2 text-xs text-slate-300">
{s.ocr_text || "(无)"}
</pre>
</Section>
<Section title="分类">
<select
className="input"
value={s.category?.id ?? ""}
onChange={(e) =>
update.mutate({
category_id: e.target.value ? Number(e.target.value) : null,
})
}
>
<option value=""></option>
{cats.data?.map((c) => (
<option key={c.id} value={c.id}>
{c.name}
</option>
))}
</select>
</Section>
<Section title="标签">
<div className="flex flex-wrap gap-1.5">
{s.tags.map((t) => (
<span key={t.id} className="chip">
#{t.name}
<button
className="text-slate-500 hover:text-rose-300"
onClick={() =>
update.mutate({
tags: s.tags.filter((x) => x.id !== t.id).map((x) => x.name),
})
}
>
<X size={12} />
</button>
</span>
))}
</div>
<div className="mt-2 flex gap-2">
<input
className="input"
placeholder="新增标签后回车"
value={tagInput}
onChange={(e) => setTagInput(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter" && tagInput.trim()) {
update.mutate({
tags: [
...s.tags.map((t) => t.name),
tagInput.trim(),
],
});
setTagInput("");
}
}}
/>
</div>
</Section>
{s.todos.length > 0 && (
<Section title="待办">
<ul className="flex flex-col gap-2">
{s.todos.map((t) => (
<li
key={t.id}
className="rounded-md border border-slate-800 bg-slate-900/60 p-2 text-sm"
>
<div className="flex items-center justify-between">
<span className="text-slate-100">{t.title}</span>
<span className="text-[10px] text-slate-500">
{t.kind} · {t.status}
</span>
</div>
{t.note && (
<div className="mt-1 text-xs text-slate-400">{t.note}</div>
)}
</li>
))}
</ul>
</Section>
)}
<Section title="元信息">
<div className="grid grid-cols-2 gap-y-1 text-xs text-slate-400">
<div></div>
<div className="text-slate-200">
{s.width}×{s.height}
</div>
<div></div>
<div className="text-slate-200">
{(s.size / 1024).toFixed(1)} KB
</div>
<div></div>
<div className="break-all text-slate-200">{s.path}</div>
<div></div>
<div className="text-slate-200">
{new Date(s.captured_at).toLocaleString("zh-CN", {
hour12: false,
})}
</div>
</div>
</Section>
</div>
</div>
) : (
<div className="flex flex-1 items-center justify-center text-slate-500">
</div>
)}
</div>
</div>
);
}
function Section({
title,
right,
children,
}: {
title: string;
right?: React.ReactNode;
children: React.ReactNode;
}) {
return (
<div className="mt-5">
<div className="mb-1.5 flex items-center justify-between">
<span className="text-xs font-medium uppercase tracking-wide text-slate-500">
{title}
</span>
{right}
</div>
{children}
</div>
);
}
+176
View File
@@ -0,0 +1,176 @@
import { useQuery } from "@tanstack/react-query";
import { Link } from "react-router-dom";
import { Search, Star, X } from "lucide-react";
import { api } from "@/api/client";
import type { ListQuery } from "@/types";
interface Props {
query: ListQuery;
onChange: (next: ListQuery) => void;
}
export function FilterBar({ query, onChange }: Props) {
const cats = useQuery({ queryKey: ["categories"], queryFn: api.listCategories });
const tags = useQuery({
queryKey: ["tags", "top"],
queryFn: () => api.listTags({ size: 30, sort: "count_desc" }),
});
const patch = (next: Partial<ListQuery>) => onChange({ ...query, ...next, page: 1 });
return (
<aside className="flex w-64 shrink-0 flex-col gap-5 overflow-y-auto border-r border-slate-800 bg-slate-950/60 px-4 py-4">
<div>
<label className="mb-1 block text-xs font-medium text-slate-400"></label>
<div className="relative">
<Search
size={14}
className="pointer-events-none absolute left-2.5 top-1/2 -translate-y-1/2 text-slate-500"
/>
<input
className="input pl-7"
placeholder="OCR / 标题 / 标签"
value={query.q ?? ""}
onChange={(e) => patch({ q: e.target.value })}
/>
{query.q && (
<button
className="absolute right-2 top-1/2 -translate-y-1/2 text-slate-500 hover:text-white"
onClick={() => patch({ q: "" })}
aria-label="清除"
>
<X size={14} />
</button>
)}
</div>
</div>
<div>
<div className="mb-2 flex items-center justify-between">
<label className="text-xs font-medium text-slate-400"></label>
{query.category_id && (
<button
className="text-[11px] text-slate-500 hover:text-white"
onClick={() => patch({ category_id: undefined })}
>
</button>
)}
</div>
<div className="flex flex-wrap gap-1.5">
{cats.data?.map((c) => (
<button
key={c.id}
className={`chip ${query.category_id === c.id ? "chip-active" : ""}`}
onClick={() =>
patch({ category_id: query.category_id === c.id ? undefined : c.id })
}
>
<span
className="h-2 w-2 rounded-full"
style={{ background: c.color ?? "#64748b" }}
/>
{c.name}
</button>
))}
</div>
</div>
<div>
<label className="mb-1 block text-xs font-medium text-slate-400"></label>
<div className="flex flex-col gap-2">
<input
type="date"
className="input"
value={query.date_from ? query.date_from.slice(0, 10) : ""}
onChange={(e) =>
patch({ date_from: e.target.value ? `${e.target.value}T00:00:00` : undefined })
}
/>
<input
type="date"
className="input"
value={query.date_to ? query.date_to.slice(0, 10) : ""}
onChange={(e) =>
patch({ date_to: e.target.value ? `${e.target.value}T23:59:59` : undefined })
}
/>
</div>
</div>
<div>
<label className="mb-1 block text-xs font-medium text-slate-400"></label>
<select
className="input"
value={query.sort ?? "captured_desc"}
onChange={(e) => patch({ sort: e.target.value })}
>
<option value="captured_desc"></option>
<option value="captured_asc"></option>
<option value="imported_desc"></option>
<option value="imported_asc"></option>
<option value="title_asc"> AZ</option>
<option value="title_desc"> ZA</option>
<option value="size_desc"></option>
<option value="size_asc"></option>
</select>
</div>
<div>
<label className="mb-2 block text-xs font-medium text-slate-400"></label>
<div className="flex flex-wrap gap-1.5">
{["", "done", "pending", "running", "failed"].map((st) => (
<button
key={st || "all"}
className={`chip ${
(query.status ?? "") === st ? "chip-active" : ""
}`}
onClick={() => patch({ status: st || undefined })}
>
{st === ""
? "全部"
: st === "done"
? "已分析"
: st === "pending"
? "排队"
: st === "running"
? "分析中"
: "失败"}
</button>
))}
</div>
</div>
<button
className={`btn ${query.favorite ? "btn-primary" : ""}`}
onClick={() => patch({ favorite: query.favorite ? undefined : true })}
>
<Star size={14} />
</button>
{tags.data && tags.data.items.length > 0 && (
<div>
<div className="mb-2 flex items-center justify-between">
<label className="text-xs font-medium text-slate-400"></label>
<Link to="/tags" className="text-[11px] text-brand-400 hover:underline">
</Link>
</div>
<div className="flex flex-wrap gap-1.5">
{tags.data.items.map((t) => (
<button
key={t.id}
className={`chip ${query.tag === t.name ? "chip-active" : ""}`}
onClick={() => patch({ tag: query.tag === t.name ? undefined : t.name })}
>
#{t.name}
<span className="opacity-50">{t.count}</span>
</button>
))}
</div>
</div>
)}
</aside>
);
}
+31
View File
@@ -0,0 +1,31 @@
interface Props {
status: string;
}
const labelMap: Record<string, string> = {
pending: "等待中",
running: "分析中",
done: "完成",
failed: "失败",
skipped: "跳过",
};
const styleMap: Record<string, string> = {
pending: "bg-slate-700/60 text-slate-300",
running: "bg-amber-500/20 text-amber-300",
done: "bg-emerald-500/20 text-emerald-300",
failed: "bg-rose-500/20 text-rose-300",
skipped: "bg-slate-700/40 text-slate-400",
};
export function StatusBadge({ status }: Props) {
return (
<span
className={`inline-flex items-center rounded-full px-2 py-0.5 text-[10px] ${
styleMap[status] ?? styleMap.pending
}`}
>
{labelMap[status] ?? status}
</span>
);
}
+67
View File
@@ -0,0 +1,67 @@
@tailwind base;
@tailwind components;
@tailwind utilities;
:root {
color-scheme: dark;
}
html,
body,
#root {
height: 100%;
}
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", "PingFang SC",
"Microsoft YaHei", "Helvetica Neue", Helvetica, Arial, sans-serif;
}
::-webkit-scrollbar {
width: 10px;
height: 10px;
}
::-webkit-scrollbar-track {
background: transparent;
}
::-webkit-scrollbar-thumb {
background: rgba(148, 163, 184, 0.25);
border-radius: 999px;
}
::-webkit-scrollbar-thumb:hover {
background: rgba(148, 163, 184, 0.45);
}
.glass {
background: rgba(15, 23, 42, 0.7);
backdrop-filter: blur(12px);
-webkit-backdrop-filter: blur(12px);
}
.card-hover {
transition: transform 0.18s ease, box-shadow 0.18s ease, border-color 0.18s ease;
}
.card-hover:hover {
transform: translateY(-2px);
border-color: rgba(99, 102, 241, 0.6);
box-shadow: 0 12px 32px -16px rgba(99, 102, 241, 0.5);
}
.btn {
@apply inline-flex items-center gap-1.5 rounded-md border border-slate-700 bg-slate-800/60 px-3 py-1.5 text-sm text-slate-200 transition hover:border-brand-500 hover:bg-slate-800 hover:text-white;
}
.btn-primary {
@apply border-brand-500 bg-brand-600 text-white hover:bg-brand-500;
}
.btn-ghost {
@apply border-transparent bg-transparent text-slate-400 hover:text-white;
}
.input {
@apply w-full rounded-md border border-slate-700 bg-slate-900/70 px-3 py-2 text-sm text-slate-100 placeholder:text-slate-500 focus:border-brand-500 focus:outline-none;
}
.chip {
@apply inline-flex items-center gap-1 rounded-full border border-slate-700 bg-slate-800/60 px-2.5 py-0.5 text-xs text-slate-300;
}
.chip-active {
@apply border-brand-500 bg-brand-600/20 text-white;
}
+26
View File
@@ -0,0 +1,26 @@
import React from "react";
import ReactDOM from "react-dom/client";
import { BrowserRouter } from "react-router-dom";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import App from "./App";
import "./index.css";
const queryClient = new QueryClient({
defaultOptions: {
queries: {
refetchOnWindowFocus: false,
staleTime: 15_000,
},
},
});
ReactDOM.createRoot(document.getElementById("root")!).render(
<React.StrictMode>
<QueryClientProvider client={queryClient}>
<BrowserRouter>
<App />
</BrowserRouter>
</QueryClientProvider>
</React.StrictMode>
);
+166
View File
@@ -0,0 +1,166 @@
import { useQuery } from "@tanstack/react-query";
import { Link } from "react-router-dom";
import { useState } from "react";
import { Sparkles, Images, ListChecks, Loader2 } from "lucide-react";
import { api } from "@/api/client";
import { DetailPanel } from "@/components/DetailPanel";
import { Card } from "@/components/Card";
export default function HomePage() {
const stats = useQuery({ queryKey: ["stats"], queryFn: api.stats });
const daily = useQuery({
queryKey: ["random", 6],
queryFn: () => api.randomScreenshots({ n: 6 }),
staleTime: 60_000,
});
const todoSummary = useQuery({ queryKey: ["todo-summary"], queryFn: api.todoSummary });
const queue = useQuery({
queryKey: ["queue"],
queryFn: api.queueStatus,
refetchInterval: 5000,
});
const [openId, setOpenId] = useState<number | null>(null);
return (
<div className="h-full overflow-y-auto px-8 py-6">
<header className="mb-6 flex items-end justify-between">
<div>
<h1 className="flex items-center gap-2 text-2xl font-semibold text-white">
<Sparkles size={22} className="text-brand-400" />
</h1>
<p className="mt-1 text-sm text-slate-400">
AI
</p>
</div>
<div className="flex gap-2 text-sm">
<Link to="/library" className="btn btn-primary">
<Images size={14} />
</Link>
<Link to="/todos" className="btn">
<ListChecks size={14} />
</Link>
</div>
</header>
<section className="mb-6 grid grid-cols-4 gap-4">
<StatCard label="截图总数" value={stats.data?.total ?? 0} />
<StatCard
label="已分析"
value={stats.data?.by_status?.done ?? 0}
hint={`失败 ${stats.data?.by_status?.failed ?? 0} · 排队 ${stats.data?.by_status?.pending ?? 0}`}
/>
<StatCard
label="待办"
value={todoSummary.data?.pending ?? 0}
hint={`已完成 ${todoSummary.data?.done ?? 0}`}
/>
<StatCard
label="队列"
value={queue.data?.pending ?? 0}
hint={`运行中 ${queue.data?.running ?? 0} · 失败 ${queue.data?.failed ?? 0}`}
icon={(queue.data?.running ?? 0) > 0 ? <Loader2 className="animate-spin" size={14} /> : undefined}
to="/queue"
/>
</section>
<section className="mb-8">
<div className="mb-3 flex items-end justify-between">
<h2 className="text-lg font-semibold text-white"></h2>
<button
className="btn"
onClick={() => daily.refetch()}
disabled={daily.isFetching}
>
</button>
</div>
<div className="grid grid-cols-2 gap-4 md:grid-cols-3 xl:grid-cols-6">
{(daily.data ?? []).map((s) => (
<Card key={s.id} shot={s} onOpen={setOpenId} />
))}
{(daily.data?.length ?? 0) === 0 && !daily.isLoading && (
<div className="col-span-full rounded-md border border-dashed border-slate-700 p-8 text-center text-sm text-slate-500">
<Link to="/settings" className="mx-1 text-brand-400 underline">
</Link>
</div>
)}
</div>
</section>
{stats.data && stats.data.by_category.length > 0 && (
<section className="mb-8">
<h2 className="mb-3 text-lg font-semibold text-white"></h2>
<div className="grid grid-cols-2 gap-3 md:grid-cols-4">
{stats.data.by_category
.filter((c) => c.id)
.map((c) => (
<Link
key={c.id}
to={`/library`}
className="flex items-center justify-between rounded-lg border border-slate-800 bg-slate-900/50 px-4 py-3 transition hover:border-brand-500"
>
<span className="flex items-center gap-2 text-sm text-slate-200">
<span
className="h-3 w-3 rounded-full"
style={{ background: c.color ?? "#6366f1" }}
/>
{c.name}
</span>
<span className="text-sm font-medium text-slate-100">{c.count}</span>
</Link>
))}
</div>
</section>
)}
<DetailPanel id={openId} onClose={() => setOpenId(null)} />
</div>
);
}
function StatCard({
label,
value,
hint,
icon,
to,
}: {
label: string;
value: number;
hint?: string;
icon?: React.ReactNode;
to?: string;
}) {
const inner = (
<>
<div className="text-xs text-slate-500">{label}</div>
<div className="mt-1 flex items-center gap-2 text-2xl font-semibold text-white">
{value.toLocaleString()}
{icon}
</div>
{hint && <div className="mt-1 text-xs text-slate-500">{hint}</div>}
</>
);
if (to) {
return (
<Link
to={to}
className="block rounded-xl border border-slate-800 bg-slate-900/50 px-5 py-4 transition hover:border-brand-500"
>
{inner}
</Link>
);
}
return (
<div className="rounded-xl border border-slate-800 bg-slate-900/50 px-5 py-4">
{inner}
</div>
);
}
+88
View File
@@ -0,0 +1,88 @@
import { useMemo, useState } from "react";
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
import { useSearchParams } from "react-router-dom";
import { api } from "@/api/client";
import { CardGrid } from "@/components/CardGrid";
import { DetailPanel } from "@/components/DetailPanel";
import { FilterBar } from "@/components/FilterBar";
import type { ListQuery, ScreenshotBrief } from "@/types";
const PAGE_SIZE = 80;
function initialQuery(params: URLSearchParams): ListQuery {
return {
page: 1,
size: PAGE_SIZE,
tag: params.get("tag") ?? undefined,
q: params.get("q") ?? undefined,
};
}
export default function LibraryPage() {
const [searchParams] = useSearchParams();
const [query, setQuery] = useState<ListQuery>(() => initialQuery(searchParams));
const [openId, setOpenId] = useState<number | null>(null);
const qc = useQueryClient();
const listKey = useMemo(() => ["screenshots", query], [query]);
const list = useQuery({
queryKey: listKey,
queryFn: () => api.listScreenshots({ ...query, size: PAGE_SIZE }),
placeholderData: (prev) => prev,
});
const totalPages = Math.max(1, Math.ceil((list.data?.total ?? 0) / PAGE_SIZE));
const toggleFav = useMutation({
mutationFn: (shot: ScreenshotBrief) =>
api.updateScreenshot(shot.id, { is_favorite: !shot.is_favorite }),
onSuccess: () => qc.invalidateQueries({ queryKey: ["screenshots"] }),
});
return (
<div className="flex h-full">
<FilterBar query={query} onChange={setQuery} />
<div className="flex flex-1 flex-col overflow-hidden">
<header className="flex items-center justify-between border-b border-slate-800 px-5 py-3">
<div>
<h1 className="text-lg font-semibold text-white"></h1>
<p className="text-xs text-slate-500">
{list.data?.total ?? 0}
{list.isFetching && <span className="ml-2 text-brand-400"></span>}
</p>
</div>
<div className="flex items-center gap-2 text-sm">
<button
className="btn"
disabled={(query.page ?? 1) <= 1}
onClick={() => setQuery((q) => ({ ...q, page: Math.max(1, (q.page ?? 1) - 1) }))}
>
</button>
<span className="text-xs text-slate-400">
{query.page ?? 1} / {totalPages}
</span>
<button
className="btn"
disabled={(query.page ?? 1) >= totalPages}
onClick={() =>
setQuery((q) => ({ ...q, page: Math.min(totalPages, (q.page ?? 1) + 1) }))
}
>
</button>
</div>
</header>
<div className="flex-1 overflow-hidden p-4">
<CardGrid
items={list.data?.items ?? []}
onOpen={setOpenId}
onToggleFav={(shot) => toggleFav.mutate(shot)}
/>
</div>
</div>
<DetailPanel id={openId} onClose={() => setOpenId(null)} />
</div>
);
}
+673
View File
@@ -0,0 +1,673 @@
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
import { useState } from "react";
import {
ListOrdered,
Loader2,
RefreshCw,
RotateCcw,
AlertCircle,
Image as ImageIcon,
} from "lucide-react";
import { api } from "@/api/client";
import { DetailPanel } from "@/components/DetailPanel";
import type { JobItem } from "@/types";
const STATUS_TABS: { key: string; label: string }[] = [
{ key: "failed", label: "失败" },
{ key: "running", label: "运行中" },
{ key: "pending", label: "排队中" },
{ key: "done", label: "已完成" },
];
const PAGE_SIZE = 50;
/** 从路径取文件名,便于列表展示。 */
function basename(path?: string | null) {
if (!path) return "—";
const parts = path.replace(/\\/g, "/").split("/");
return parts[parts.length - 1] || path;
}
export default function QueuePage() {
const [status, setStatus] = useState("failed");
const [page, setPage] = useState(1);
const [openId, setOpenId] = useState<number | null>(null);
const qc = useQueryClient();
const queue = useQuery({
queryKey: ["queue"],
queryFn: api.queueStatus,
refetchInterval: 5000,
});
const list = useQuery({
queryKey: ["jobs", status, page],
queryFn: () => api.listJobs({ status, page, size: PAGE_SIZE }),
placeholderData: (prev) => prev,
});
const totalPages = Math.max(1, Math.ceil((list.data?.total ?? 0) / PAGE_SIZE));
const retryFailed = useMutation({
mutationFn: (jobIds?: number[]) => api.retryFailedJobs(jobIds),
onSuccess: () => {
qc.invalidateQueries({ queryKey: ["jobs"] });
qc.invalidateQueries({ queryKey: ["queue"] });
qc.invalidateQueries({ queryKey: ["stats"] });
},
});
const resetStale = useMutation({
mutationFn: (all?: boolean) => api.resetStaleJobs(all),
onSuccess: () => {
qc.invalidateQueries({ queryKey: ["jobs"] });
qc.invalidateQueries({ queryKey: ["queue"] });
},
});
const enqueueOcr = useMutation({
mutationFn: () => api.enqueueOcrFailed(500),
onSuccess: () => {
qc.invalidateQueries({ queryKey: ["jobs"] });
qc.invalidateQueries({ queryKey: ["queue"] });
},
});
const onTabChange = (key: string) => {
setStatus(key);
setPage(1);
};
return (
<div className="flex h-full flex-col px-8 py-6">
<header className="mb-4 flex flex-wrap items-end justify-between gap-4">
<div>
<h1 className="flex items-center gap-2 text-2xl font-semibold text-white">
<ListOrdered size={22} className="text-brand-400" />
</h1>
<p className="mt-1 text-xs text-slate-500">
OCR
</p>
</div>
<div className="flex flex-wrap gap-2">
{(queue.data?.ocr_retryable ?? 0) > 0 && (
<button
className="btn border-brand-600/50 text-brand-300"
disabled={enqueueOcr.isPending}
onClick={() => enqueueOcr.mutate()}
>
<RefreshCw size={14} className={enqueueOcr.isPending ? "animate-spin" : ""} />
OCR ({queue.data?.ocr_retryable})
</button>
)}
{(queue.data?.running ?? 0) > 0 && (queue.data?.inflight ?? 0) === 0 && (
<button
className="btn border-amber-600/50 text-amber-300"
disabled={resetStale.isPending}
onClick={() => resetStale.mutate(true)}
>
<RotateCcw size={14} />
({queue.data?.running})
</button>
)}
{(queue.data?.failed ?? 0) > 0 && (
<button
className="btn btn-primary"
disabled={retryFailed.isPending}
onClick={() => retryFailed.mutate(undefined)}
>
<RefreshCw size={14} className={retryFailed.isPending ? "animate-spin" : ""} />
({queue.data?.failed})
</button>
)}
</div>
</header>
<section className="mb-4 grid grid-cols-2 gap-3 md:grid-cols-6">
<QueueStat label="排队" value={queue.data?.pending ?? 0} />
<QueueStat
label="运行中"
value={queue.data?.running ?? 0}
hint={`实际并发 ${queue.data?.inflight ?? 0}`}
spinning={(queue.data?.inflight ?? 0) > 0}
/>
<QueueStat label="已完成" value={queue.data?.done ?? 0} />
<QueueStat label="失败" value={queue.data?.failed ?? 0} accent="text-red-400" />
<QueueStat
label="OCR 待补"
value={queue.data?.ocr_retryable ?? 0}
hint={`队列中 ${queue.data?.ocr_pending ?? 0}`}
accent="text-amber-300"
/>
<QueueStat label="Worker" value={queue.data?.inflight ?? 0} hint="inflight" />
</section>
<div className="mb-3 flex flex-wrap items-center gap-2">
{STATUS_TABS.map((t) => (
<button
key={t.key}
className={`chip ${status === t.key ? "chip-active" : ""}`}
onClick={() => onTabChange(t.key)}
>
{t.label}
<span className="opacity-60">{queue.data?.[t.key as keyof typeof queue.data] ?? 0}</span>
</button>
))}
<div className="ml-auto flex items-center gap-2 text-sm">
<button
className="btn"
disabled={page <= 1}
onClick={() => setPage((p) => Math.max(1, p - 1))}
>
</button>
<span className="text-xs text-slate-400">
{page} / {totalPages} · {list.data?.total ?? 0}
</span>
<button
className="btn"
disabled={page >= totalPages}
onClick={() => setPage((p) => Math.min(totalPages, p + 1))}
>
</button>
</div>
</div>
<div className="flex-1 overflow-y-auto">
{list.isLoading && (
<div className="flex h-40 items-center justify-center text-sm text-slate-500">
<Loader2 className="mr-2 animate-spin" size={16} />
</div>
)}
{!list.isLoading && (list.data?.items.length ?? 0) === 0 && (
<div className="flex h-40 items-center justify-center text-sm text-slate-500">
</div>
)}
<ul className="grid gap-3">
{(list.data?.items ?? []).map((job) => (
<JobRow
key={job.id}
job={job}
onOpen={() => setOpenId(job.screenshot_id)}
onRetry={() => retryFailed.mutate([job.id])}
retrying={retryFailed.isPending}
/>
))}
</ul>
{status === "pending" && (list.data?.total ?? 0) > PAGE_SIZE && (
<p className="mt-4 text-center text-xs text-slate-500">
Worker id
</p>
)}
</div>
<DetailPanel id={openId} onClose={() => setOpenId(null)} />
</div>
);
}
function QueueStat({
label,
value,
hint,
spinning,
accent,
}: {
label: string;
value: number;
hint?: string;
spinning?: boolean;
accent?: string;
}) {
return (
<div className="rounded-lg border border-slate-800 bg-slate-900/50 px-4 py-3">
<div className="text-xs text-slate-500">{label}</div>
<div className={`mt-1 flex items-center gap-2 text-xl font-semibold ${accent ?? "text-white"}`}>
{value.toLocaleString()}
{spinning && <Loader2 className="animate-spin text-brand-400" size={14} />}
</div>
{hint && <div className="mt-0.5 text-[10px] text-slate-500">{hint}</div>}
</div>
);
}
function JobRow({
job,
onOpen,
onRetry,
retrying,
}: {
job: JobItem;
onOpen: () => void;
onRetry: () => void;
retrying: boolean;
}) {
const statusColor =
job.status === "failed"
? "text-red-400"
: job.status === "running"
? "text-amber-400"
: job.status === "done"
? "text-emerald-400"
: "text-slate-400";
return (
<li className="rounded-lg border border-slate-800 bg-slate-900/50 p-4">
<div className="flex gap-4">
<button
type="button"
className="h-16 w-24 shrink-0 overflow-hidden rounded-md border border-slate-700 bg-slate-950"
onClick={onOpen}
>
{job.thumb_url ? (
<img src={job.thumb_url} alt="" className="h-full w-full object-cover" />
) : (
<div className="flex h-full items-center justify-center text-slate-600">
<ImageIcon size={20} />
</div>
)}
</button>
<div className="min-w-0 flex-1">
<div className="mb-1 flex flex-wrap items-center gap-2 text-xs">
<span className="font-mono text-slate-500">#{job.id}</span>
<span className="rounded bg-slate-800 px-1.5 py-0.5 text-[10px] uppercase text-slate-400">
{job.kind}
</span>
<span className={`font-medium ${statusColor}`}>{job.status}</span>
<span className="text-slate-500"> {job.retries}</span>
{job.ai_status && (
<span className="text-slate-500">AI {job.ai_status}</span>
)}
{job.ocr_status && (
<span className={job.ocr_status === "failed" ? "text-red-400" : "text-slate-500"}>
OCR {job.ocr_status}
</span>
)}
</div>
<div className="truncate text-sm font-medium text-slate-100">
{job.ai_title || basename(job.path)}
</div>
<div className="truncate text-xs text-slate-500" title={job.path ?? undefined}>
{job.path}
</div>
{job.last_error && (
<div className="mt-2 flex gap-2 rounded-md border border-red-900/50 bg-red-950/30 px-3 py-2 text-xs text-red-200">
<AlertCircle size={14} className="mt-0.5 shrink-0 text-red-400" />
<pre className="whitespace-pre-wrap break-all font-mono leading-relaxed">
{job.last_error}
</pre>
</div>
)}
<div className="mt-2 flex flex-wrap gap-2 text-[10px] text-slate-500">
{job.created_at && <span> {fmtTime(job.created_at)}</span>}
{job.started_at && <span> {fmtTime(job.started_at)}</span>}
{job.finished_at && <span> {fmtTime(job.finished_at)}</span>}
</div>
</div>
<div className="flex shrink-0 flex-col gap-2">
<button className="btn" onClick={onOpen}>
</button>
{job.status === "failed" && (
<button className="btn" disabled={retrying} onClick={onRetry}>
</button>
)}
</div>
</div>
</li>
);
}
function fmtTime(iso: string) {
return new Date(iso).toLocaleString("zh-CN", { hour12: false });
}
+689
View File
@@ -0,0 +1,689 @@
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
import { useEffect, useState } from "react";
import { FolderPlus, FolderOpen, RefreshCw, ShieldAlert, Trash2, Plug } from "lucide-react";
import { api } from "@/api/client";
import type {
Category,
ProviderConfig,
ProviderConfigOut,
ProviderTestResult,
RecognitionMode,
WatchFolder,
} from "@/types";
import { RECOGNITION_MODE_LABELS } from "@/types";
export default function SettingsPage() {
return (
<div className="h-full overflow-y-auto px-8 py-6">
<h1 className="mb-6 text-2xl font-semibold text-white"></h1>
<div className="flex flex-col gap-8">
<WatchFolderSection />
<RecognitionModeSection />
<ProviderSection
k="ocr_provider"
title="OCR 引擎"
desc="传统文字识别。在「传统 OCR / 混合」模式下使用;也可单独选「视觉模型识文」作为 OCR 引擎。"
defaults={{ type: "tesseract", extra: { lang: "chi_sim+eng" } }}
/>
<ProviderSection
k="vlm_provider"
title="视觉 AI 模型"
desc="多模态大模型:在「视觉 AI / 混合」模式下负责识文与分类摘要;支持 Ollama / GLM / MiniMax / OpenAI 等 OpenAI 兼容接口。"
defaults={{
type: "openai_compat",
base_url: "http://localhost:11434/v1",
model: "qwen2.5vl:7b",
extra: {},
}}
/>
<CategorySection />
</div>
</div>
);
}
function RecognitionModeSection() {
const qc = useQueryClient();
const cur = useQuery({
queryKey: ["recognition-mode"],
queryFn: api.getRecognitionMode,
});
const [mode, setMode] = useState<RecognitionMode>("hybrid");
useEffect(() => {
if (cur.data?.mode && cur.data.mode in RECOGNITION_MODE_LABELS) {
setMode(cur.data.mode as RecognitionMode);
}
}, [cur.data]);
const save = useMutation({
mutationFn: () => api.setRecognitionMode(mode),
onSuccess: () => qc.invalidateQueries({ queryKey: ["recognition-mode"] }),
});
return (
<section className="rounded-xl border border-slate-800 bg-slate-900/40 p-5">
<header className="mb-1 flex items-center justify-between">
<h2 className="text-lg font-semibold text-white"></h2>
<button
className="btn btn-primary"
onClick={() => save.mutate()}
disabled={save.isPending}
>
</button>
</header>
<p className="mb-4 text-xs text-slate-500">
OCR AIOCR AI AI
</p>
<div className="flex flex-wrap gap-2">
{(Object.keys(RECOGNITION_MODE_LABELS) as RecognitionMode[]).map((m) => (
<button
key={m}
type="button"
className={`chip ${mode === m ? "chip-active" : ""}`}
onClick={() => setMode(m)}
>
{RECOGNITION_MODE_LABELS[m]}
</button>
))}
</div>
<p className="mt-3 text-xs text-slate-500">
{mode === "ocr" && "仅使用下方 OCR 引擎提取文字;视觉 AI 仍可用于分类/摘要(若已配置)。"}
{mode === "vision" && "使用下方「视觉 AI 模型」从图片识文,不走 Tesseract 等传统 OCR。"}
{mode === "hybrid" && "先用 OCR 引擎识文,再交给视觉 AI 联合分析;OCR 失败时会自动尝试视觉识文。"}
</p>
</section>
);
}
function WatchFolderSection() {
const qc = useQueryClient();
const folders = useQuery({ queryKey: ["watch-folders"], queryFn: api.listWatchFolders });
const [draft, setDraft] = useState<Omit<WatchFolder, "id">>({
path: "",
enabled: true,
recursive: true,
is_sensitive: false,
});
const add = useMutation({
mutationFn: () => api.addWatchFolder(draft),
onSuccess: () => {
qc.invalidateQueries({ queryKey: ["watch-folders"] });
setDraft({ ...draft, path: "" });
},
});
const update = useMutation({
mutationFn: ({ id, payload }: { id: number; payload: Omit<WatchFolder, "id"> }) =>
api.updateWatchFolder(id, payload),
onSuccess: () => qc.invalidateQueries({ queryKey: ["watch-folders"] }),
});
const remove = useMutation({
mutationFn: (id: number) => api.deleteWatchFolder(id),
onSuccess: () => qc.invalidateQueries({ queryKey: ["watch-folders"] }),
});
const reimport = useMutation({
mutationFn: (folder: WatchFolder) =>
api.importNow({
path: folder.path,
enabled: folder.enabled,
recursive: folder.recursive,
is_sensitive: folder.is_sensitive,
}),
});
const [pathTestMsg, setPathTestMsg] = useState<{ ok: boolean; text: string } | null>(null);
const testPath = useMutation({
mutationFn: () => api.validateWatchPath(draft.path),
onSuccess: (data) => {
setPathTestMsg({
ok: true,
text: `${data.message}(规范化路径: ${data.path}`,
});
},
onError: (err: Error) => {
setPathTestMsg({ ok: false, text: err.message });
},
});
return (
<section className="rounded-xl border border-slate-800 bg-slate-900/40 p-5">
<header className="mb-3 flex items-center justify-between">
<h2 className="text-lg font-semibold text-white"></h2>
<span className="text-xs text-slate-500">
UNC \\NAS\\\\
</span>
</header>
<div className="mb-4 flex flex-wrap gap-2">
<input
className="input min-w-[280px] flex-1"
placeholder="D:/Pictures 或 \\JIULUGNAS\\personal_folder\\Photos\\..."
value={draft.path}
onChange={(e) => {
setDraft({ ...draft, path: e.target.value });
setPathTestMsg(null);
}}
/>
<label className="flex items-center gap-1 text-xs text-slate-400">
<input
type="checkbox"
checked={draft.recursive}
onChange={(e) => setDraft({ ...draft, recursive: e.target.checked })}
/>
</label>
<label className="flex items-center gap-1 text-xs text-slate-400">
<input
type="checkbox"
checked={draft.is_sensitive}
onChange={(e) => setDraft({ ...draft, is_sensitive: e.target.checked })}
/>
</label>
<button
className="btn"
disabled={!draft.path || testPath.isPending}
onClick={() => testPath.mutate()}
>
<Plug size={14} />
</button>
<button
className="btn btn-primary"
disabled={!draft.path || add.isPending}
onClick={() => add.mutate()}
>
<FolderPlus size={14} />
</button>
</div>
{pathTestMsg && (
<p
className={`mb-3 text-xs ${pathTestMsg.ok ? "text-emerald-400" : "text-rose-400"}`}
>
{pathTestMsg.text}
</p>
)}
<ul className="flex flex-col divide-y divide-slate-800">
{(folders.data ?? []).map((f) => (
<li key={f.id} className="flex items-center gap-2 py-2 text-sm">
<FolderOpen size={14} className="text-slate-500" />
<span className="flex-1 break-all text-slate-200">{f.path}</span>
{f.is_sensitive && (
<span className="inline-flex items-center gap-1 rounded bg-amber-500/20 px-2 py-0.5 text-[10px] text-amber-300">
<ShieldAlert size={10} />
</span>
)}
<label className="flex items-center gap-1 text-xs text-slate-400">
<input
type="checkbox"
checked={f.enabled}
onChange={(e) =>
update.mutate({
id: f.id,
payload: { ...f, enabled: e.target.checked },
})
}
/>
</label>
<label className="flex items-center gap-1 text-xs text-slate-400">
<input
type="checkbox"
checked={f.recursive}
onChange={(e) =>
update.mutate({
id: f.id,
payload: { ...f, recursive: e.target.checked },
})
}
/>
</label>
<button
className="btn"
onClick={() => reimport.mutate(f)}
disabled={reimport.isPending}
>
<RefreshCw size={14} />
</button>
<button
className="btn text-rose-300 hover:border-rose-500"
onClick={() => {
if (confirm(`删除监听目录 ${f.path}`)) remove.mutate(f.id);
}}
>
<Trash2 size={14} />
</button>
</li>
))}
{(folders.data?.length ?? 0) === 0 && (
<li className="py-4 text-center text-xs text-slate-500">
</li>
)}
</ul>
</section>
);
}
function ProviderSection({
k,
title,
desc,
defaults,
}: {
k: "ocr_provider" | "vlm_provider";
title: string;
desc: string;
defaults: ProviderConfig;
}) {
const qc = useQueryClient();
const cur = useQuery({ queryKey: ["provider", k], queryFn: () => api.getProvider(k) });
const [draft, setDraft] = useState<ProviderConfig>(defaults);
const mask: string | null | undefined = (cur.data as ProviderConfigOut | null | undefined)
?.api_key_mask;
useEffect(() => {
if (cur.data) {
// 仅同步用户可编辑字段,避免把 api_key_mask 之类的派生字段灌进去
const { type, base_url, api_key, model, extra } = cur.data;
setDraft({
...defaults,
type: type || defaults.type,
base_url: base_url ?? defaults.base_url ?? null,
api_key: api_key ?? "",
model: model ?? defaults.model ?? null,
extra: { ...defaults.extra, ...(extra ?? {}) },
});
}
}, [cur.data]);
const save = useMutation({
mutationFn: () => api.setProvider(k, draft),
onSuccess: () => qc.invalidateQueries({ queryKey: ["provider", k] }),
});
const [testResult, setTestResult] = useState<ProviderTestResult | null>(null);
const testConn = useMutation({
mutationFn: () => api.testProvider(k, draft),
onSuccess: (data) => setTestResult(data),
onError: (err: Error) =>
setTestResult({ ok: false, message: err.message, detail: null, latency_ms: null }),
});
return (
<section className="rounded-xl border border-slate-800 bg-slate-900/40 p-5">
<header className="mb-1 flex items-center justify-between">
<h2 className="text-lg font-semibold text-white">{title}</h2>
<div className="flex gap-2">
<button
className="btn"
onClick={() => testConn.mutate()}
disabled={testConn.isPending || draft.type === "none"}
>
<Plug size={14} /> {testConn.isPending ? "测试中…" : "测试连通性"}
</button>
<button
className="btn btn-primary"
onClick={() => save.mutate()}
disabled={save.isPending}
>
</button>
</div>
</header>
<p className="mb-4 text-xs text-slate-500">{desc}</p>
{testResult && (
<div
className={`mb-4 rounded-md border px-3 py-2 text-xs ${
testResult.ok
? "border-emerald-500/40 bg-emerald-500/10 text-emerald-300"
: "border-rose-500/40 bg-rose-500/10 text-rose-300"
}`}
>
<div className="font-medium">{testResult.message}</div>
{testResult.detail && (
<div className="mt-1 opacity-80">{testResult.detail}</div>
)}
{testResult.latency_ms != null && (
<div className="mt-1 opacity-60"> {testResult.latency_ms} ms</div>
)}
</div>
)}
<div className="grid gap-3 md:grid-cols-2">
<Field label="Provider 类型">
{k === "ocr_provider" ? (
<select
className="input"
value={draft.type}
onChange={(e) => {
setDraft({ ...draft, type: e.target.value });
setTestResult(null);
}}
>
<option value="tesseract">Tesseract</option>
<option value="paddleocr">PaddleOCR</option>
<option value="http">HTTP API OCR </option>
<option value="vision">OpenAI </option>
<option value="none">使</option>
</select>
) : (
<select
className="input"
value={draft.type}
onChange={(e) => {
setDraft({ ...draft, type: e.target.value });
setTestResult(null);
}}
>
<option value="openai_compat">OpenAI </option>
<option value="none">使</option>
</select>
)}
</Field>
{(k === "vlm_provider" ||
(k === "ocr_provider" &&
(draft.type === "vision" || draft.type === "http"))) && (
<VisionApiFields
draft={draft}
setDraft={setDraft}
mask={mask}
showModel={draft.type !== "http" || k === "vlm_provider"}
urlLabel={draft.type === "http" && k === "ocr_provider" ? "OCR API URL" : "Base URL"}
urlPlaceholder={
draft.type === "http" && k === "ocr_provider"
? "https://your-ocr-service/recognize"
: "https://api.openai.com/v1"
}
/>
)}
{k === "ocr_provider" && draft.type === "tesseract" && (
<>
<Field label="语言(lang">
<input
className="input"
value={String(draft.extra.lang ?? "chi_sim+eng")}
onChange={(e) =>
setDraft({
...draft,
extra: { ...draft.extra, lang: e.target.value },
})
}
/>
</Field>
<Field label="tesseract 路径(可选)">
<input
className="input"
placeholder="C:/Program Files/Tesseract-OCR/tesseract.exe"
value={String(draft.extra.cmd ?? "")}
onChange={(e) =>
setDraft({
...draft,
extra: { ...draft.extra, cmd: e.target.value },
})
}
/>
</Field>
</>
)}
{k === "ocr_provider" && draft.type === "paddleocr" && (
<Field label="语言(lang">
<input
className="input"
placeholder="ch / en / ..."
value={String(draft.extra.lang ?? "ch")}
onChange={(e) =>
setDraft({
...draft,
extra: { ...draft.extra, lang: e.target.value },
})
}
/>
</Field>
)}
{k === "ocr_provider" && draft.type === "http" && (
<Field label="响应文本字段(text_path">
<input
className="input"
placeholder="text 或 data.text"
value={String(draft.extra.text_path ?? "text")}
onChange={(e) =>
setDraft({
...draft,
extra: { ...draft.extra, text_path: e.target.value },
})
}
/>
</Field>
)}
</div>
</section>
);
}
/** 视觉 / HTTP OCR 共用的 URL、Model、API Key 表单项 */
function VisionApiFields({
draft,
setDraft,
mask,
showModel = true,
urlLabel = "Base URL",
urlPlaceholder = "https://api.openai.com/v1",
}: {
draft: ProviderConfig;
setDraft: (v: ProviderConfig) => void;
mask?: string | null;
showModel?: boolean;
urlLabel?: string;
urlPlaceholder?: string;
}) {
return (
<>
<Field label={urlLabel}>
<input
className="input"
placeholder={urlPlaceholder}
value={draft.base_url ?? ""}
onChange={(e) => setDraft({ ...draft, base_url: e.target.value })}
/>
</Field>
{showModel && (
<Field label="Model">
<input
className="input"
placeholder="gpt-4o-mini / glm-4v-flash / qwen2.5vl:7b"
value={draft.model ?? ""}
onChange={(e) => setDraft({ ...draft, model: e.target.value })}
/>
</Field>
)}
<Field label="API Key(留空则保留原值)">
<input
className="input"
type="password"
placeholder={mask ? `已配置:${mask}` : "sk-..."}
value={draft.api_key ?? ""}
onChange={(e) => setDraft({ ...draft, api_key: e.target.value })}
/>
</Field>
</>
);
}
function CategorySection() {
const qc = useQueryClient();
const cats = useQuery({ queryKey: ["categories"], queryFn: api.listCategories });
const [draft, setDraft] = useState<Omit<Category, "id">>({
name: "",
color: "#6366f1",
prompt_hint: "",
});
const create = useMutation({
mutationFn: () => api.createCategory(draft),
onSuccess: () => {
qc.invalidateQueries({ queryKey: ["categories"] });
setDraft({ name: "", color: "#6366f1", prompt_hint: "" });
},
});
const update = useMutation({
mutationFn: ({ id, payload }: { id: number; payload: Omit<Category, "id"> }) =>
api.updateCategory(id, payload),
onSuccess: () => qc.invalidateQueries({ queryKey: ["categories"] }),
});
const remove = useMutation({
mutationFn: (id: number) => api.deleteCategory(id),
onSuccess: () => qc.invalidateQueries({ queryKey: ["categories"] }),
});
return (
<section className="rounded-xl border border-slate-800 bg-slate-900/40 p-5">
<header className="mb-3 flex items-center justify-between">
<h2 className="text-lg font-semibold text-white"></h2>
<span className="text-xs text-slate-500"> AI</span>
</header>
<div className="mb-4 grid grid-cols-[1fr_120px_2fr_auto] gap-2">
<input
className="input"
placeholder="分类名"
value={draft.name}
onChange={(e) => setDraft({ ...draft, name: e.target.value })}
/>
<input
className="input"
type="color"
value={draft.color ?? "#6366f1"}
onChange={(e) => setDraft({ ...draft, color: e.target.value })}
/>
<input
className="input"
placeholder="提示词(可选)"
value={draft.prompt_hint ?? ""}
onChange={(e) => setDraft({ ...draft, prompt_hint: e.target.value })}
/>
<button
className="btn btn-primary"
disabled={!draft.name || create.isPending}
onClick={() => create.mutate()}
>
</button>
</div>
<ul className="flex flex-col divide-y divide-slate-800">
{(cats.data ?? []).map((c) => (
<CategoryRow
key={c.id}
cat={c}
onSave={(payload) => update.mutate({ id: c.id, payload })}
onDelete={() => {
if (confirm(`删除分类 ${c.name}`)) remove.mutate(c.id);
}}
/>
))}
</ul>
</section>
);
}
function CategoryRow({
cat,
onSave,
onDelete,
}: {
cat: Category;
onSave: (p: Omit<Category, "id">) => void;
onDelete: () => void;
}) {
const [editing, setEditing] = useState(false);
const [draft, setDraft] = useState<Omit<Category, "id">>({
name: cat.name,
color: cat.color ?? "#6366f1",
prompt_hint: cat.prompt_hint ?? "",
});
useEffect(() => {
setDraft({
name: cat.name,
color: cat.color ?? "#6366f1",
prompt_hint: cat.prompt_hint ?? "",
});
}, [cat]);
if (editing) {
return (
<li className="grid grid-cols-[1fr_120px_2fr_auto] gap-2 py-2">
<input
className="input"
value={draft.name}
onChange={(e) => setDraft({ ...draft, name: e.target.value })}
/>
<input
className="input"
type="color"
value={draft.color ?? "#6366f1"}
onChange={(e) => setDraft({ ...draft, color: e.target.value })}
/>
<input
className="input"
value={draft.prompt_hint ?? ""}
onChange={(e) => setDraft({ ...draft, prompt_hint: e.target.value })}
/>
<div className="flex gap-1">
<button
className="btn btn-primary"
onClick={() => {
onSave(draft);
setEditing(false);
}}
>
</button>
<button className="btn" onClick={() => setEditing(false)}>
</button>
</div>
</li>
);
}
return (
<li className="grid grid-cols-[1fr_120px_2fr_auto] items-center gap-2 py-2 text-sm">
<span className="flex items-center gap-2 text-slate-100">
<span
className="h-3 w-3 rounded-full"
style={{ background: cat.color ?? "#6366f1" }}
/>
{cat.name}
</span>
<span className="text-xs text-slate-500">{cat.color}</span>
<span className="text-xs text-slate-400">{cat.prompt_hint || "—"}</span>
<div className="flex gap-1">
<button className="btn" onClick={() => setEditing(true)}>
</button>
<button
className="btn text-rose-300 hover:border-rose-500"
onClick={onDelete}
>
<Trash2 size={14} />
</button>
</div>
</li>
);
}
function Field({ label, children }: { label: string; children: React.ReactNode }) {
return (
<label className="block">
<span className="mb-1 block text-xs font-medium text-slate-400">{label}</span>
{children}
</label>
);
}
+103
View File
@@ -0,0 +1,103 @@
import { useQuery } from "@tanstack/react-query";
import { useState } from "react";
import { Shuffle, ExternalLink, Star } from "lucide-react";
import { api } from "@/api/client";
import { DetailPanel } from "@/components/DetailPanel";
import { StatusBadge } from "@/components/StatusBadge";
export default function ShufflePage() {
const [openId, setOpenId] = useState<number | null>(null);
const random = useQuery({
queryKey: ["random-one"],
queryFn: () => api.randomScreenshots({ n: 1 }),
});
const shot = random.data?.[0];
return (
<div className="flex h-full flex-col px-8 py-6">
<header className="mb-4 flex items-center justify-between">
<div>
<h1 className="flex items-center gap-2 text-2xl font-semibold text-white">
<Shuffle size={20} className="text-brand-400" />
</h1>
<p className="text-xs text-slate-500"></p>
</div>
<button
className="btn btn-primary"
onClick={() => random.refetch()}
disabled={random.isFetching}
>
<Shuffle size={14} />
</button>
</header>
{!shot ? (
<div className="flex flex-1 items-center justify-center text-slate-500">
{random.isLoading ? "加载中…" : "暂无截图"}
</div>
) : (
<div className="flex flex-1 gap-6 overflow-hidden">
<div className="flex flex-1 items-center justify-center rounded-xl border border-slate-800 bg-black/30 p-4">
<img
src={`/api/screenshots/${shot.id}/file`}
alt={shot.ai_title ?? "screenshot"}
className="max-h-full max-w-full rounded-md shadow-2xl"
/>
</div>
<aside className="w-[360px] shrink-0 overflow-y-auto rounded-xl border border-slate-800 bg-slate-900/50 px-5 py-5">
<div className="mb-3 flex items-center gap-2">
<StatusBadge status={shot.ai_status} />
{shot.category && (
<span
className="rounded-full px-2 py-0.5 text-xs"
style={{
backgroundColor: `${shot.category.color ?? "#6366f1"}22`,
color: shot.category.color ?? "#a5b4fc",
}}
>
{shot.category.name}
</span>
)}
{shot.is_favorite && (
<span className="inline-flex items-center gap-1 rounded-full bg-amber-400/20 px-2 py-0.5 text-xs text-amber-300">
<Star size={12} />
</span>
)}
</div>
<h2 className="text-xl font-semibold text-white">
{shot.ai_title || "(未生成标题)"}
</h2>
<p className="mt-1 text-xs text-slate-500">
{new Date(shot.captured_at).toLocaleString("zh-CN", { hour12: false })}
</p>
<div className="mt-4 flex gap-2">
<button className="btn" onClick={() => setOpenId(shot.id)}>
</button>
<a
className="btn"
href={`/api/screenshots/${shot.id}/file`}
target="_blank"
rel="noreferrer"
>
<ExternalLink size={14} />
</a>
</div>
{shot.tags.length > 0 && (
<div className="mt-4 flex flex-wrap gap-1.5">
{shot.tags.map((t) => (
<span key={t.id} className="chip">
#{t.name}
</span>
))}
</div>
)}
</aside>
</div>
)}
<DetailPanel id={openId} onClose={() => setOpenId(null)} />
</div>
);
}
+115
View File
@@ -0,0 +1,115 @@
import { useQuery } from "@tanstack/react-query";
import { useMemo, useState } from "react";
import { Link } from "react-router-dom";
import { Hash, Search } from "lucide-react";
import { api } from "@/api/client";
const PAGE_SIZE = 80;
const SORT_OPTIONS = [
{ value: "count_desc", label: "使用最多" },
{ value: "count_asc", label: "使用最少" },
{ value: "name_asc", label: "名称 A→Z" },
{ value: "name_desc", label: "名称 Z→A" },
];
export default function TagsPage() {
const [q, setQ] = useState("");
const [search, setSearch] = useState("");
const [sort, setSort] = useState("count_desc");
const [page, setPage] = useState(1);
const list = useQuery({
queryKey: ["tags", search, sort, page],
queryFn: () => api.listTags({ q: search || undefined, sort, page, size: PAGE_SIZE }),
placeholderData: (prev) => prev,
});
const totalPages = Math.max(1, Math.ceil((list.data?.total ?? 0) / PAGE_SIZE));
const onSearch = () => {
setSearch(q.trim());
setPage(1);
};
const items = useMemo(() => list.data?.items ?? [], [list.data?.items]);
return (
<div className="flex h-full flex-col px-8 py-6">
<header className="mb-4 flex flex-wrap items-end justify-between gap-4">
<div>
<h1 className="flex items-center gap-2 text-2xl font-semibold text-white">
<Hash size={22} className="text-brand-400" />
</h1>
<p className="mt-1 text-xs text-slate-500">
{list.data?.total ?? 0} ·
</p>
</div>
<div className="flex flex-wrap items-center gap-2">
<div className="relative">
<Search
size={14}
className="pointer-events-none absolute left-2.5 top-1/2 -translate-y-1/2 text-slate-500"
/>
<input
className="input w-56 pl-7"
placeholder="搜索标签名…"
value={q}
onChange={(e) => setQ(e.target.value)}
onKeyDown={(e) => e.key === "Enter" && onSearch()}
/>
</div>
<button className="btn" onClick={onSearch}>
</button>
<select className="input w-36" value={sort} onChange={(e) => { setSort(e.target.value); setPage(1); }}>
{SORT_OPTIONS.map((o) => (
<option key={o.value} value={o.value}>
{o.label}
</option>
))}
</select>
</div>
</header>
<div className="mb-3 flex items-center justify-end gap-2 text-sm">
<button className="btn" disabled={page <= 1} onClick={() => setPage((p) => Math.max(1, p - 1))}>
</button>
<span className="text-xs text-slate-400">
{page} / {totalPages}
</span>
<button
className="btn"
disabled={page >= totalPages}
onClick={() => setPage((p) => Math.min(totalPages, p + 1))}
>
</button>
</div>
<div className="flex-1 overflow-y-auto">
{list.isLoading && (
<div className="flex h-40 items-center justify-center text-sm text-slate-500"></div>
)}
{!list.isLoading && items.length === 0 && (
<div className="flex h-40 items-center justify-center text-sm text-slate-500"></div>
)}
<div className="flex flex-wrap gap-2">
{items.map((t) => (
<Link
key={t.id}
to={`/library?tag=${encodeURIComponent(t.name)}`}
className="chip hover:border-brand-500"
>
#{t.name}
<span className="opacity-50">{t.count}</span>
</Link>
))}
</div>
</div>
</div>
);
}
+177
View File
@@ -0,0 +1,177 @@
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
import { useState } from "react";
import { ListChecks, Check, X, Image as ImageIcon, Search } from "lucide-react";
import { api } from "@/api/client";
import { DetailPanel } from "@/components/DetailPanel";
import type { TodoListQuery } from "@/types";
const STATUS_TABS: { key: string; label: string }[] = [
{ key: "pending", label: "待办" },
{ key: "doing", label: "进行中" },
{ key: "done", label: "已完成" },
{ key: "dropped", label: "已搁置" },
];
const PAGE_SIZE = 30;
export default function TodosPage() {
const [status, setStatus] = useState<string>("pending");
const [qInput, setQInput] = useState("");
const [query, setQuery] = useState<TodoListQuery>({ page: 1, size: PAGE_SIZE });
const [openId, setOpenId] = useState<number | null>(null);
const qc = useQueryClient();
const summary = useQuery({ queryKey: ["todo-summary"], queryFn: api.todoSummary });
const list = useQuery({
queryKey: ["todos", status, query],
queryFn: () => api.listTodos({ ...query, status, size: PAGE_SIZE }),
placeholderData: (prev) => prev,
});
const totalPages = Math.max(1, Math.ceil((list.data?.total ?? 0) / PAGE_SIZE));
const update = useMutation({
mutationFn: ({ id, payload }: { id: number; payload: Parameters<typeof api.updateTodo>[1] }) =>
api.updateTodo(id, payload),
onSuccess: () => {
qc.invalidateQueries({ queryKey: ["todos"] });
qc.invalidateQueries({ queryKey: ["todo-summary"] });
},
});
const onTabChange = (key: string) => {
setStatus(key);
setQuery((prev) => ({ ...prev, page: 1 }));
};
const onSearch = () => {
setQuery((prev) => ({ ...prev, q: qInput.trim() || undefined, page: 1 }));
};
return (
<div className="flex h-full flex-col px-8 py-6">
<header className="mb-4 flex flex-wrap items-end justify-between gap-4">
<div>
<h1 className="flex items-center gap-2 text-2xl font-semibold text-white">
<ListChecks size={22} className="text-brand-400" />
</h1>
<p className="text-xs text-slate-500">
AI / / · {list.data?.total ?? 0}
</p>
</div>
<div className="flex items-center gap-2">
<div className="relative">
<Search
size={14}
className="pointer-events-none absolute left-2.5 top-1/2 -translate-y-1/2 text-slate-500"
/>
<input
className="input w-52 pl-7"
placeholder="搜索标题/备注…"
value={qInput}
onChange={(e) => setQInput(e.target.value)}
onKeyDown={(e) => e.key === "Enter" && onSearch()}
/>
</div>
<button className="btn" onClick={onSearch}>
</button>
</div>
</header>
<div className="mb-3 flex flex-wrap items-center gap-2">
{STATUS_TABS.map((t) => (
<button
key={t.key}
className={`chip ${status === t.key ? "chip-active" : ""}`}
onClick={() => onTabChange(t.key)}
>
{t.label}
<span className="opacity-60">{summary.data?.[t.key] ?? 0}</span>
</button>
))}
<div className="ml-auto flex items-center gap-2 text-sm">
<button
className="btn"
disabled={(query.page ?? 1) <= 1}
onClick={() => setQuery((prev) => ({ ...prev, page: Math.max(1, (prev.page ?? 1) - 1) }))}
>
</button>
<span className="text-xs text-slate-400">
{query.page ?? 1} / {totalPages}
</span>
<button
className="btn"
disabled={(query.page ?? 1) >= totalPages}
onClick={() =>
setQuery((prev) => ({ ...prev, page: Math.min(totalPages, (prev.page ?? 1) + 1) }))
}
>
</button>
</div>
</div>
<div className="flex-1 overflow-y-auto">
{(list.data?.items.length ?? 0) === 0 && !list.isLoading && (
<div className="flex h-40 items-center justify-center text-sm text-slate-500">
</div>
)}
{list.isLoading && (
<div className="flex h-40 items-center justify-center text-sm text-slate-500"></div>
)}
<ul className="grid gap-3 lg:grid-cols-2">
{(list.data?.items ?? []).map((t) => (
<li
key={t.id}
className="group rounded-lg border border-slate-800 bg-slate-900/50 p-4 transition hover:border-brand-500"
>
<div className="mb-1 flex items-center justify-between text-[10px] text-slate-500">
<span>{t.kind ?? "待办"}</span>
<span>{new Date(t.created_at).toLocaleString("zh-CN", { hour12: false })}</span>
</div>
<div className="text-sm font-medium text-slate-100">{t.title}</div>
{t.note && (
<div className="mt-1 line-clamp-3 text-xs text-slate-400">{t.note}</div>
)}
<div className="mt-3 flex items-center gap-2">
<button className="btn" onClick={() => setOpenId(t.screenshot_id)}>
<ImageIcon size={14} />
</button>
{status !== "done" && (
<button
className="btn"
onClick={() => update.mutate({ id: t.id, payload: { status: "done" } })}
>
<Check size={14} />
</button>
)}
{status !== "dropped" && (
<button
className="btn btn-ghost"
onClick={() => update.mutate({ id: t.id, payload: { status: "dropped" } })}
>
<X size={14} />
</button>
)}
{status !== "pending" && (
<button
className="btn btn-ghost"
onClick={() => update.mutate({ id: t.id, payload: { status: "pending" } })}
>
</button>
)}
</div>
</li>
))}
</ul>
</div>
<DetailPanel id={openId} onClose={() => setOpenId(null)} />
</div>
);
}
+157
View File
@@ -0,0 +1,157 @@
export interface Tag {
id: number;
name: string;
color?: string | null;
}
export interface Category {
id: number;
name: string;
color?: string | null;
prompt_hint?: string | null;
}
export interface ScreenshotBrief {
id: number;
path: string;
width: number;
height: number;
captured_at: string;
thumb_url?: string | null;
ai_title?: string | null;
ai_status: string;
ocr_status: string;
is_favorite: boolean;
category?: Category | null;
tags: Tag[];
}
export interface TodoItem {
id: number;
title: string;
note?: string | null;
kind?: string | null;
status: string;
created_at: string;
completed_at?: string | null;
screenshot_id: number;
}
export interface ScreenshotDetail extends ScreenshotBrief {
file_url: string;
size: number;
ocr_text?: string | null;
ai_summary?: string | null;
ai_suggestion?: string | null;
todos: TodoItem[];
}
export interface ListResp {
items: ScreenshotBrief[];
total: number;
page: number;
size: number;
}
export interface WatchFolder {
id: number;
path: string;
enabled: boolean;
recursive: boolean;
is_sensitive: boolean;
}
export interface ProviderConfig {
type: string;
base_url?: string | null;
api_key?: string | null;
model?: string | null;
extra: Record<string, unknown>;
}
/** 读取 Provider 时后端会附带 api_key_mask,用于 UI 提示。 */
export interface ProviderConfigOut extends ProviderConfig {
api_key_mask?: string | null;
}
export interface ProviderTestResult {
ok: boolean;
message: string;
detail?: string | null;
latency_ms?: number | null;
}
export interface StatsResp {
total: number;
by_status: Record<string, number>;
by_category: { id: number; name: string; color?: string | null; count: number }[];
by_month: { month: string; count: number }[];
queue: Record<string, number>;
}
export type RecognitionMode = "ocr" | "vision" | "hybrid";
export const RECOGNITION_MODE_LABELS: Record<RecognitionMode, string> = {
ocr: "传统 OCR",
vision: "视觉 AI 识文",
hybrid: "混合(OCR + 视觉 AI",
};
export interface ListQuery {
q?: string;
category_id?: number;
tag?: string;
date_from?: string;
date_to?: string;
favorite?: boolean;
status?: string;
sort?: string;
page?: number;
size?: number;
}
export interface JobItem {
id: number;
screenshot_id: number;
kind: string;
status: string;
retries: number;
last_error?: string | null;
created_at: string;
started_at?: string | null;
finished_at?: string | null;
thumb_url?: string | null;
path?: string | null;
ai_title?: string | null;
ai_status?: string | null;
ocr_status?: string | null;
}
export interface TodoListResp {
items: TodoItem[];
total: number;
page: number;
size: number;
}
export interface TagListResp {
items: (Tag & { count: number })[];
total: number;
page: number;
size: number;
}
export interface TodoListQuery {
status?: string;
kind?: string;
q?: string;
page?: number;
size?: number;
}
export interface JobListResp {
items: JobItem[];
total: number;
page: number;
size: number;
}
+23
View File
@@ -0,0 +1,23 @@
import type { Config } from "tailwindcss";
const config: Config = {
content: ["./index.html", "./src/**/*.{ts,tsx}"],
darkMode: "class",
theme: {
extend: {
colors: {
brand: {
50: "#eef2ff",
100: "#e0e7ff",
400: "#818cf8",
500: "#6366f1",
600: "#4f46e5",
700: "#4338ca",
},
},
},
},
plugins: [],
};
export default config;
+24
View File
@@ -0,0 +1,24 @@
{
"compilerOptions": {
"target": "ES2020",
"lib": ["ES2020", "DOM", "DOM.Iterable"],
"module": "ESNext",
"moduleResolution": "Bundler",
"jsx": "react-jsx",
"strict": true,
"noUnusedLocals": false,
"noUnusedParameters": false,
"noFallthroughCasesInSwitch": true,
"resolveJsonModule": true,
"allowSyntheticDefaultImports": true,
"esModuleInterop": true,
"skipLibCheck": true,
"isolatedModules": true,
"useDefineForClassFields": true,
"baseUrl": ".",
"paths": {
"@/*": ["src/*"]
}
},
"include": ["src"]
}
+21
View File
@@ -0,0 +1,21 @@
import { defineConfig } from "vite";
import react from "@vitejs/plugin-react";
import path from "node:path";
export default defineConfig({
plugins: [react()],
resolve: {
alias: {
"@": path.resolve(__dirname, "src"),
},
},
server: {
port: 5173,
proxy: {
"/api": {
target: "http://127.0.0.1:8765",
changeOrigin: true,
},
},
},
});
+56
View File
@@ -0,0 +1,56 @@
# snapAna 一键启动脚本(PowerShell
# 同时启动后端 (uvicorn @ 8765) 与前端 (vite @ 5173)
# 使用方式:在仓库根目录运行 .\start-dev.ps1
[CmdletBinding()]
param(
[switch]$InstallDeps # 传入此参数会重新安装依赖
)
$ErrorActionPreference = "Stop"
$root = Split-Path -Parent $MyInvocation.MyCommand.Definition
$backend = Join-Path $root "backend"
$frontend = Join-Path $root "frontend"
Write-Host "[snapAna] root = $root"
# 1. 后端虚拟环境
if (!(Test-Path (Join-Path $backend ".venv"))) {
Write-Host "[snapAna] 创建 Python 虚拟环境..."
& python -m venv (Join-Path $backend ".venv")
}
$venvPython = Join-Path $backend ".venv\Scripts\python.exe"
if ($InstallDeps -or !(Test-Path (Join-Path $backend ".venv\Lib\site-packages\fastapi"))) {
Write-Host "[snapAna] 安装后端依赖..."
& $venvPython -m pip install --upgrade pip
& $venvPython -m pip install -r (Join-Path $backend "requirements.txt")
}
# 2. 前端依赖
if ($InstallDeps -or !(Test-Path (Join-Path $frontend "node_modules"))) {
Write-Host "[snapAna] 安装前端依赖..."
Push-Location $frontend
npm install
Pop-Location
}
# 3. 后台启动后端
Write-Host "[snapAna] 启动后端 http://127.0.0.1:8765 ..."
$backendProc = Start-Process -FilePath $venvPython `
-ArgumentList "run.py" `
-WorkingDirectory $backend `
-PassThru `
-WindowStyle Normal
# 4. 启动前端(占用当前控制台,便于查看日志)
Write-Host "[snapAna] 启动前端 http://127.0.0.1:5173 ..."
Push-Location $frontend
try {
npm run dev
}
finally {
Pop-Location
Write-Host "[snapAna] 关闭后端 (PID $($backendProc.Id))..."
Stop-Process -Id $backendProc.Id -Force -ErrorAction SilentlyContinue
}