Files

557 lines
18 KiB
Python
Raw Permalink Normal View History

2026-06-12 16:04:03 +08:00
"""dataClean FastAPI 入口"""
import logging
import os
import threading
2026-06-12 16:04:03 +08:00
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from typing import Optional, List
from fastapi import FastAPI, Depends, HTTPException, Query, Body, Security, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, Response
2026-06-12 16:04:03 +08:00
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, ConfigDict
from sqlalchemy.orm import Session
from config import settings
from database import init_db, get_db, SessionLocal
from scheduler import init_scheduler, stop_scheduler, get_scheduler, get_task_lock
from app.taxonomy import bootstrap_taxonomy, list_taxonomy
from app.rss_client import rss_client, RSSKeeperClient
from app.ai_client import ai_client, AIClient
from app import task_progress
2026-06-12 16:04:03 +08:00
from app.summarizer import fetch_and_summarize
from app.tagger import tag_articles
from app.deduplicator import deduplicate_articles
from app.scorer import score_articles
from app.brief import generate_daily_brief
from app.settings_manager import (
init_default_settings,
list_settings,
get_setting,
set_setting,
reset_settings,
apply_db_settings_to_config,
)
from models import EnrichedArticle, DailyBrief, Taxonomy, DuplicateGroup, AppSetting
logging.basicConfig(
level=getattr(logging, settings.LOG_LEVEL.upper(), logging.INFO),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# API Token 鉴权(当配置时启用)
security_scheme = HTTPBearer(auto_error=False)
def _get_allowed_origins() -> List[str]:
"""解析 CORS 允许来源配置"""
raw = settings.CORS_ALLOWED_ORIGINS
if raw:
return [o.strip() for o in raw.split(",") if o.strip()]
# 默认只允许同源(Docker/生产由反向代理或浏览器同源访问)
return []
def verify_token(credentials: Optional[HTTPAuthorizationCredentials] = Security(security_scheme)):
"""验证 API Token;未配置时跳过鉴权"""
token = settings.API_TOKEN
if not token:
return None
if not credentials:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="缺少 Authorization 请求头",
headers={"WWW-Authenticate": "Bearer"},
)
if credentials.scheme != "Bearer" or credentials.credentials != token:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="无效的 API Token",
)
return credentials.credentials
def _run_task_background(task_key: str, trigger: str, fn) -> bool:
"""
将任务提交到后台线程执行,立即返回。
请求线程非阻塞获取 _task_lock(失败返回 False → 调用方抛 409),
并把锁所有权交给后台 worker。worker 内创建独立 SessionLocal
上报进度,执行 fn(db),最终释放锁。无 TOCTOU 窗口。
"""
if not get_task_lock().acquire(blocking=False):
return False # 锁被占用,调用方抛 409
def _worker():
db = SessionLocal()
task_progress.update_progress(
task_key, status="running", trigger=trigger,
stage="初始化", current=0, total=0, message=None,
)
try:
fn(db)
task_progress.update_progress(
task_key, status="success", stage="完成", message="任务执行成功"
)
except Exception as exc:
logger.error("后台任务 %s 失败: %s", task_key, exc, exc_info=True)
task_progress.update_progress(
task_key, status="error", stage="失败", message=str(exc)[:500]
)
finally:
db.close()
get_task_lock().release()
threading.Thread(target=_worker, name=f"task-{task_key}", daemon=True).start()
return True
2026-06-12 16:04:03 +08:00
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
logger.info("启动 dataClean 服务")
init_db()
db = SessionLocal()
try:
# 初始化默认配置
init_default_settings(db)
# 用数据库配置覆盖全局 settings
apply_db_settings_to_config(db)
# 注意:taxonomy 初始化交由 scheduler 的 bootstrap job 后台异步执行,
# 避免在启动时同步调用 LLM 阻塞服务就绪(进度可在前端实时查看)。
2026-06-12 16:04:03 +08:00
except Exception as exc:
logger.error("启动初始化失败: %s", exc)
finally:
db.close()
init_scheduler()
yield
stop_scheduler()
app = FastAPI(
title="dataClean",
description="RSS 数据清洗、摘要、分类、打分与简报生成服务",
version="1.0.0",
lifespan=lifespan,
)
# CORS 配置:生产环境收敛到具体域名,且不与 credentials=true 同时用通配符
_allowed_origins = _get_allowed_origins()
app.add_middleware(
CORSMiddleware,
allow_origins=_allowed_origins or ["*"],
allow_credentials=bool(_allowed_origins),
allow_methods=["*"],
allow_headers=["*"],
)
# ---------- Pydantic 模型 ----------
class ArticleOut(BaseModel):
id: int
rk_article_id: int
title: str
link: str
feed_title: str
category: str
tags: List[str]
heat_score: float
importance_score: float
duplication_score: float
composite_score: float
ai_summary: str
is_representative: bool
published_at: Optional[datetime]
2026-06-12 16:04:03 +08:00
model_config = ConfigDict(from_attributes=True)
class ArticleListOut(BaseModel):
items: List[ArticleOut]
total: int
class BriefOut(BaseModel):
id: int
brief_date: str
total_articles: int
unique_articles: int
by_category: dict
markdown_path: str
model_config = ConfigDict(from_attributes=True)
class TaxonomyOut(BaseModel):
id: int
name: str
kind: str
description: str
keywords: List[str]
weight: float
created_by_ai: bool
model_config = ConfigDict(from_attributes=True)
class SettingOut(BaseModel):
key: str
value: str
description: str
is_sensitive: bool
is_masked: bool
updated_at: Optional[str]
class SettingUpdate(BaseModel):
value: str
class BatchSettingsUpdate(BaseModel):
settings: dict
class StatsOut(BaseModel):
total_articles: int
today_articles: int
ai_summarized: int
categories: int
tags: int
duplicate_groups: int
briefs: int
next_jobs: dict
class ConnectionTestResult(BaseModel):
name: str
status: str
latency_ms: Optional[float] = None
error: Optional[str] = None
class ConnectionTestResponse(BaseModel):
rss_keeper: ConnectionTestResult
llm: ConnectionTestResult
2026-06-12 16:04:03 +08:00
# ---------- 健康检查 ----------
@app.get("/health")
def health():
return {"status": "ok", "service": "dataClean"}
# ---------- 文章接口 ----------
@app.get("/api/articles", response_model=ArticleListOut)
def list_articles(
date: Optional[str] = Query(None, description="日期 YYYY-MM-DD"),
category: Optional[str] = Query(None),
tag: Optional[str] = Query(None),
representative_only: bool = Query(False, description="仅返回重复组代表文章"),
limit: int = Query(50, ge=1, le=200),
offset: int = Query(0, ge=0),
db: Session = Depends(get_db),
):
query = db.query(EnrichedArticle)
if date:
day = datetime.strptime(date, "%Y-%m-%d")
next_day = day + timedelta(days=1)
query = query.filter(EnrichedArticle.fetched_at >= day, EnrichedArticle.fetched_at < next_day)
if category:
query = query.filter(EnrichedArticle.category == category)
if tag:
# SQLite JSON 列使用 json_each 做精确匹配,避免字符串子串误命中
query = query.filter(
EnrichedArticle.tags.contains([tag])
)
if representative_only:
query = query.filter(
(EnrichedArticle.is_representative == True) | (EnrichedArticle.duplicate_group_id == None)
)
total = query.count()
items = query.order_by(EnrichedArticle.composite_score.desc()).offset(offset).limit(limit).all()
return {"items": items, "total": total}
@app.get("/api/articles/{article_id}", response_model=ArticleOut)
def get_article(article_id: int, db: Session = Depends(get_db)):
article = db.query(EnrichedArticle).filter(EnrichedArticle.id == article_id).first()
if not article:
raise HTTPException(status_code=404, detail="文章不存在")
return article
# ---------- 简报接口 ----------
@app.get("/api/briefs", response_model=List[BriefOut])
def list_briefs(
limit: int = Query(30, ge=1, le=100),
db: Session = Depends(get_db),
):
return (
db.query(DailyBrief)
.order_by(DailyBrief.brief_date.desc())
.limit(limit)
.all()
)
@app.get("/api/briefs/{date}", response_model=BriefOut)
def get_brief(date: str, db: Session = Depends(get_db)):
brief = db.query(DailyBrief).filter(DailyBrief.brief_date == date).first()
if not brief:
raise HTTPException(status_code=404, detail="简报不存在")
return brief
@app.post("/api/briefs/{date}/regenerate")
def regenerate_brief(date: str, db: Session = Depends(get_db), _=Depends(verify_token)):
try:
data = generate_daily_brief(db, date_str=date, force=True)
return {"message": "简报已重新生成", "data": data}
except Exception as exc:
logger.error("重新生成简报失败: %s", exc)
raise HTTPException(status_code=500, detail=str(exc))
# ---------- 分类体系接口 ----------
@app.get("/api/taxonomy", response_model=List[TaxonomyOut])
def get_taxonomy(kind: Optional[str] = Query(None), db: Session = Depends(get_db)):
return list_taxonomy(db, kind=kind)
@app.post("/api/taxonomy/bootstrap")
def trigger_taxonomy_bootstrap(
force: bool = False,
_=Depends(verify_token),
):
def _run(session):
ok = bootstrap_taxonomy(session, force=force)
if not ok:
raise RuntimeError("taxonomy 已存在或初始化失败,请检查日志")
if not _run_task_background("bootstrap_taxonomy", "manual", _run):
raise HTTPException(status_code=409, detail="已有任务正在执行,请稍后再试")
return {"message": "taxonomy 初始化已开始", "task_key": "bootstrap_taxonomy"}
2026-06-12 16:04:03 +08:00
# ---------- 手动触发任务接口 ----------
# ---------- 手动触发任务接口(后台执行,立即返回,前端轮询进度) ----------
2026-06-12 16:04:03 +08:00
@app.post("/api/tasks/summarize")
def task_summarize(_=Depends(verify_token)):
def _run(session):
fetch_and_summarize(session, hours=24, limit=200)
if not _run_task_background("summarize", "manual", _run):
raise HTTPException(status_code=409, detail="已有任务正在执行,请稍后再试")
return {"message": "摘要任务已开始", "task_key": "summarize"}
2026-06-12 16:04:03 +08:00
@app.post("/api/tasks/tag-score-dedup")
def task_tag_score_dedup(_=Depends(verify_token)):
2026-06-12 16:04:03 +08:00
def _run(session):
tag_articles(session)
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
deduplicate_articles(session, date_str=today)
score_articles(session, update_duplication=True)
if not _run_task_background("tag_score_dedup", "manual", _run):
raise HTTPException(status_code=409, detail="已有任务正在执行,请稍后再试")
return {"message": "分类/去重/打分任务已开始", "task_key": "tag_score_dedup"}
2026-06-12 16:04:03 +08:00
@app.post("/api/tasks/brief")
def task_brief(_=Depends(verify_token)):
2026-06-12 16:04:03 +08:00
def _run(session):
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
generate_daily_brief(session, date_str=today, force=True)
if not _run_task_background("generate_daily_brief", "manual", _run):
raise HTTPException(status_code=409, detail="已有任务正在执行,请稍后再试")
return {"message": "简报生成任务已开始", "task_key": "generate_daily_brief"}
@app.get("/api/tasks/progress")
def get_task_progress(_=Depends(verify_token)):
"""返回所有任务的实时进度(前端轮询)"""
return task_progress.get_progress()
@app.post("/api/tasks/progress/reset")
def reset_task_progress(task_key: str = Query(...), _=Depends(verify_token)):
"""重置指定任务的进度显示为 idle"""
task_progress.reset_progress(task_key)
return {"message": "已重置"}
# ---------- 接口连通性测试 ----------
@app.post("/api/test-connection", response_model=ConnectionTestResponse)
def test_connection(_=Depends(verify_token)):
"""测试 rssKeeper 和 LLM API 连通性,返回状态和延迟"""
import time
# rssKeeper 连通测试(使用短超时,避免长时间等待)
rss_result = {"name": "rssKeeper", "status": "error", "latency_ms": None, "error": None}
try:
t0 = time.monotonic()
# 临时用短超时的 client 测试
test_client = RSSKeeperClient(base_url=settings.RSSKEEPER_BASE_URL, timeout=10)
test_client._get("/api/v1/external/feeds", params={"limit": 1})
rss_result = {
"name": "rssKeeper",
"status": "ok",
"latency_ms": round((time.monotonic() - t0) * 1000, 1),
"error": None,
}
except Exception as exc:
rss_result["error"] = str(exc)[:200]
# LLM 连通测试(使用短超时 + 无重试)
llm_result = {"name": "LLM", "status": "error", "latency_ms": None, "error": None}
try:
t0 = time.monotonic()
test_ai = AIClient(timeout=10, max_retries=0)
test_ai.chat_completion(
system_prompt="You are a connectivity test.",
user_prompt="Reply with exactly: ok",
temperature=0.0,
)
llm_result = {
"name": "LLM",
"status": "ok",
"latency_ms": round((time.monotonic() - t0) * 1000, 1),
"error": None,
}
except Exception as exc:
llm_result["error"] = str(exc)[:200]
return {"rss_keeper": rss_result, "llm": llm_result}
2026-06-12 16:04:03 +08:00
# ---------- 配置管理接口 ----------
@app.get("/api/settings", response_model=List[SettingOut])
def get_settings(db: Session = Depends(get_db), _=Depends(verify_token)):
return list_settings(db, mask_sensitive=True)
@app.put("/api/settings/{key}")
def update_setting(
key: str,
body: SettingUpdate,
db: Session = Depends(get_db),
_=Depends(verify_token),
):
ok = set_setting(db, key, body.value)
if not ok:
raise HTTPException(status_code=400, detail="无效的配置项")
return {"message": "配置已保存,重启服务后生效"}
@app.put("/api/settings")
def update_settings_batch(
body: BatchSettingsUpdate,
db: Session = Depends(get_db),
_=Depends(verify_token),
):
errors = []
for key, value in body.settings.items():
if not set_setting(db, key, value):
errors.append(key)
if errors:
raise HTTPException(status_code=400, detail=f"以下配置项无效: {', '.join(errors)}")
return {"message": "配置已保存,重启服务后生效"}
@app.post("/api/settings/reset")
def reset_all_settings(db: Session = Depends(get_db), _=Depends(verify_token)):
reset_settings(db)
return {"message": "配置已重置为环境变量默认值,重启服务后生效"}
# ---------- 仪表盘统计接口 ----------
@app.get("/api/stats", response_model=StatsOut)
def get_stats(db: Session = Depends(get_db)):
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
day_start = datetime.strptime(today, "%Y-%m-%d")
day_end = day_start + timedelta(days=1)
total_articles = db.query(EnrichedArticle).count()
today_articles = (
db.query(EnrichedArticle)
.filter(EnrichedArticle.fetched_at >= day_start, EnrichedArticle.fetched_at < day_end)
.count()
)
ai_summarized = db.query(EnrichedArticle).filter(EnrichedArticle.ai_summary != "").count()
categories = db.query(Taxonomy).filter(Taxonomy.kind == "category").count()
tags = db.query(Taxonomy).filter(Taxonomy.kind == "tag").count()
duplicate_groups = db.query(DuplicateGroup).count()
briefs = db.query(DailyBrief).count()
scheduler = get_scheduler()
next_jobs = {}
for job in scheduler.get_jobs():
next_jobs[job.id] = job.next_run_time.isoformat() if job.next_run_time else None
return {
"total_articles": total_articles,
"today_articles": today_articles,
"ai_summarized": ai_summarized,
"categories": categories,
"tags": tags,
"duplicate_groups": duplicate_groups,
"briefs": briefs,
"next_jobs": next_jobs,
}
# ---------- 静态文件托管(生产环境 SPA ----------
2026-06-12 16:04:03 +08:00
static_dir = os.path.join(os.path.dirname(__file__), "static")
if not os.path.isdir(static_dir):
# 本地构建时 frontend/dist 也可作为静态文件源
frontend_dist = os.path.join(os.path.dirname(__file__), "frontend", "dist")
if os.path.isdir(frontend_dist):
static_dir = frontend_dist
if os.path.isdir(static_dir):
# 静态资源(JS/CSS/图片等)走 /assets 子路径挂载
assets_dir = os.path.join(static_dir, "assets")
if os.path.isdir(assets_dir):
app.mount("/assets", StaticFiles(directory=assets_dir), name="assets")
# SPA favicon、vite.svg 等根级静态文件
@app.get("/favicon.ico")
@app.get("/vite.svg")
async def serve_static_root(request):
from starlette.requests import Request
filename = os.path.basename(str(request.url.path))
file_path = os.path.join(static_dir, filename)
if os.path.isfile(file_path):
return FileResponse(file_path)
return Response(status_code=404)
# 所有未匹配的路由 → 返回 index.html(SPA 客户端路由)
@app.get("/{full_path:path}")
async def serve_spa(full_path: str):
# 先尝试匹配静态文件
file_path = os.path.join(static_dir, full_path)
if full_path and os.path.isfile(file_path):
return FileResponse(file_path)
# 否则返回 index.html 让 Vue Router 处理
return FileResponse(os.path.join(static_dir, "index.html"))
2026-06-12 16:04:03 +08:00
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7331)