Initial commit: snapAna 截图智能整理工具
包含 FastAPI 后端、React 前端、队列/OCR/标签/待办等完整功能。 Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -0,0 +1,220 @@
|
||||
"""设置接口:Provider 配置、分类、Tag。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.deps import db_session
|
||||
from app.models.category import Category
|
||||
from app.models.setting import (
|
||||
DEFAULT_RECOGNITION_MODE,
|
||||
KEY_OCR_PROVIDER,
|
||||
KEY_RECOGNITION_MODE,
|
||||
KEY_VLM_PROVIDER,
|
||||
)
|
||||
from app.models.screenshot import Screenshot
|
||||
from app.models.tag import Tag
|
||||
from app.providers import RECOGNITION_MODES
|
||||
from app.schemas.common import (
|
||||
CategoryIn,
|
||||
ProviderConfig,
|
||||
ProviderConfigOut,
|
||||
ProviderTestResult,
|
||||
RecognitionModeIn,
|
||||
)
|
||||
from app.services.provider_test import merge_provider_api_key, test_provider_config
|
||||
from app.services.settings_store import all_settings, get_setting, set_setting
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/settings", tags=["settings"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
def get_all(session: Session = Depends(db_session)) -> dict:
|
||||
"""返回所有非敏感设置。api_key 字段做脱敏。"""
|
||||
raw = all_settings(session)
|
||||
for key in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
|
||||
cfg = raw.get(key)
|
||||
if isinstance(cfg, dict) and cfg.get("api_key"):
|
||||
cfg["api_key_mask"] = _mask(cfg["api_key"])
|
||||
cfg["api_key"] = ""
|
||||
return raw
|
||||
|
||||
|
||||
def _mask(value: str) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
if len(value) <= 6:
|
||||
return "*" * len(value)
|
||||
return value[:3] + "*" * (len(value) - 6) + value[-3:]
|
||||
|
||||
|
||||
@router.get("/providers/{key}", response_model=ProviderConfigOut | None)
|
||||
def get_provider(
|
||||
key: str,
|
||||
session: Session = Depends(db_session),
|
||||
) -> ProviderConfigOut | None:
|
||||
"""读取 Provider 配置:api_key 明文不外传,只给一个掩码用于 UI 提示。"""
|
||||
if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
|
||||
raise HTTPException(400, "key must be ocr_provider or vlm_provider")
|
||||
raw = get_setting(session, key, None)
|
||||
if not raw:
|
||||
return None
|
||||
mask = _mask(raw.get("api_key", "") or "")
|
||||
return ProviderConfigOut(
|
||||
type=raw.get("type", ""),
|
||||
base_url=raw.get("base_url"),
|
||||
api_key="",
|
||||
api_key_mask=mask or None,
|
||||
model=raw.get("model"),
|
||||
extra=raw.get("extra", {}) or {},
|
||||
)
|
||||
|
||||
|
||||
@router.put("/providers/{key}")
|
||||
def put_provider(
|
||||
key: str,
|
||||
cfg: ProviderConfig,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
|
||||
raise HTTPException(400, "key must be ocr_provider or vlm_provider")
|
||||
# 如果客户端没有传新的 api_key(空字符串),保留旧值
|
||||
existing = get_setting(session, key, None)
|
||||
payload = cfg.model_dump()
|
||||
if (not payload.get("api_key")) and isinstance(existing, dict):
|
||||
payload["api_key"] = existing.get("api_key", "")
|
||||
set_setting(session, key, payload)
|
||||
session.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.post("/providers/{key}/test", response_model=ProviderTestResult)
|
||||
async def test_provider(
|
||||
key: str,
|
||||
cfg: ProviderConfig,
|
||||
session: Session = Depends(db_session),
|
||||
) -> ProviderTestResult:
|
||||
"""测试 OCR / 视觉 AI Provider 连通性(使用当前表单配置,api_key 可留空沿用已保存值)。"""
|
||||
if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
|
||||
raise HTTPException(400, "key must be ocr_provider or vlm_provider")
|
||||
existing = get_setting(session, key, None)
|
||||
merged = merge_provider_api_key(cfg, existing if isinstance(existing, dict) else None)
|
||||
result = await test_provider_config(key, merged)
|
||||
return ProviderTestResult(**result)
|
||||
|
||||
|
||||
@router.get("/recognition-mode")
|
||||
def get_recognition_mode(session: Session = Depends(db_session)) -> dict:
|
||||
"""读取文字识别策略:ocr / vision / hybrid。"""
|
||||
mode = get_setting(session, KEY_RECOGNITION_MODE, DEFAULT_RECOGNITION_MODE)
|
||||
if mode not in RECOGNITION_MODES:
|
||||
mode = DEFAULT_RECOGNITION_MODE
|
||||
return {"mode": mode, "options": list(RECOGNITION_MODES)}
|
||||
|
||||
|
||||
@router.put("/recognition-mode")
|
||||
def put_recognition_mode(
|
||||
payload: RecognitionModeIn,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
"""保存文字识别策略。"""
|
||||
if payload.mode not in RECOGNITION_MODES:
|
||||
raise HTTPException(400, f"mode must be one of {RECOGNITION_MODES}")
|
||||
set_setting(session, KEY_RECOGNITION_MODE, payload.mode)
|
||||
session.commit()
|
||||
return {"ok": True, "mode": payload.mode}
|
||||
|
||||
|
||||
@router.get("/categories")
|
||||
def list_categories(session: Session = Depends(db_session)) -> list[dict]:
|
||||
rows = session.scalars(select(Category).order_by(Category.id)).all()
|
||||
return [
|
||||
{"id": c.id, "name": c.name, "color": c.color, "prompt_hint": c.prompt_hint}
|
||||
for c in rows
|
||||
]
|
||||
|
||||
|
||||
@router.post("/categories")
|
||||
def create_category(
|
||||
payload: CategoryIn,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
exists = session.scalar(select(Category).where(Category.name == payload.name))
|
||||
if exists is not None:
|
||||
raise HTTPException(400, "category exists")
|
||||
cat = Category(name=payload.name, color=payload.color, prompt_hint=payload.prompt_hint)
|
||||
session.add(cat)
|
||||
session.commit()
|
||||
session.refresh(cat)
|
||||
return {"id": cat.id}
|
||||
|
||||
|
||||
@router.patch("/categories/{cat_id}")
|
||||
def update_category(
|
||||
cat_id: int,
|
||||
payload: CategoryIn,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
cat = session.get(Category, cat_id)
|
||||
if cat is None:
|
||||
raise HTTPException(404, "category not found")
|
||||
cat.name = payload.name
|
||||
cat.color = payload.color
|
||||
cat.prompt_hint = payload.prompt_hint
|
||||
session.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.delete("/categories/{cat_id}")
|
||||
def delete_category(
|
||||
cat_id: int,
|
||||
session: Session = Depends(db_session),
|
||||
) -> dict:
|
||||
cat = session.get(Category, cat_id)
|
||||
if cat is None:
|
||||
raise HTTPException(404, "category not found")
|
||||
session.delete(cat)
|
||||
session.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/tags")
|
||||
def list_tags(
|
||||
session: Session = Depends(db_session),
|
||||
q: Optional[str] = Query(None, description="标签名关键词"),
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(200, ge=1, le=500),
|
||||
sort: str = Query("count_desc", description="count_desc|count_asc|name_asc|name_desc"),
|
||||
) -> dict:
|
||||
"""标签列表(含使用次数),支持搜索与分页。"""
|
||||
base = select(Tag.id)
|
||||
if q:
|
||||
base = base.where(Tag.name.ilike(f"%{q.strip()}%"))
|
||||
total = session.scalar(select(func.count()).select_from(base.subquery())) or 0
|
||||
|
||||
stmt = (
|
||||
select(Tag.id, Tag.name, Tag.color, func.count(Screenshot.id))
|
||||
.join(Tag.screenshots, isouter=True)
|
||||
.group_by(Tag.id)
|
||||
)
|
||||
if q:
|
||||
stmt = stmt.where(Tag.name.ilike(f"%{q.strip()}%"))
|
||||
|
||||
if sort == "count_asc":
|
||||
stmt = stmt.order_by(func.count(Screenshot.id).asc())
|
||||
elif sort == "name_asc":
|
||||
stmt = stmt.order_by(Tag.name.asc())
|
||||
elif sort == "name_desc":
|
||||
stmt = stmt.order_by(Tag.name.desc())
|
||||
else:
|
||||
stmt = stmt.order_by(func.count(Screenshot.id).desc())
|
||||
|
||||
rows = session.execute(stmt.offset((page - 1) * size).limit(size)).all()
|
||||
items = [
|
||||
{"id": r[0], "name": r[1], "color": r[2], "count": int(r[3] or 0)} for r in rows
|
||||
]
|
||||
return {"items": items, "total": int(total), "page": page, "size": size}
|
||||
Reference in New Issue
Block a user