Files

221 lines
7.2 KiB
Python
Raw Permalink Normal View History

"""设置接口:Provider 配置、分类、Tag。"""
from __future__ import annotations
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from app.api.deps import db_session
from app.models.category import Category
from app.models.setting import (
DEFAULT_RECOGNITION_MODE,
KEY_OCR_PROVIDER,
KEY_RECOGNITION_MODE,
KEY_VLM_PROVIDER,
)
from app.models.screenshot import Screenshot
from app.models.tag import Tag
from app.providers import RECOGNITION_MODES
from app.schemas.common import (
CategoryIn,
ProviderConfig,
ProviderConfigOut,
ProviderTestResult,
RecognitionModeIn,
)
from app.services.provider_test import merge_provider_api_key, test_provider_config
from app.services.settings_store import all_settings, get_setting, set_setting
router = APIRouter(prefix="/api/settings", tags=["settings"])
@router.get("")
def get_all(session: Session = Depends(db_session)) -> dict:
"""返回所有非敏感设置。api_key 字段做脱敏。"""
raw = all_settings(session)
for key in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
cfg = raw.get(key)
if isinstance(cfg, dict) and cfg.get("api_key"):
cfg["api_key_mask"] = _mask(cfg["api_key"])
cfg["api_key"] = ""
return raw
def _mask(value: str) -> str:
if not value:
return ""
if len(value) <= 6:
return "*" * len(value)
return value[:3] + "*" * (len(value) - 6) + value[-3:]
@router.get("/providers/{key}", response_model=ProviderConfigOut | None)
def get_provider(
key: str,
session: Session = Depends(db_session),
) -> ProviderConfigOut | None:
"""读取 Provider 配置:api_key 明文不外传,只给一个掩码用于 UI 提示。"""
if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
raise HTTPException(400, "key must be ocr_provider or vlm_provider")
raw = get_setting(session, key, None)
if not raw:
return None
mask = _mask(raw.get("api_key", "") or "")
return ProviderConfigOut(
type=raw.get("type", ""),
base_url=raw.get("base_url"),
api_key="",
api_key_mask=mask or None,
model=raw.get("model"),
extra=raw.get("extra", {}) or {},
)
@router.put("/providers/{key}")
def put_provider(
key: str,
cfg: ProviderConfig,
session: Session = Depends(db_session),
) -> dict:
if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
raise HTTPException(400, "key must be ocr_provider or vlm_provider")
# 如果客户端没有传新的 api_key(空字符串),保留旧值
existing = get_setting(session, key, None)
payload = cfg.model_dump()
if (not payload.get("api_key")) and isinstance(existing, dict):
payload["api_key"] = existing.get("api_key", "")
set_setting(session, key, payload)
session.commit()
return {"ok": True}
@router.post("/providers/{key}/test", response_model=ProviderTestResult)
async def test_provider(
key: str,
cfg: ProviderConfig,
session: Session = Depends(db_session),
) -> ProviderTestResult:
"""测试 OCR / 视觉 AI Provider 连通性(使用当前表单配置,api_key 可留空沿用已保存值)。"""
if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER):
raise HTTPException(400, "key must be ocr_provider or vlm_provider")
existing = get_setting(session, key, None)
merged = merge_provider_api_key(cfg, existing if isinstance(existing, dict) else None)
result = await test_provider_config(key, merged)
return ProviderTestResult(**result)
@router.get("/recognition-mode")
def get_recognition_mode(session: Session = Depends(db_session)) -> dict:
"""读取文字识别策略:ocr / vision / hybrid。"""
mode = get_setting(session, KEY_RECOGNITION_MODE, DEFAULT_RECOGNITION_MODE)
if mode not in RECOGNITION_MODES:
mode = DEFAULT_RECOGNITION_MODE
return {"mode": mode, "options": list(RECOGNITION_MODES)}
@router.put("/recognition-mode")
def put_recognition_mode(
payload: RecognitionModeIn,
session: Session = Depends(db_session),
) -> dict:
"""保存文字识别策略。"""
if payload.mode not in RECOGNITION_MODES:
raise HTTPException(400, f"mode must be one of {RECOGNITION_MODES}")
set_setting(session, KEY_RECOGNITION_MODE, payload.mode)
session.commit()
return {"ok": True, "mode": payload.mode}
@router.get("/categories")
def list_categories(session: Session = Depends(db_session)) -> list[dict]:
rows = session.scalars(select(Category).order_by(Category.id)).all()
return [
{"id": c.id, "name": c.name, "color": c.color, "prompt_hint": c.prompt_hint}
for c in rows
]
@router.post("/categories")
def create_category(
payload: CategoryIn,
session: Session = Depends(db_session),
) -> dict:
exists = session.scalar(select(Category).where(Category.name == payload.name))
if exists is not None:
raise HTTPException(400, "category exists")
cat = Category(name=payload.name, color=payload.color, prompt_hint=payload.prompt_hint)
session.add(cat)
session.commit()
session.refresh(cat)
return {"id": cat.id}
@router.patch("/categories/{cat_id}")
def update_category(
cat_id: int,
payload: CategoryIn,
session: Session = Depends(db_session),
) -> dict:
cat = session.get(Category, cat_id)
if cat is None:
raise HTTPException(404, "category not found")
cat.name = payload.name
cat.color = payload.color
cat.prompt_hint = payload.prompt_hint
session.commit()
return {"ok": True}
@router.delete("/categories/{cat_id}")
def delete_category(
cat_id: int,
session: Session = Depends(db_session),
) -> dict:
cat = session.get(Category, cat_id)
if cat is None:
raise HTTPException(404, "category not found")
session.delete(cat)
session.commit()
return {"ok": True}
@router.get("/tags")
def list_tags(
session: Session = Depends(db_session),
q: Optional[str] = Query(None, description="标签名关键词"),
page: int = Query(1, ge=1),
size: int = Query(200, ge=1, le=500),
sort: str = Query("count_desc", description="count_desc|count_asc|name_asc|name_desc"),
) -> dict:
"""标签列表(含使用次数),支持搜索与分页。"""
base = select(Tag.id)
if q:
base = base.where(Tag.name.ilike(f"%{q.strip()}%"))
total = session.scalar(select(func.count()).select_from(base.subquery())) or 0
stmt = (
select(Tag.id, Tag.name, Tag.color, func.count(Screenshot.id))
.join(Tag.screenshots, isouter=True)
.group_by(Tag.id)
)
if q:
stmt = stmt.where(Tag.name.ilike(f"%{q.strip()}%"))
if sort == "count_asc":
stmt = stmt.order_by(func.count(Screenshot.id).asc())
elif sort == "name_asc":
stmt = stmt.order_by(Tag.name.asc())
elif sort == "name_desc":
stmt = stmt.order_by(Tag.name.desc())
else:
stmt = stmt.order_by(func.count(Screenshot.id).desc())
rows = session.execute(stmt.offset((page - 1) * size).limit(size)).all()
items = [
{"id": r[0], "name": r[1], "color": r[2], "count": int(r[3] or 0)} for r in rows
]
return {"items": items, "total": int(total), "page": page, "size": size}