5c028d7952
包含 FastAPI 后端、React 前端、队列/OCR/标签/待办等完整功能。 Co-authored-by: Cursor <cursoragent@cursor.com>
221 lines
7.2 KiB
Python
221 lines
7.2 KiB
Python
"""设置接口:Provider 配置、分类、Tag。"""
|
|
from __future__ import annotations
|
|
|
|
from typing import Optional
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.api.deps import db_session
|
|
from app.models.category import Category
|
|
from app.models.setting import (
|
|
DEFAULT_RECOGNITION_MODE,
|
|
KEY_OCR_PROVIDER,
|
|
KEY_RECOGNITION_MODE,
|
|
KEY_VLM_PROVIDER,
|
|
)
|
|
from app.models.screenshot import Screenshot
|
|
from app.models.tag import Tag
|
|
from app.providers import RECOGNITION_MODES
|
|
from app.schemas.common import (
|
|
CategoryIn,
|
|
ProviderConfig,
|
|
ProviderConfigOut,
|
|
ProviderTestResult,
|
|
RecognitionModeIn,
|
|
)
|
|
from app.services.provider_test import merge_provider_api_key, test_provider_config
|
|
from app.services.settings_store import all_settings, get_setting, set_setting
|
|
|
|
|
|
router = APIRouter(prefix="/api/settings", tags=["settings"])
|
|
|
|
|
|
@router.get("")
|
|
def get_all(session: Session = Depends(db_session)) -> dict:
|
|
"""返回所有非敏感设置。api_key 字段做脱敏。"""
|
|
raw = all_settings(session)
|
|
for key in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
|
|
cfg = raw.get(key)
|
|
if isinstance(cfg, dict) and cfg.get("api_key"):
|
|
cfg["api_key_mask"] = _mask(cfg["api_key"])
|
|
cfg["api_key"] = ""
|
|
return raw
|
|
|
|
|
|
def _mask(value: str) -> str:
|
|
if not value:
|
|
return ""
|
|
if len(value) <= 6:
|
|
return "*" * len(value)
|
|
return value[:3] + "*" * (len(value) - 6) + value[-3:]
|
|
|
|
|
|
@router.get("/providers/{key}", response_model=ProviderConfigOut | None)
|
|
def get_provider(
|
|
key: str,
|
|
session: Session = Depends(db_session),
|
|
) -> ProviderConfigOut | None:
|
|
"""读取 Provider 配置:api_key 明文不外传,只给一个掩码用于 UI 提示。"""
|
|
if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
|
|
raise HTTPException(400, "key must be ocr_provider or vlm_provider")
|
|
raw = get_setting(session, key, None)
|
|
if not raw:
|
|
return None
|
|
mask = _mask(raw.get("api_key", "") or "")
|
|
return ProviderConfigOut(
|
|
type=raw.get("type", ""),
|
|
base_url=raw.get("base_url"),
|
|
api_key="",
|
|
api_key_mask=mask or None,
|
|
model=raw.get("model"),
|
|
extra=raw.get("extra", {}) or {},
|
|
)
|
|
|
|
|
|
@router.put("/providers/{key}")
|
|
def put_provider(
|
|
key: str,
|
|
cfg: ProviderConfig,
|
|
session: Session = Depends(db_session),
|
|
) -> dict:
|
|
if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
|
|
raise HTTPException(400, "key must be ocr_provider or vlm_provider")
|
|
# 如果客户端没有传新的 api_key(空字符串),保留旧值
|
|
existing = get_setting(session, key, None)
|
|
payload = cfg.model_dump()
|
|
if (not payload.get("api_key")) and isinstance(existing, dict):
|
|
payload["api_key"] = existing.get("api_key", "")
|
|
set_setting(session, key, payload)
|
|
session.commit()
|
|
return {"ok": True}
|
|
|
|
|
|
@router.post("/providers/{key}/test", response_model=ProviderTestResult)
|
|
async def test_provider(
|
|
key: str,
|
|
cfg: ProviderConfig,
|
|
session: Session = Depends(db_session),
|
|
) -> ProviderTestResult:
|
|
"""测试 OCR / 视觉 AI Provider 连通性(使用当前表单配置,api_key 可留空沿用已保存值)。"""
|
|
if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
|
|
raise HTTPException(400, "key must be ocr_provider or vlm_provider")
|
|
existing = get_setting(session, key, None)
|
|
merged = merge_provider_api_key(cfg, existing if isinstance(existing, dict) else None)
|
|
result = await test_provider_config(key, merged)
|
|
return ProviderTestResult(**result)
|
|
|
|
|
|
@router.get("/recognition-mode")
|
|
def get_recognition_mode(session: Session = Depends(db_session)) -> dict:
|
|
"""读取文字识别策略:ocr / vision / hybrid。"""
|
|
mode = get_setting(session, KEY_RECOGNITION_MODE, DEFAULT_RECOGNITION_MODE)
|
|
if mode not in RECOGNITION_MODES:
|
|
mode = DEFAULT_RECOGNITION_MODE
|
|
return {"mode": mode, "options": list(RECOGNITION_MODES)}
|
|
|
|
|
|
@router.put("/recognition-mode")
|
|
def put_recognition_mode(
|
|
payload: RecognitionModeIn,
|
|
session: Session = Depends(db_session),
|
|
) -> dict:
|
|
"""保存文字识别策略。"""
|
|
if payload.mode not in RECOGNITION_MODES:
|
|
raise HTTPException(400, f"mode must be one of {RECOGNITION_MODES}")
|
|
set_setting(session, KEY_RECOGNITION_MODE, payload.mode)
|
|
session.commit()
|
|
return {"ok": True, "mode": payload.mode}
|
|
|
|
|
|
@router.get("/categories")
|
|
def list_categories(session: Session = Depends(db_session)) -> list[dict]:
|
|
rows = session.scalars(select(Category).order_by(Category.id)).all()
|
|
return [
|
|
{"id": c.id, "name": c.name, "color": c.color, "prompt_hint": c.prompt_hint}
|
|
for c in rows
|
|
]
|
|
|
|
|
|
@router.post("/categories")
|
|
def create_category(
|
|
payload: CategoryIn,
|
|
session: Session = Depends(db_session),
|
|
) -> dict:
|
|
exists = session.scalar(select(Category).where(Category.name == payload.name))
|
|
if exists is not None:
|
|
raise HTTPException(400, "category exists")
|
|
cat = Category(name=payload.name, color=payload.color, prompt_hint=payload.prompt_hint)
|
|
session.add(cat)
|
|
session.commit()
|
|
session.refresh(cat)
|
|
return {"id": cat.id}
|
|
|
|
|
|
@router.patch("/categories/{cat_id}")
|
|
def update_category(
|
|
cat_id: int,
|
|
payload: CategoryIn,
|
|
session: Session = Depends(db_session),
|
|
) -> dict:
|
|
cat = session.get(Category, cat_id)
|
|
if cat is None:
|
|
raise HTTPException(404, "category not found")
|
|
cat.name = payload.name
|
|
cat.color = payload.color
|
|
cat.prompt_hint = payload.prompt_hint
|
|
session.commit()
|
|
return {"ok": True}
|
|
|
|
|
|
@router.delete("/categories/{cat_id}")
|
|
def delete_category(
|
|
cat_id: int,
|
|
session: Session = Depends(db_session),
|
|
) -> dict:
|
|
cat = session.get(Category, cat_id)
|
|
if cat is None:
|
|
raise HTTPException(404, "category not found")
|
|
session.delete(cat)
|
|
session.commit()
|
|
return {"ok": True}
|
|
|
|
|
|
@router.get("/tags")
|
|
def list_tags(
|
|
session: Session = Depends(db_session),
|
|
q: Optional[str] = Query(None, description="标签名关键词"),
|
|
page: int = Query(1, ge=1),
|
|
size: int = Query(200, ge=1, le=500),
|
|
sort: str = Query("count_desc", description="count_desc|count_asc|name_asc|name_desc"),
|
|
) -> dict:
|
|
"""标签列表(含使用次数),支持搜索与分页。"""
|
|
base = select(Tag.id)
|
|
if q:
|
|
base = base.where(Tag.name.ilike(f"%{q.strip()}%"))
|
|
total = session.scalar(select(func.count()).select_from(base.subquery())) or 0
|
|
|
|
stmt = (
|
|
select(Tag.id, Tag.name, Tag.color, func.count(Screenshot.id))
|
|
.join(Tag.screenshots, isouter=True)
|
|
.group_by(Tag.id)
|
|
)
|
|
if q:
|
|
stmt = stmt.where(Tag.name.ilike(f"%{q.strip()}%"))
|
|
|
|
if sort == "count_asc":
|
|
stmt = stmt.order_by(func.count(Screenshot.id).asc())
|
|
elif sort == "name_asc":
|
|
stmt = stmt.order_by(Tag.name.asc())
|
|
elif sort == "name_desc":
|
|
stmt = stmt.order_by(Tag.name.desc())
|
|
else:
|
|
stmt = stmt.order_by(func.count(Screenshot.id).desc())
|
|
|
|
rows = session.execute(stmt.offset((page - 1) * size).limit(size)).all()
|
|
items = [
|
|
{"id": r[0], "name": r[1], "color": r[2], "count": int(r[3] or 0)} for r in rows
|
|
]
|
|
return {"items": items, "total": int(total), "page": page, "size": size}
|