"""设置接口:Provider 配置、分类、Tag。""" from __future__ import annotations from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy import func, select from sqlalchemy.orm import Session from app.api.deps import db_session from app.models.category import Category from app.models.setting import ( DEFAULT_RECOGNITION_MODE, KEY_OCR_PROVIDER, KEY_RECOGNITION_MODE, KEY_VLM_PROVIDER, ) from app.models.screenshot import Screenshot from app.models.tag import Tag from app.providers import RECOGNITION_MODES from app.schemas.common import ( CategoryIn, ProviderConfig, ProviderConfigOut, ProviderTestResult, RecognitionModeIn, ) from app.services.provider_test import merge_provider_api_key, test_provider_config from app.services.settings_store import all_settings, get_setting, set_setting router = APIRouter(prefix="/api/settings", tags=["settings"]) @router.get("") def get_all(session: Session = Depends(db_session)) -> dict: """返回所有非敏感设置。api_key 字段做脱敏。""" raw = all_settings(session) for key in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER): cfg = raw.get(key) if isinstance(cfg, dict) and cfg.get("api_key"): cfg["api_key_mask"] = _mask(cfg["api_key"]) cfg["api_key"] = "" return raw def _mask(value: str) -> str: if not value: return "" if len(value) <= 6: return "*" * len(value) return value[:3] + "*" * (len(value) - 6) + value[-3:] @router.get("/providers/{key}", response_model=ProviderConfigOut | None) def get_provider( key: str, session: Session = Depends(db_session), ) -> ProviderConfigOut | None: """读取 Provider 配置:api_key 明文不外传,只给一个掩码用于 UI 提示。""" if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER): raise HTTPException(400, "key must be ocr_provider or vlm_provider") raw = get_setting(session, key, None) if not raw: return None mask = _mask(raw.get("api_key", "") or "") return ProviderConfigOut( type=raw.get("type", ""), base_url=raw.get("base_url"), api_key="", api_key_mask=mask or None, model=raw.get("model"), extra=raw.get("extra", {}) or {}, ) @router.put("/providers/{key}") def put_provider( key: str, cfg: ProviderConfig, session: Session = Depends(db_session), ) -> dict: if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER): raise HTTPException(400, "key must be ocr_provider or vlm_provider") # 如果客户端没有传新的 api_key(空字符串),保留旧值 existing = get_setting(session, key, None) payload = cfg.model_dump() if (not payload.get("api_key")) and isinstance(existing, dict): payload["api_key"] = existing.get("api_key", "") set_setting(session, key, payload) session.commit() return {"ok": True} @router.post("/providers/{key}/test", response_model=ProviderTestResult) async def test_provider( key: str, cfg: ProviderConfig, session: Session = Depends(db_session), ) -> ProviderTestResult: """测试 OCR / 视觉 AI Provider 连通性(使用当前表单配置,api_key 可留空沿用已保存值)。""" if key not in (KEY_OCR_PROVIDER, KEY_VLM_PROVIDER): raise HTTPException(400, "key must be ocr_provider or vlm_provider") existing = get_setting(session, key, None) merged = merge_provider_api_key(cfg, existing if isinstance(existing, dict) else None) result = await test_provider_config(key, merged) return ProviderTestResult(**result) @router.get("/recognition-mode") def get_recognition_mode(session: Session = Depends(db_session)) -> dict: """读取文字识别策略:ocr / vision / hybrid。""" mode = get_setting(session, KEY_RECOGNITION_MODE, DEFAULT_RECOGNITION_MODE) if mode not in RECOGNITION_MODES: mode = DEFAULT_RECOGNITION_MODE return {"mode": mode, "options": list(RECOGNITION_MODES)} @router.put("/recognition-mode") def put_recognition_mode( payload: RecognitionModeIn, session: Session = Depends(db_session), ) -> dict: """保存文字识别策略。""" if payload.mode not in RECOGNITION_MODES: raise HTTPException(400, f"mode must be one of {RECOGNITION_MODES}") set_setting(session, KEY_RECOGNITION_MODE, payload.mode) session.commit() return {"ok": True, "mode": payload.mode} @router.get("/categories") def list_categories(session: Session = Depends(db_session)) -> list[dict]: rows = session.scalars(select(Category).order_by(Category.id)).all() return [ {"id": c.id, "name": c.name, "color": c.color, "prompt_hint": c.prompt_hint} for c in rows ] @router.post("/categories") def create_category( payload: CategoryIn, session: Session = Depends(db_session), ) -> dict: exists = session.scalar(select(Category).where(Category.name == payload.name)) if exists is not None: raise HTTPException(400, "category exists") cat = Category(name=payload.name, color=payload.color, prompt_hint=payload.prompt_hint) session.add(cat) session.commit() session.refresh(cat) return {"id": cat.id} @router.patch("/categories/{cat_id}") def update_category( cat_id: int, payload: CategoryIn, session: Session = Depends(db_session), ) -> dict: cat = session.get(Category, cat_id) if cat is None: raise HTTPException(404, "category not found") cat.name = payload.name cat.color = payload.color cat.prompt_hint = payload.prompt_hint session.commit() return {"ok": True} @router.delete("/categories/{cat_id}") def delete_category( cat_id: int, session: Session = Depends(db_session), ) -> dict: cat = session.get(Category, cat_id) if cat is None: raise HTTPException(404, "category not found") session.delete(cat) session.commit() return {"ok": True} @router.get("/tags") def list_tags( session: Session = Depends(db_session), q: Optional[str] = Query(None, description="标签名关键词"), page: int = Query(1, ge=1), size: int = Query(200, ge=1, le=500), sort: str = Query("count_desc", description="count_desc|count_asc|name_asc|name_desc"), ) -> dict: """标签列表(含使用次数),支持搜索与分页。""" base = select(Tag.id) if q: base = base.where(Tag.name.ilike(f"%{q.strip()}%")) total = session.scalar(select(func.count()).select_from(base.subquery())) or 0 stmt = ( select(Tag.id, Tag.name, Tag.color, func.count(Screenshot.id)) .join(Tag.screenshots, isouter=True) .group_by(Tag.id) ) if q: stmt = stmt.where(Tag.name.ilike(f"%{q.strip()}%")) if sort == "count_asc": stmt = stmt.order_by(func.count(Screenshot.id).asc()) elif sort == "name_asc": stmt = stmt.order_by(Tag.name.asc()) elif sort == "name_desc": stmt = stmt.order_by(Tag.name.desc()) else: stmt = stmt.order_by(func.count(Screenshot.id).desc()) rows = session.execute(stmt.offset((page - 1) * size).limit(size)).all() items = [ {"id": r[0], "name": r[1], "color": r[2], "count": int(r[3] or 0)} for r in rows ] return {"items": items, "total": int(total), "page": page, "size": size}