Initial commit: RSS platform phase 1 skeleton with code review fixes
Features: - FastAPI + SQLAlchemy 2.0 async + PostgreSQL/pgvector + Redis backend - Vue 3 + TypeScript + Element Plus frontend - JWT auth with access/refresh tokens and revocation - Admin/member RBAC - RSS feed CRUD and article listing - Settings management with Fernet encryption for sensitive values - Redis distributed lock service - Alembic initial migration - Docker Compose development environment Fixes from code review: - Fix DB session leak in dependency injection - Restrict registration to admin only - Add default admin password warning - Implement JWT refresh tokens and jti blacklist - Strengthen password policy - Use func.count for pagination totals - Replace NullPool with AsyncAdaptedQueuePool - Remove init_db from lifespan to enforce alembic migrations - Add request_id middleware and logging filter - Fix vite.config.ts env loading - Add frontend token refresh interceptor - Add Vue error handler Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,82 @@
|
||||
"""FastAPI dependencies."""
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import decode_token, is_token_revoked
|
||||
from app.core.database import get_db as _get_db
|
||||
from app.core.rbac import require_admin
|
||||
from app.core.redis import get_redis
|
||||
from app.models.user import User
|
||||
from app.schemas.user import TokenPayload
|
||||
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Yield async database session managed by FastAPI."""
|
||||
async for session in _get_db():
|
||||
yield session
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(security),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""Get current authenticated user from JWT access token."""
|
||||
if not credentials:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
token = credentials.credentials
|
||||
|
||||
try:
|
||||
payload = decode_token(token, expected_type="access")
|
||||
token_data = TokenPayload(**payload)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Invalid authentication credentials: {exc}",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
) from exc
|
||||
|
||||
if not token_data.sub or not token_data.jti:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token payload",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
revoked = await is_token_revoked(token_data.jti)
|
||||
if revoked:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has been revoked",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
user = await db.get(User, token_data.sub)
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_admin(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""Get current user and require admin role."""
|
||||
return require_admin(current_user)
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Admin locks router."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_admin, get_db
|
||||
from app.models.lock import Lock
|
||||
from app.models.user import User
|
||||
from app.schemas.common import MessageResponse
|
||||
|
||||
router = APIRouter(prefix="/locks", tags=["admin"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_locks(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""List active locks."""
|
||||
result = await db.execute(select(Lock))
|
||||
locks = result.scalars().all()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
active_locks = [
|
||||
{
|
||||
"id": str(lock.id),
|
||||
"lock_name": lock.lock_name,
|
||||
"owner_id": lock.owner_id,
|
||||
"acquired_at": lock.acquired_at.isoformat() if lock.acquired_at else None,
|
||||
"expires_at": lock.expires_at.isoformat() if lock.expires_at else None,
|
||||
"is_expired": lock.expires_at is not None and lock.expires_at < now,
|
||||
}
|
||||
for lock in locks
|
||||
]
|
||||
|
||||
return {"total": len(active_locks), "items": active_locks}
|
||||
|
||||
|
||||
@router.delete("/{lock_name}", response_model=MessageResponse)
|
||||
async def force_release_lock(
|
||||
lock_name: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""Force release a lock."""
|
||||
result = await db.execute(select(Lock).where(Lock.lock_name == lock_name))
|
||||
lock = result.scalar_one_or_none()
|
||||
if not lock:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Lock not found")
|
||||
|
||||
await db.delete(lock)
|
||||
await db.commit()
|
||||
return {"message": f"Lock {lock_name} released"}
|
||||
@@ -0,0 +1,79 @@
|
||||
"""Articles router."""
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user, get_db
|
||||
from app.models.article import CleanedArticle
|
||||
from app.models.user import User
|
||||
from app.schemas.article import ArticleListParams, ArticleOut
|
||||
from app.schemas.common import MessageResponse, PaginatedResponse
|
||||
|
||||
router = APIRouter(prefix="/articles", tags=["articles"])
|
||||
|
||||
|
||||
@router.get("", response_model=PaginatedResponse)
|
||||
async def list_articles(
|
||||
params: ArticleListParams = Depends(),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List cleaned articles with filters."""
|
||||
query = select(CleanedArticle)
|
||||
|
||||
if params.feed_id:
|
||||
query = query.where(CleanedArticle.feed_id == params.feed_id)
|
||||
if params.category:
|
||||
query = query.where(CleanedArticle.category == params.category)
|
||||
if params.tag:
|
||||
query = query.where(CleanedArticle.tags.contains([params.tag]))
|
||||
if params.search:
|
||||
query = query.where(
|
||||
CleanedArticle.title.ilike(f"%{params.search}%")
|
||||
| CleanedArticle.ai_summary.ilike(f"%{params.search}%")
|
||||
)
|
||||
if params.is_read is not None:
|
||||
# CleanedArticle doesn't have is_read in current schema; placeholder
|
||||
pass
|
||||
|
||||
# Count
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total = (await db.execute(count_query)).scalar_one()
|
||||
|
||||
# Paginate
|
||||
query = (
|
||||
query.offset(params.skip)
|
||||
.limit(params.limit)
|
||||
.order_by(CleanedArticle.published_at.desc().nulls_last())
|
||||
)
|
||||
result = await db.execute(query)
|
||||
items = result.scalars().all()
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"items": [ArticleOut.model_validate(item) for item in items],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{article_id}", response_model=ArticleOut)
|
||||
async def get_article(
|
||||
article_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get a single cleaned article."""
|
||||
article = await db.get(CleanedArticle, article_id)
|
||||
if not article:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Article not found")
|
||||
return ArticleOut.model_validate(article)
|
||||
|
||||
|
||||
@router.put("/{article_id}/read", response_model=MessageResponse)
|
||||
async def mark_article_read(
|
||||
article_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Mark an article as read (placeholder)."""
|
||||
# In Phase 1, cleaned_articles doesn't have is_read field yet
|
||||
return {"message": "Article marked as read"}
|
||||
@@ -0,0 +1,143 @@
|
||||
"""Authentication router."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_admin, get_current_user, get_db
|
||||
from app.core.auth import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
get_password_hash,
|
||||
revoke_token,
|
||||
verify_password,
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.schemas.user import (
|
||||
RefreshTokenRequest,
|
||||
TokenResponse,
|
||||
UserCreate,
|
||||
UserLogin,
|
||||
UserOut,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserOut)
|
||||
async def register(
|
||||
user_in: UserCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: User = Depends(get_current_admin),
|
||||
):
|
||||
"""Register a new user (admin only)."""
|
||||
# Check if username exists
|
||||
result = await db.execute(select(User).where(User.username == user_in.username))
|
||||
existing = result.scalar_one_or_none()
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Username already exists",
|
||||
)
|
||||
|
||||
user = User(
|
||||
username=user_in.username,
|
||||
password_hash=get_password_hash(user_in.password),
|
||||
role=user_in.role,
|
||||
is_active=user_in.is_active,
|
||||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(
|
||||
credentials: UserLogin,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Login and get access/refresh tokens."""
|
||||
result = await db.execute(select(User).where(User.username == credentials.username))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not verify_password(credentials.password, user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user",
|
||||
)
|
||||
|
||||
user.last_login_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
|
||||
access_token, _ = create_access_token(sub=str(user.id), role=user.role)
|
||||
refresh_token, _ = create_refresh_token(sub=str(user.id))
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh(
|
||||
req: RefreshTokenRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Exchange a valid refresh token for a new token pair."""
|
||||
try:
|
||||
payload = decode_token(req.refresh_token, expected_type="refresh")
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Invalid refresh token: {exc}",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
) from exc
|
||||
|
||||
user = await db.get(User, payload["sub"])
|
||||
if user is None or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid user",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
access_token, _ = create_access_token(sub=str(user.id), role=user.role)
|
||||
refresh_token, _ = create_refresh_token(sub=str(user.id))
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
|
||||
@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def logout(
|
||||
req: RefreshTokenRequest,
|
||||
):
|
||||
"""Revoke the provided refresh token."""
|
||||
try:
|
||||
payload = decode_token(req.refresh_token, expected_type="refresh")
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
exp = payload.get("exp")
|
||||
if exp:
|
||||
expires_at = datetime.fromtimestamp(exp, tz=timezone.utc)
|
||||
await revoke_token(payload["jti"], expires_at)
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserOut)
|
||||
async def get_me(current_user: User = Depends(get_current_user)):
|
||||
"""Get current user info."""
|
||||
return current_user
|
||||
@@ -0,0 +1,135 @@
|
||||
"""Feeds router."""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user, get_db
|
||||
from app.models.feed import Feed
|
||||
from app.models.user import User
|
||||
from app.schemas.common import MessageResponse, PaginatedResponse, PaginationParams
|
||||
from app.schemas.feed import FeedCreate, FeedOut, FeedUpdate
|
||||
|
||||
router = APIRouter(prefix="/feeds", tags=["feeds"])
|
||||
|
||||
|
||||
@router.get("", response_model=PaginatedResponse)
|
||||
async def list_feeds(
|
||||
pagination: PaginationParams = Depends(),
|
||||
category: str | None = Query(None),
|
||||
search: str | None = Query(None),
|
||||
is_active: bool | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List RSS feeds with pagination and filters."""
|
||||
query = select(Feed)
|
||||
|
||||
if category:
|
||||
query = query.where(Feed.category == category)
|
||||
if search:
|
||||
query = query.where(
|
||||
Feed.title.ilike(f"%{search}%")
|
||||
| Feed.url.ilike(f"%{search}%")
|
||||
| Feed.description.ilike(f"%{search}%")
|
||||
)
|
||||
if is_active is not None:
|
||||
query = query.where(Feed.is_active == is_active)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total = (await db.execute(count_query)).scalar_one()
|
||||
|
||||
# Get paginated items
|
||||
query = query.offset(pagination.skip).limit(pagination.limit).order_by(Feed.created_at.desc())
|
||||
result = await db.execute(query)
|
||||
items = result.scalars().all()
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"items": [FeedOut.model_validate(item) for item in items],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{feed_id}", response_model=FeedOut)
|
||||
async def get_feed(
|
||||
feed_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get a single feed by ID."""
|
||||
feed = await db.get(Feed, feed_id)
|
||||
if not feed:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Feed not found")
|
||||
return FeedOut.model_validate(feed)
|
||||
|
||||
|
||||
@router.post("", response_model=FeedOut, status_code=status.HTTP_201_CREATED)
|
||||
async def create_feed(
|
||||
feed_in: FeedCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Create a new RSS feed."""
|
||||
# Check URL uniqueness
|
||||
result = await db.execute(select(Feed).where(Feed.url == str(feed_in.url)))
|
||||
existing = result.scalar_one_or_none()
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Feed with this URL already exists",
|
||||
)
|
||||
|
||||
feed = Feed(
|
||||
url=str(feed_in.url),
|
||||
title=feed_in.title or "",
|
||||
description=feed_in.description or "",
|
||||
category=feed_in.category or "",
|
||||
is_active=feed_in.is_active,
|
||||
fetch_interval_minutes=feed_in.fetch_interval_minutes,
|
||||
priority=feed_in.priority,
|
||||
parser_config=feed_in.parser_config,
|
||||
proxy_policy=feed_in.proxy_policy,
|
||||
)
|
||||
db.add(feed)
|
||||
await db.commit()
|
||||
await db.refresh(feed)
|
||||
return FeedOut.model_validate(feed)
|
||||
|
||||
|
||||
@router.put("/{feed_id}", response_model=FeedOut)
|
||||
async def update_feed(
|
||||
feed_id: str,
|
||||
feed_in: FeedUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Update an existing feed."""
|
||||
feed = await db.get(Feed, feed_id)
|
||||
if not feed:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Feed not found")
|
||||
|
||||
update_data = feed_in.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
if field == "url" and value is not None:
|
||||
value = str(value)
|
||||
setattr(feed, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(feed)
|
||||
return FeedOut.model_validate(feed)
|
||||
|
||||
|
||||
@router.delete("/{feed_id}", response_model=MessageResponse)
|
||||
async def delete_feed(
|
||||
feed_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a feed."""
|
||||
feed = await db.get(Feed, feed_id)
|
||||
if not feed:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Feed not found")
|
||||
|
||||
await db.delete(feed)
|
||||
await db.commit()
|
||||
return {"message": "Feed deleted successfully"}
|
||||
@@ -0,0 +1,52 @@
|
||||
"""Health check router."""
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_admin, get_db
|
||||
from app.core.redis import check_redis_health
|
||||
|
||||
router = APIRouter(prefix="/health", tags=["health"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def health_check(request: Request, db: AsyncSession = Depends(get_db)):
|
||||
"""Basic health check."""
|
||||
db_ok = False
|
||||
try:
|
||||
await db.execute(text("SELECT 1"))
|
||||
db_ok = True
|
||||
except Exception:
|
||||
db_ok = False
|
||||
|
||||
redis_ok = await check_redis_health()
|
||||
|
||||
status_code = "ok" if db_ok and redis_ok else "degraded"
|
||||
|
||||
response = {
|
||||
"status": status_code,
|
||||
"service": "rss-platform",
|
||||
"db": "ok" if db_ok else "error",
|
||||
"redis": "ok" if redis_ok else "error",
|
||||
}
|
||||
warnings = getattr(request.app.state, "startup_warnings", None)
|
||||
if warnings:
|
||||
response["warnings"] = warnings
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/db", dependencies=[Depends(get_current_admin)])
|
||||
async def db_health(db: AsyncSession = Depends(get_db)):
|
||||
"""Database health check."""
|
||||
try:
|
||||
await db.execute(text("SELECT 1"))
|
||||
return {"status": "ok", "component": "database"}
|
||||
except Exception as exc:
|
||||
return {"status": "error", "component": "database", "detail": str(exc)}
|
||||
|
||||
|
||||
@router.get("/redis", dependencies=[Depends(get_current_admin)])
|
||||
async def redis_health():
|
||||
"""Redis health check."""
|
||||
ok = await check_redis_health()
|
||||
return {"status": "ok" if ok else "error", "component": "redis"}
|
||||
@@ -0,0 +1,92 @@
|
||||
"""Settings router."""
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_admin, get_current_user, get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.services.settings_service import (
|
||||
apply_db_settings_to_config,
|
||||
list_settings,
|
||||
reset_settings,
|
||||
set_setting,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/settings", tags=["settings"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_settings(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List all settings."""
|
||||
return await list_settings(db, mask_sensitive=current_user.role != "admin")
|
||||
|
||||
|
||||
@router.put("/{key}")
|
||||
async def update_setting(
|
||||
key: str,
|
||||
value: dict[str, Any],
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""Update a single setting."""
|
||||
if "value" not in value:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Request body must contain 'value' field",
|
||||
)
|
||||
|
||||
success = await set_setting(db, key, value["value"])
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid setting key: {key}",
|
||||
)
|
||||
|
||||
await apply_db_settings_to_config(db)
|
||||
return {"message": "Setting updated", "key": key}
|
||||
|
||||
|
||||
@router.put("")
|
||||
async def batch_update_settings(
|
||||
data: dict[str, Any],
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""Update multiple settings."""
|
||||
settings_data = data.get("settings", {})
|
||||
if not settings_data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Request body must contain 'settings' object",
|
||||
)
|
||||
|
||||
errors = []
|
||||
for key, value in settings_data.items():
|
||||
success = await set_setting(db, key, value)
|
||||
if not success:
|
||||
errors.append(key)
|
||||
|
||||
if errors:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid setting keys: {', '.join(errors)}",
|
||||
)
|
||||
|
||||
await apply_db_settings_to_config(db)
|
||||
return {"message": "Settings updated", "count": len(settings_data)}
|
||||
|
||||
|
||||
@router.post("/reset", response_model=MessageResponse)
|
||||
async def reset_all_settings(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""Reset all settings to environment defaults."""
|
||||
await reset_settings(db)
|
||||
await apply_db_settings_to_config(db)
|
||||
return {"message": "Settings reset to defaults"}
|
||||
@@ -0,0 +1,102 @@
|
||||
"""Authentication and authorization utilities."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.redis import get_redis
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
TOKEN_TYPE_ACCESS = "access"
|
||||
TOKEN_TYPE_REFRESH = "refresh"
|
||||
|
||||
# Redis key for revoked JWT jti set
|
||||
REVOKED_JTIS_KEY = "auth:revoked_jtis"
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a plain password against a hash."""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Hash a password."""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def _create_token(data: dict[str, Any], expires_delta: timedelta, token_type: str) -> tuple[str, str]:
|
||||
"""Create a JWT with jti/type claims. Returns (token, jti)."""
|
||||
jti = str(uuid4())
|
||||
to_encode = data.copy()
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"iat": datetime.now(timezone.utc),
|
||||
"jti": jti,
|
||||
"type": token_type,
|
||||
})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt, jti
|
||||
|
||||
|
||||
def create_access_token(sub: str, role: str | None = None) -> tuple[str, str]:
|
||||
"""Create a short-lived JWT access token. Returns (token, jti)."""
|
||||
data: dict[str, Any] = {"sub": sub}
|
||||
if role is not None:
|
||||
data["role"] = role
|
||||
return _create_token(
|
||||
data,
|
||||
timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||
TOKEN_TYPE_ACCESS,
|
||||
)
|
||||
|
||||
|
||||
def create_refresh_token(sub: str) -> tuple[str, str]:
|
||||
"""Create a long-lived JWT refresh token. Returns (token, jti)."""
|
||||
return _create_token(
|
||||
{"sub": sub},
|
||||
timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
|
||||
TOKEN_TYPE_REFRESH,
|
||||
)
|
||||
|
||||
|
||||
def decode_token(token: str, expected_type: str = TOKEN_TYPE_ACCESS) -> dict[str, Any]:
|
||||
"""Decode and validate a JWT token, checking type claim."""
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
|
||||
except JWTError as exc:
|
||||
raise ValueError("Invalid token") from exc
|
||||
|
||||
token_type = payload.get("type")
|
||||
if token_type != expected_type:
|
||||
raise ValueError(f"Invalid token type: expected {expected_type}, got {token_type}")
|
||||
|
||||
if "sub" not in payload or "jti" not in payload:
|
||||
raise ValueError("Invalid token payload")
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
async def is_token_revoked(jti: str) -> bool:
|
||||
"""Check whether a token jti has been revoked."""
|
||||
redis = await get_redis()
|
||||
if redis is None:
|
||||
# Without Redis we cannot reliably maintain a revocation list.
|
||||
return False
|
||||
return await redis.sismember(REVOKED_JTIS_KEY, jti)
|
||||
|
||||
|
||||
async def revoke_token(jti: str, expires_at: datetime) -> None:
|
||||
"""Revoke a token by its jti with TTL matching token expiry."""
|
||||
redis = await get_redis()
|
||||
if redis is None:
|
||||
return
|
||||
ttl_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds())
|
||||
if ttl_seconds > 0:
|
||||
await redis.sadd(REVOKED_JTIS_KEY, jti)
|
||||
await redis.expire(REVOKED_JTIS_KEY, ttl_seconds)
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Application configuration."""
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# Database
|
||||
DATABASE_URL: str = "postgresql+asyncpg://rss:rss@postgres:5432/rss_platform"
|
||||
|
||||
# Redis
|
||||
REDIS_URL: str = "redis://redis:6379/0"
|
||||
|
||||
# JWT
|
||||
SECRET_KEY: str = Field(..., min_length=32)
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 7
|
||||
|
||||
# AI
|
||||
AI_DEFAULT_PROVIDER: str = "openai"
|
||||
AI_DEFAULT_MODEL: str = "gpt-4o-mini"
|
||||
|
||||
# Storage
|
||||
STORAGE_TYPE: str = "minio"
|
||||
MINIO_ENDPOINT: str = "minio:9000"
|
||||
MINIO_ACCESS_KEY: str = "minioadmin"
|
||||
MINIO_SECRET_KEY: str = "minioadmin"
|
||||
MINIO_BUCKET: str = "rss-platform"
|
||||
|
||||
# CORS
|
||||
CORS_ALLOWED_ORIGINS: str = ""
|
||||
|
||||
# Default admin
|
||||
DEFAULT_ADMIN_USERNAME: str = "admin"
|
||||
DEFAULT_ADMIN_PASSWORD: str = "admin"
|
||||
|
||||
# Sensitive settings encryption
|
||||
SETTINGS_ENCRYPTION_KEY: str = ""
|
||||
|
||||
# Logging
|
||||
LOG_LEVEL: str = "INFO"
|
||||
|
||||
# RSS Fetching
|
||||
FETCH_CONCURRENCY: int = 10
|
||||
FETCH_TIMEOUT: int = 30
|
||||
DEFAULT_FETCH_INTERVAL: int = 60
|
||||
MIN_FETCH_INTERVAL: int = 15
|
||||
|
||||
# Ports (for reference)
|
||||
BACKEND_PORT: int = 8000
|
||||
FRONTEND_PORT: int = 5173
|
||||
|
||||
@property
|
||||
def cors_origins(self) -> list[str]:
|
||||
"""Parse CORS_ALLOWED_ORIGINS into list."""
|
||||
if not self.CORS_ALLOWED_ORIGINS:
|
||||
return []
|
||||
return [origin.strip() for origin in self.CORS_ALLOWED_ORIGINS.split(",") if origin.strip()]
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Database configuration and session management."""
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=settings.LOG_LEVEL == "DEBUG",
|
||||
future=True,
|
||||
pool_size=10,
|
||||
max_overflow=20,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_db() -> AsyncSession:
|
||||
"""Dependency for FastAPI to get async DB session."""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def close_db() -> None:
|
||||
"""Close database connections."""
|
||||
await engine.dispose()
|
||||
@@ -0,0 +1,64 @@
|
||||
"""Custom exceptions and error handlers."""
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PlatformException(Exception):
|
||||
"""Base exception for the platform."""
|
||||
|
||||
def __init__(self, message: str, status_code: int = 400):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class AuthenticationError(PlatformException):
|
||||
"""Authentication failed."""
|
||||
|
||||
def __init__(self, message: str = "Authentication failed"):
|
||||
super().__init__(message, status_code=401)
|
||||
|
||||
|
||||
class AuthorizationError(PlatformException):
|
||||
"""Authorization failed."""
|
||||
|
||||
def __init__(self, message: str = "Forbidden"):
|
||||
super().__init__(message, status_code=403)
|
||||
|
||||
|
||||
class NotFoundError(PlatformException):
|
||||
"""Resource not found."""
|
||||
|
||||
def __init__(self, message: str = "Resource not found"):
|
||||
super().__init__(message, status_code=404)
|
||||
|
||||
|
||||
class ConflictError(PlatformException):
|
||||
"""Resource conflict."""
|
||||
|
||||
def __init__(self, message: str = "Conflict"):
|
||||
super().__init__(message, status_code=409)
|
||||
|
||||
|
||||
def add_exception_handlers(app: FastAPI) -> None:
|
||||
"""Register global exception handlers."""
|
||||
|
||||
@app.exception_handler(PlatformException)
|
||||
async def platform_exception_handler(request: Request, exc: PlatformException):
|
||||
logger.warning("Platform exception: %s", exc.message)
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"detail": exc.message},
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def generic_exception_handler(request: Request, exc: Exception):
|
||||
logger.exception("Unhandled exception: %s", exc)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "Internal server error"},
|
||||
)
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Logging configuration."""
|
||||
import logging
|
||||
import sys
|
||||
from contextvars import ContextVar
|
||||
|
||||
request_id_var: ContextVar[str] = ContextVar("request_id", default="")
|
||||
|
||||
|
||||
def configure_logging(log_level: str = "INFO") -> None:
|
||||
"""Configure structured logging."""
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - [%(request_id)s] %(message)s"
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
handler.addFilter(RequestIdFilter())
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(getattr(logging, log_level.upper(), logging.INFO))
|
||||
root_logger.handlers = []
|
||||
root_logger.addHandler(handler)
|
||||
|
||||
# Reduce noise from third-party libraries
|
||||
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class RequestIdFilter(logging.Filter):
|
||||
"""Inject request_id into log records."""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
record.request_id = request_id_var.get() # type: ignore[attr-defined]
|
||||
return True
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Get a logger with request_id filter."""
|
||||
logger = logging.getLogger(name)
|
||||
logger.addFilter(RequestIdFilter())
|
||||
return logger
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Role-based access control."""
|
||||
from enum import Enum
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class Role(str, Enum):
|
||||
"""User roles."""
|
||||
|
||||
ADMIN = "admin"
|
||||
MEMBER = "member"
|
||||
|
||||
|
||||
def require_admin(current_user: User) -> User:
|
||||
"""Dependency that requires admin role."""
|
||||
if current_user.role != Role.ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin privileges required",
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def has_permission(user: User, required_role: Role) -> bool:
|
||||
"""Check if user has required role."""
|
||||
if user.role == Role.ADMIN:
|
||||
return True
|
||||
return user.role == required_role
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Redis connection management."""
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
_redis: Redis | None = None
|
||||
|
||||
|
||||
async def get_redis() -> Redis:
|
||||
"""Get or create Redis connection."""
|
||||
global _redis
|
||||
if _redis is None:
|
||||
_redis = Redis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
return _redis
|
||||
|
||||
|
||||
async def close_redis() -> None:
|
||||
"""Close Redis connection."""
|
||||
global _redis
|
||||
if _redis:
|
||||
await _redis.close()
|
||||
_redis = None
|
||||
|
||||
|
||||
async def check_redis_health() -> bool:
|
||||
"""Check if Redis is reachable."""
|
||||
try:
|
||||
redis = await get_redis()
|
||||
await redis.ping()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
@@ -0,0 +1,34 @@
|
||||
"""Models package."""
|
||||
from app.models.ai_config import AIProviderConfig, AITaskConfig
|
||||
from app.models.article import CleanedArticle, RawArticle
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin, utc_now
|
||||
from app.models.chat import ChatMessage, ChatSession
|
||||
from app.models.feed import Feed
|
||||
from app.models.lock import Lock
|
||||
from app.models.output import Output, OutputTask
|
||||
from app.models.reference import ArticleReference, DuplicateGroup
|
||||
from app.models.setting import AppSetting
|
||||
from app.models.skill import Skill
|
||||
from app.models.user import User
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"TimestampMixin",
|
||||
"UUIDMixin",
|
||||
"utc_now",
|
||||
"User",
|
||||
"Feed",
|
||||
"RawArticle",
|
||||
"CleanedArticle",
|
||||
"ArticleReference",
|
||||
"DuplicateGroup",
|
||||
"Skill",
|
||||
"AIProviderConfig",
|
||||
"AITaskConfig",
|
||||
"OutputTask",
|
||||
"Output",
|
||||
"ChatSession",
|
||||
"ChatMessage",
|
||||
"Lock",
|
||||
"AppSetting",
|
||||
]
|
||||
@@ -0,0 +1,45 @@
|
||||
"""AI configuration models."""
|
||||
from sqlalchemy import Boolean, Float, ForeignKey, Integer, JSON, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class AIProviderConfig(Base, UUIDMixin, TimestampMixin):
|
||||
"""AI provider configuration (OpenAI, Anthropic, etc.)."""
|
||||
|
||||
__tablename__ = "ai_provider_configs"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
base_url: Mapped[str | None] = mapped_column(String(512), nullable=True)
|
||||
api_key_encrypted: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
default_model: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
timeout: Mapped[int] = mapped_column(Integer, default=60, nullable=False)
|
||||
max_retries: Mapped[int] = mapped_column(Integer, default=3, nullable=False)
|
||||
rate_limit_rpm: Mapped[int] = mapped_column(Integer, default=60, nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
|
||||
|
||||
class AITaskConfig(Base, UUIDMixin, TimestampMixin):
|
||||
"""AI task configuration (which model/skill for which task)."""
|
||||
|
||||
__tablename__ = "ai_task_configs"
|
||||
|
||||
task_type: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
provider_config_id: Mapped[str | None] = mapped_column(
|
||||
ForeignKey("ai_provider_configs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
model: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
skill_id: Mapped[str | None] = mapped_column(
|
||||
ForeignKey("skills.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
temperature: Mapped[float] = mapped_column(Float, default=0.3, nullable=False)
|
||||
max_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
top_p: Mapped[float] = mapped_column(Float, default=1.0, nullable=False)
|
||||
system_prompt_override: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
fallback_config_id: Mapped[str | None] = mapped_column(
|
||||
ForeignKey("ai_task_configs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
@@ -0,0 +1,81 @@
|
||||
"""Article models: raw and cleaned."""
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Float, ForeignKey, Integer, JSON, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class RawArticle(Base, UUIDMixin, TimestampMixin):
|
||||
"""Raw article fetched from RSS feed."""
|
||||
|
||||
__tablename__ = "raw_articles"
|
||||
|
||||
feed_id: Mapped[str] = mapped_column(
|
||||
ForeignKey("feeds.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
external_id: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True)
|
||||
title: Mapped[str | None] = mapped_column(String(1024), default="", index=True)
|
||||
link: Mapped[str] = mapped_column(String(2048), nullable=False, index=True)
|
||||
author: Mapped[str | None] = mapped_column(String(256), default="")
|
||||
published_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True, index=True)
|
||||
fetched_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, index=True
|
||||
)
|
||||
content: Mapped[str | None] = mapped_column(Text, default="")
|
||||
summary: Mapped[str | None] = mapped_column(Text, default="")
|
||||
raw_html: Mapped[str | None] = mapped_column(Text, default="")
|
||||
content_hash: Mapped[str | None] = mapped_column(String(64), default="")
|
||||
language: Mapped[str | None] = mapped_column(String(16), default="")
|
||||
status: Mapped[str] = mapped_column(String(32), default="pending", nullable=False, index=True)
|
||||
|
||||
feed: Mapped["Feed"] = relationship("Feed", back_populates="raw_articles")
|
||||
cleaned_article: Mapped["CleanedArticle | None"] = relationship(
|
||||
"CleanedArticle", back_populates="raw_article", uselist=False
|
||||
)
|
||||
|
||||
|
||||
class CleanedArticle(Base, UUIDMixin, TimestampMixin):
|
||||
"""Cleaned and AI-enriched article."""
|
||||
|
||||
__tablename__ = "cleaned_articles"
|
||||
|
||||
raw_article_id: Mapped[str | None] = mapped_column(
|
||||
ForeignKey("raw_articles.id", ondelete="SET NULL"), nullable=True, index=True
|
||||
)
|
||||
feed_id: Mapped[str] = mapped_column(
|
||||
ForeignKey("feeds.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
|
||||
title: Mapped[str | None] = mapped_column(String(1024), default="", index=True)
|
||||
link: Mapped[str] = mapped_column(String(2048), default="", index=True)
|
||||
author: Mapped[str | None] = mapped_column(String(256), default="")
|
||||
feed_title: Mapped[str | None] = mapped_column(String(512), default="")
|
||||
feed_category: Mapped[str | None] = mapped_column(String(128), default="")
|
||||
|
||||
published_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True, index=True)
|
||||
fetched_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
|
||||
|
||||
content: Mapped[str | None] = mapped_column(Text, default="")
|
||||
content_length: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
original_summary: Mapped[str | None] = mapped_column(Text, default="")
|
||||
ai_summary: Mapped[str | None] = mapped_column(Text, default="")
|
||||
|
||||
category: Mapped[str | None] = mapped_column(String(128), default="", index=True)
|
||||
tags: Mapped[list] = mapped_column(JSON, default=list, nullable=False)
|
||||
|
||||
heat_score: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
|
||||
importance_score: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
|
||||
duplication_score: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
|
||||
composite_score: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
|
||||
|
||||
duplicate_group_id: Mapped[str | None] = mapped_column(
|
||||
ForeignKey("duplicate_groups.id", ondelete="SET NULL"), nullable=True, index=True
|
||||
)
|
||||
is_representative: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False, index=True)
|
||||
reference_links: Mapped[list] = mapped_column(JSON, default=list, nullable=False)
|
||||
processing_status: Mapped[str] = mapped_column(String(32), default="pending", nullable=False, index=True)
|
||||
|
||||
raw_article: Mapped["RawArticle | None"] = relationship("RawArticle", back_populates="cleaned_article")
|
||||
duplicate_group: Mapped["DuplicateGroup | None"] = relationship("DuplicateGroup", back_populates="articles")
|
||||
@@ -0,0 +1,45 @@
|
||||
"""SQLAlchemy 2.0 async base and session factory."""
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all models."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
"""Adds created_at and updated_at columns."""
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
class UUIDMixin:
|
||||
"""Adds UUID primary key."""
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid4,
|
||||
index=True,
|
||||
)
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
"""Return timezone-aware UTC now."""
|
||||
return datetime.now(timezone.utc)
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Chat models."""
|
||||
from sqlalchemy import ForeignKey, JSON, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class ChatSession(Base, UUIDMixin, TimestampMixin):
|
||||
"""Chat session."""
|
||||
|
||||
__tablename__ = "chat_sessions"
|
||||
|
||||
user_id: Mapped[str | None] = mapped_column(
|
||||
ForeignKey("users.id", ondelete="CASCADE"), nullable=True, index=True
|
||||
)
|
||||
title: Mapped[str | None] = mapped_column(String(256), default="")
|
||||
skill_id: Mapped[str | None] = mapped_column(
|
||||
ForeignKey("skills.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
context_window: Mapped[int] = mapped_column(default=10, nullable=False)
|
||||
|
||||
messages: Mapped[list["ChatMessage"]] = relationship(
|
||||
"ChatMessage", back_populates="session", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class ChatMessage(Base, UUIDMixin, TimestampMixin):
|
||||
"""Chat message."""
|
||||
|
||||
__tablename__ = "chat_messages"
|
||||
|
||||
session_id: Mapped[str] = mapped_column(
|
||||
ForeignKey("chat_sessions.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
role: Mapped[str] = mapped_column(String(32), nullable=False, index=True) # user / assistant / tool
|
||||
content: Mapped[str | None] = mapped_column(Text, default="")
|
||||
tool_calls: Mapped[list] = mapped_column(JSON, default=list, nullable=False)
|
||||
tool_results: Mapped[list] = mapped_column(JSON, default=list, nullable=False)
|
||||
references: Mapped[list] = mapped_column(JSON, default=list, nullable=False)
|
||||
token_usage: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||
|
||||
session: Mapped["ChatSession"] = relationship("ChatSession", back_populates="messages")
|
||||
@@ -0,0 +1,59 @@
|
||||
"""Feed model."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Integer, JSON, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class Feed(Base, UUIDMixin, TimestampMixin):
|
||||
"""RSS feed source."""
|
||||
|
||||
__tablename__ = "feeds"
|
||||
|
||||
url: Mapped[str] = mapped_column(String(2048), unique=True, nullable=False, index=True)
|
||||
title: Mapped[str | None] = mapped_column(String(512), default="")
|
||||
description: Mapped[str | None] = mapped_column(Text, default="")
|
||||
category: Mapped[str | None] = mapped_column(String(128), default="")
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False, index=True)
|
||||
fetch_interval_minutes: Mapped[int] = mapped_column(Integer, default=60, nullable=False)
|
||||
priority: Mapped[int] = mapped_column(Integer, default=5, nullable=False)
|
||||
parser_config: Mapped[dict] = mapped_column(JSON, default=dict, nullable=False)
|
||||
proxy_policy: Mapped[str] = mapped_column(String(32), default="auto", nullable=False)
|
||||
|
||||
# Fetch statistics
|
||||
last_fetch_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
last_fetch_status: Mapped[str | None] = mapped_column(String(32), default="")
|
||||
last_error: Mapped[str | None] = mapped_column(Text, default="")
|
||||
error_type: Mapped[str | None] = mapped_column(String(64), default="")
|
||||
success_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
fail_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
article_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
|
||||
raw_articles: Mapped[list["RawArticle"]] = relationship(
|
||||
"RawArticle", back_populates="feed", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def health_status(self, now: datetime | None = None) -> str:
|
||||
"""Compute feed health status."""
|
||||
if now is None:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
total = self.success_count + self.fail_count
|
||||
if total == 0:
|
||||
return "unknown"
|
||||
|
||||
success_rate = self.success_count / total
|
||||
days_since = None
|
||||
if self.last_fetch_at:
|
||||
days_since = (now - self.last_fetch_at).days
|
||||
|
||||
if success_rate >= 0.9 and (days_since is None or days_since <= 7):
|
||||
return "healthy"
|
||||
if success_rate >= 0.5 and (days_since is None or days_since <= 7):
|
||||
return "warning"
|
||||
return "unhealthy"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Feed {self.title or self.url}>"
|
||||
@@ -0,0 +1,24 @@
|
||||
"""Lock model."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import DateTime, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.models.base import Base, UUIDMixin
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class Lock(Base, UUIDMixin):
|
||||
"""Distributed lock record (fallback when Redis is unavailable)."""
|
||||
|
||||
__tablename__ = "locks"
|
||||
|
||||
lock_name: Mapped[str] = mapped_column(String(128), unique=True, nullable=False, index=True)
|
||||
owner_id: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
acquired_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, default=_utc_now
|
||||
)
|
||||
expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
@@ -0,0 +1,41 @@
|
||||
"""Output task and output record models."""
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, JSON, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class OutputTask(Base, UUIDMixin, TimestampMixin):
|
||||
"""Configurable output task (e.g. daily brief)."""
|
||||
|
||||
__tablename__ = "output_tasks"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
task_type: Mapped[str] = mapped_column(String(64), default="daily_brief", nullable=False, index=True)
|
||||
skill_id: Mapped[str] = mapped_column(
|
||||
ForeignKey("skills.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
schedule: Mapped[str | None] = mapped_column(String(128), nullable=True) # cron expression
|
||||
filter_config: Mapped[dict] = mapped_column(JSON, default=dict, nullable=False)
|
||||
output_config: Mapped[dict] = mapped_column(JSON, default=dict, nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
last_run_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
last_output_id: Mapped[str | None] = mapped_column(
|
||||
ForeignKey("outputs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class Output(Base, UUIDMixin, TimestampMixin):
|
||||
"""Generated output record."""
|
||||
|
||||
__tablename__ = "outputs"
|
||||
|
||||
output_task_id: Mapped[str | None] = mapped_column(
|
||||
ForeignKey("output_tasks.id", ondelete="SET NULL"), nullable=True, index=True
|
||||
)
|
||||
content: Mapped[str | None] = mapped_column(Text, default="")
|
||||
content_html: Mapped[str | None] = mapped_column(Text, default="")
|
||||
references: Mapped[list] = mapped_column(JSON, default=list, nullable=False)
|
||||
metadata: Mapped[dict] = mapped_column(JSON, default=dict, nullable=False)
|
||||
@@ -0,0 +1,39 @@
|
||||
"""Reference and duplicate group models."""
|
||||
from sqlalchemy import Float, ForeignKey, JSON, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class ArticleReference(Base, UUIDMixin, TimestampMixin):
|
||||
"""Reference from a cleaned article to another related article."""
|
||||
|
||||
__tablename__ = "article_references"
|
||||
|
||||
source_article_id: Mapped[str] = mapped_column(
|
||||
ForeignKey("cleaned_articles.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
referenced_article_id: Mapped[str | None] = mapped_column(
|
||||
ForeignKey("cleaned_articles.id", ondelete="SET NULL"), nullable=True, index=True
|
||||
)
|
||||
reference_type: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
reference_link: Mapped[str | None] = mapped_column(String(2048), default="")
|
||||
reference_title: Mapped[str | None] = mapped_column(String(1024), default="")
|
||||
similarity: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
|
||||
|
||||
class DuplicateGroup(Base, UUIDMixin, TimestampMixin):
|
||||
"""Group of duplicate articles."""
|
||||
|
||||
__tablename__ = "duplicate_groups"
|
||||
|
||||
representative_article_id: Mapped[str | None] = mapped_column(
|
||||
ForeignKey("cleaned_articles.id", ondelete="SET NULL"), nullable=True, index=True
|
||||
)
|
||||
member_article_ids: Mapped[list] = mapped_column(JSON, default=list, nullable=False)
|
||||
similarity_matrix: Mapped[dict] = mapped_column(JSON, default=dict, nullable=False)
|
||||
brief_date: Mapped[str | None] = mapped_column(String(10), default="", index=True)
|
||||
|
||||
articles: Mapped[list["CleanedArticle"]] = relationship(
|
||||
"CleanedArticle", back_populates="duplicate_group"
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
"""App setting model."""
|
||||
from sqlalchemy import Boolean, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class AppSetting(Base, UUIDMixin, TimestampMixin):
|
||||
"""Runtime application setting."""
|
||||
|
||||
__tablename__ = "app_settings"
|
||||
|
||||
key: Mapped[str] = mapped_column(String(128), unique=True, nullable=False, index=True)
|
||||
value: Mapped[str] = mapped_column(Text, default="", nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, default="")
|
||||
is_sensitive: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
@@ -0,0 +1,26 @@
|
||||
"""Skill model."""
|
||||
from sqlalchemy import Boolean, Integer, JSON, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class Skill(Base, UUIDMixin, TimestampMixin):
|
||||
"""Reusable skill configuration for AI outputs."""
|
||||
|
||||
__tablename__ = "skills"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
slug: Mapped[str] = mapped_column(String(128), unique=True, nullable=False, index=True)
|
||||
description: Mapped[str | None] = mapped_column(Text, default="")
|
||||
type: Mapped[str] = mapped_column(String(32), nullable=False, index=True) # output / tool / agent
|
||||
version: Mapped[int] = mapped_column(Integer, default=1, nullable=False)
|
||||
is_default: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
|
||||
system_prompt: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
output_schema: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||
tools: Mapped[list] = mapped_column(JSON, default=list, nullable=False)
|
||||
input_schema: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||
example_inputs: Mapped[list] = mapped_column(JSON, default=list, nullable=False)
|
||||
|
||||
created_by: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
@@ -0,0 +1,22 @@
|
||||
"""User model."""
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin, utc_now
|
||||
|
||||
|
||||
class User(Base, UUIDMixin, TimestampMixin):
|
||||
"""Platform user."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
role: Mapped[str] = mapped_column(String(32), default="member", nullable=False, index=True)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
last_login_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<User {self.username} ({self.role})>"
|
||||
@@ -0,0 +1,57 @@
|
||||
"""Article Pydantic schemas."""
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ArticleListParams(BaseModel):
|
||||
"""Article list query parameters."""
|
||||
|
||||
feed_id: str | None = None
|
||||
category: str | None = None
|
||||
tag: str | None = None
|
||||
search: str | None = None
|
||||
is_read: bool | None = None
|
||||
skip: int = 0
|
||||
limit: int = Field(default=50, le=200)
|
||||
|
||||
|
||||
class ArticleOut(BaseModel):
|
||||
"""Cleaned article output schema."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str
|
||||
raw_article_id: str | None = None
|
||||
feed_id: str
|
||||
title: str | None = None
|
||||
link: str
|
||||
author: str | None = None
|
||||
feed_title: str | None = None
|
||||
feed_category: str | None = None
|
||||
published_at: str | None = None
|
||||
fetched_at: str
|
||||
content: str | None = None
|
||||
original_summary: str | None = None
|
||||
ai_summary: str | None = None
|
||||
category: str | None = None
|
||||
tags: list[str] = []
|
||||
heat_score: float = 0.0
|
||||
importance_score: float = 0.0
|
||||
duplication_score: float = 0.0
|
||||
composite_score: float = 0.0
|
||||
is_representative: bool = True
|
||||
reference_links: list[dict] = []
|
||||
processing_status: str = "pending"
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
@classmethod
|
||||
def model_validate(cls, obj):
|
||||
"""Format datetime fields."""
|
||||
data = {}
|
||||
for key in obj.__dict__:
|
||||
value = getattr(obj, key)
|
||||
if key in ("created_at", "updated_at", "published_at", "fetched_at") and value is not None:
|
||||
data[key] = value.isoformat()
|
||||
else:
|
||||
data[key] = value
|
||||
return cls.model_construct(**data)
|
||||
@@ -0,0 +1,28 @@
|
||||
"""Common Pydantic schemas."""
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class PaginationParams(BaseModel):
|
||||
"""Pagination query parameters."""
|
||||
|
||||
skip: int = Field(default=0, ge=0)
|
||||
limit: int = Field(default=50, ge=1, le=200)
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel):
|
||||
"""Paginated response wrapper."""
|
||||
|
||||
total: int
|
||||
items: list
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
"""Simple message response."""
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
class BaseSchema(BaseModel):
|
||||
"""Base schema with ORM mode."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
@@ -0,0 +1,66 @@
|
||||
"""Feed Pydantic schemas."""
|
||||
from pydantic import BaseModel, ConfigDict, Field, HttpUrl
|
||||
|
||||
|
||||
class FeedBase(BaseModel):
|
||||
"""Base feed schema."""
|
||||
|
||||
url: HttpUrl
|
||||
title: str | None = Field(default="", max_length=512)
|
||||
description: str | None = ""
|
||||
category: str | None = Field(default="", max_length=128)
|
||||
is_active: bool = True
|
||||
fetch_interval_minutes: int = Field(default=60, ge=15)
|
||||
priority: int = Field(default=5, ge=1, le=10)
|
||||
parser_config: dict = {}
|
||||
proxy_policy: str = "auto"
|
||||
|
||||
|
||||
class FeedCreate(FeedBase):
|
||||
"""Feed creation schema."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FeedUpdate(BaseModel):
|
||||
"""Feed update schema."""
|
||||
|
||||
title: str | None = Field(default=None, max_length=512)
|
||||
description: str | None = None
|
||||
category: str | None = Field(default=None, max_length=128)
|
||||
is_active: bool | None = None
|
||||
fetch_interval_minutes: int | None = Field(default=None, ge=15)
|
||||
priority: int | None = Field(default=None, ge=1, le=10)
|
||||
parser_config: dict | None = None
|
||||
proxy_policy: str | None = None
|
||||
|
||||
|
||||
class FeedOut(FeedBase):
|
||||
"""Feed output schema."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str
|
||||
last_fetch_at: str | None = None
|
||||
last_fetch_status: str | None = None
|
||||
last_error: str | None = None
|
||||
error_type: str | None = None
|
||||
success_count: int = 0
|
||||
fail_count: int = 0
|
||||
article_count: int = 0
|
||||
health_status: str = "unknown"
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
@classmethod
|
||||
def model_validate(cls, obj):
|
||||
"""Override to compute health_status and format datetimes."""
|
||||
data = {}
|
||||
for key in obj.__dict__:
|
||||
value = getattr(obj, key)
|
||||
if key in ("created_at", "updated_at", "last_fetch_at") and value is not None:
|
||||
data[key] = value.isoformat()
|
||||
else:
|
||||
data[key] = value
|
||||
data["health_status"] = obj.health_status()
|
||||
return cls.model_construct(**data)
|
||||
@@ -0,0 +1,76 @@
|
||||
"""User Pydantic schemas."""
|
||||
import re
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
_PASSWORD_RE = re.compile(r"^(?=.*[A-Za-z])(?=.*\d)[A-Za-z\d@$!%*?&_.-]{8,128}$")
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
"""Base user schema."""
|
||||
|
||||
username: str = Field(..., min_length=3, max_length=64)
|
||||
role: str = "member"
|
||||
is_active: bool = True
|
||||
|
||||
@field_validator("role")
|
||||
@classmethod
|
||||
def _validate_role(cls, value: str) -> str:
|
||||
allowed = {"admin", "member"}
|
||||
if value not in allowed:
|
||||
raise ValueError(f"role must be one of {allowed}")
|
||||
return value
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
"""User creation schema."""
|
||||
|
||||
password: str = Field(..., min_length=8, max_length=128)
|
||||
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def _validate_password_strength(cls, value: str) -> str:
|
||||
if not _PASSWORD_RE.match(value):
|
||||
raise ValueError(
|
||||
"password must be 8-128 characters and contain at least one letter and one number"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
class UserOut(UserBase):
|
||||
"""User output schema."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str
|
||||
|
||||
|
||||
class UserLogin(BaseModel):
|
||||
"""User login schema."""
|
||||
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""Token response schema."""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
"""JWT token payload."""
|
||||
|
||||
sub: str | None = None
|
||||
role: str | None = None
|
||||
jti: str | None = None
|
||||
type: str | None = None
|
||||
exp: int | None = None
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
"""Refresh token request schema."""
|
||||
|
||||
refresh_token: str
|
||||
@@ -0,0 +1,153 @@
|
||||
"""Distributed lock service with Redis and DB fallback."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import AsyncSessionLocal
|
||||
from app.core.logging import get_logger
|
||||
from app.core.redis import get_redis
|
||||
from app.models.lock import Lock
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class LockService:
|
||||
"""Distributed lock service."""
|
||||
|
||||
def __init__(self, owner_id: str | None = None):
|
||||
self.owner_id = owner_id or str(uuid4())
|
||||
|
||||
async def acquire(self, lock_name: str, ttl: int = 60) -> bool:
|
||||
"""Acquire a lock with given TTL in seconds."""
|
||||
# Try Redis first
|
||||
try:
|
||||
redis = await get_redis()
|
||||
acquired = await redis.set(lock_name, self.owner_id, nx=True, ex=ttl)
|
||||
if acquired:
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning("Redis lock failed, falling back to DB: %s", exc)
|
||||
|
||||
# Fallback to DB
|
||||
return await self._acquire_db(lock_name, ttl)
|
||||
|
||||
async def release(self, lock_name: str) -> bool:
|
||||
"""Release a lock."""
|
||||
# Try Redis first
|
||||
try:
|
||||
redis = await get_redis()
|
||||
# Only release if we own it
|
||||
current_owner = await redis.get(lock_name)
|
||||
if current_owner == self.owner_id:
|
||||
await redis.delete(lock_name)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning("Redis unlock failed, falling back to DB: %s", exc)
|
||||
|
||||
return await self._release_db(lock_name)
|
||||
|
||||
async def extend(self, lock_name: str, ttl: int = 60) -> bool:
|
||||
"""Extend lock TTL."""
|
||||
try:
|
||||
redis = await get_redis()
|
||||
current_owner = await redis.get(lock_name)
|
||||
if current_owner == self.owner_id:
|
||||
await redis.expire(lock_name, ttl)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning("Redis extend failed: %s", exc)
|
||||
|
||||
return await self._extend_db(lock_name, ttl)
|
||||
|
||||
async def is_locked(self, lock_name: str) -> bool:
|
||||
"""Check if a lock is held."""
|
||||
try:
|
||||
redis = await get_redis()
|
||||
exists = await redis.exists(lock_name)
|
||||
if exists:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
result = await db.execute(select(Lock).where(Lock.lock_name == lock_name))
|
||||
lock = result.scalar_one_or_none()
|
||||
if not lock:
|
||||
return False
|
||||
if lock.expires_at and lock.expires_at < datetime.now(timezone.utc):
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _acquire_db(self, lock_name: str, ttl: int) -> bool:
|
||||
async with AsyncSessionLocal() as db:
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=ttl)
|
||||
|
||||
# Try to update expired lock
|
||||
result = await db.execute(
|
||||
select(Lock).where(
|
||||
Lock.lock_name == lock_name,
|
||||
Lock.expires_at < now,
|
||||
)
|
||||
)
|
||||
lock = result.scalar_one_or_none()
|
||||
if lock:
|
||||
lock.owner_id = self.owner_id
|
||||
lock.acquired_at = now
|
||||
lock.expires_at = expires_at
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
# Try to insert new lock
|
||||
lock = Lock(
|
||||
lock_name=lock_name,
|
||||
owner_id=self.owner_id,
|
||||
acquired_at=now,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(lock)
|
||||
try:
|
||||
await db.commit()
|
||||
return True
|
||||
except Exception:
|
||||
await db.rollback()
|
||||
return False
|
||||
|
||||
async def _release_db(self, lock_name: str) -> bool:
|
||||
async with AsyncSessionLocal() as db:
|
||||
result = await db.execute(
|
||||
select(Lock).where(
|
||||
Lock.lock_name == lock_name,
|
||||
Lock.owner_id == self.owner_id,
|
||||
)
|
||||
)
|
||||
lock = result.scalar_one_or_none()
|
||||
if not lock:
|
||||
return False
|
||||
|
||||
await db.delete(lock)
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
async def _extend_db(self, lock_name: str, ttl: int) -> bool:
|
||||
async with AsyncSessionLocal() as db:
|
||||
result = await db.execute(
|
||||
select(Lock).where(
|
||||
Lock.lock_name == lock_name,
|
||||
Lock.owner_id == self.owner_id,
|
||||
)
|
||||
)
|
||||
lock = result.scalar_one_or_none()
|
||||
if not lock:
|
||||
return False
|
||||
|
||||
lock.expires_at = datetime.now(timezone.utc) + timedelta(seconds=ttl)
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
|
||||
async def get_lock_service(owner_id: str | None = None) -> LockService:
|
||||
"""Get a lock service instance."""
|
||||
return LockService(owner_id=owner_id)
|
||||
@@ -0,0 +1,227 @@
|
||||
"""Application settings management service."""
|
||||
from typing import Any
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.models.setting import AppSetting
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
EDITABLE_SETTINGS = {
|
||||
"RSSKEEPER_BASE_URL": {"description": "rssKeeper 服务地址", "sensitive": False},
|
||||
"OPENAI_API_KEY": {"description": "LLM API Key", "sensitive": True},
|
||||
"OPENAI_BASE_URL": {"description": "LLM API 基础地址", "sensitive": False},
|
||||
"OPENAI_MODEL": {"description": "LLM 模型名", "sensitive": False},
|
||||
"OPENAI_TIMEOUT": {"description": "LLM 调用超时(秒)", "sensitive": False},
|
||||
"OPENAI_MAX_RETRIES": {"description": "LLM 最大重试次数", "sensitive": False},
|
||||
"SUMMARIZE_INTERVAL_MINUTES": {"description": "摘要任务间隔(分钟)", "sensitive": False},
|
||||
"TAG_SCORE_INTERVAL_MINUTES": {"description": "分类/打分/去重任务间隔(分钟)", "sensitive": False},
|
||||
"DAILY_BRIEF_HOUR": {"description": "每日简报生成小时", "sensitive": False},
|
||||
"DAILY_BRIEF_MINUTE": {"description": "每日简报生成分钟", "sensitive": False},
|
||||
"TITLE_SIMILARITY_THRESHOLD": {"description": "标题相似度阈值", "sensitive": False},
|
||||
"CONTENT_SIMILARITY_THRESHOLD": {"description": "内容相似度阈值", "sensitive": False},
|
||||
"MAX_AI_SUMMARY_LENGTH": {"description": "AI 摘要最大长度", "sensitive": False},
|
||||
"MIN_ORIGINAL_SUMMARY_LENGTH": {"description": "原始摘要最小长度", "sensitive": False},
|
||||
"BRIEF_TOP_N_PER_CATEGORY": {"description": "简报每分类显示文章数", "sensitive": False},
|
||||
"LOG_LEVEL": {"description": "日志级别", "sensitive": False},
|
||||
"API_TOKEN": {"description": "API 鉴权 Token(为空时不启用)", "sensitive": True},
|
||||
"CORS_ALLOWED_ORIGINS": {"description": "CORS 允许来源(逗号分隔)", "sensitive": False},
|
||||
}
|
||||
|
||||
# Prefix to detect encrypted values
|
||||
_ENC_PREFIX = "enc:"
|
||||
|
||||
|
||||
def _get_fernet() -> Fernet | None:
|
||||
"""Get Fernet instance if encryption key is configured."""
|
||||
key = settings.SETTINGS_ENCRYPTION_KEY
|
||||
if not key:
|
||||
return None
|
||||
try:
|
||||
return Fernet(key.encode() if isinstance(key, str) else key)
|
||||
except Exception as exc:
|
||||
logger.error("SETTINGS_ENCRYPTION_KEY 无效: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
def _encrypt(value: str) -> str:
|
||||
"""Encrypt a sensitive value if encryption is enabled."""
|
||||
if not value:
|
||||
return value
|
||||
fernet = _get_fernet()
|
||||
if fernet is None:
|
||||
return value
|
||||
return _ENC_PREFIX + fernet.encrypt(value.encode()).decode()
|
||||
|
||||
|
||||
def _decrypt(value: str) -> str:
|
||||
"""Decrypt a sensitive value if it was encrypted."""
|
||||
if not value or not value.startswith(_ENC_PREFIX):
|
||||
return value
|
||||
fernet = _get_fernet()
|
||||
if fernet is None:
|
||||
logger.warning("发现加密配置值但 SETTINGS_ENCRYPTION_KEY 未配置,无法解密")
|
||||
return value
|
||||
try:
|
||||
ciphertext = value[len(_ENC_PREFIX):].encode()
|
||||
return fernet.decrypt(ciphertext).decode()
|
||||
except InvalidToken:
|
||||
logger.warning("配置值解密失败(token 无效)")
|
||||
return value
|
||||
except Exception as exc:
|
||||
logger.error("配置值解密失败: %s", exc)
|
||||
return value
|
||||
|
||||
|
||||
def _get_env_default(key: str) -> str:
|
||||
"""Get default value from environment/settings."""
|
||||
value = getattr(settings, key, "")
|
||||
return str(value) if value is not None else ""
|
||||
|
||||
|
||||
def _mask_sensitive(value: str) -> str:
|
||||
"""Mask sensitive value for display."""
|
||||
if not value:
|
||||
return ""
|
||||
if len(value) <= 8:
|
||||
return "*" * len(value)
|
||||
return f"{value[:4]}...{value[-4:]}"
|
||||
|
||||
|
||||
async def init_default_settings(db: AsyncSession) -> None:
|
||||
"""Initialize default settings from environment if table is empty."""
|
||||
result = await db.execute(select(AppSetting))
|
||||
existing = result.scalars().first()
|
||||
if existing:
|
||||
return
|
||||
|
||||
for key, meta in EDITABLE_SETTINGS.items():
|
||||
default_value = _get_env_default(key)
|
||||
stored_value = _encrypt(default_value) if meta["sensitive"] else default_value
|
||||
db.add(
|
||||
AppSetting(
|
||||
key=key,
|
||||
value=stored_value,
|
||||
description=meta["description"],
|
||||
is_sensitive=meta["sensitive"],
|
||||
)
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
logger.info("已初始化默认配置项: %d 条", len(EDITABLE_SETTINGS))
|
||||
|
||||
|
||||
async def _get_raw_setting(db: AsyncSession, key: str) -> AppSetting | None:
|
||||
"""Get setting row from DB."""
|
||||
result = await db.execute(select(AppSetting).where(AppSetting.key == key))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_setting(db: AsyncSession, key: str, default: Any = None) -> Any:
|
||||
"""Get decrypted setting value from DB or env default."""
|
||||
setting = await _get_raw_setting(db, key)
|
||||
if setting:
|
||||
return _decrypt(setting.value) if setting.is_sensitive else setting.value
|
||||
return _get_env_default(key) if default is None else default
|
||||
|
||||
|
||||
async def set_setting(db: AsyncSession, key: str, value: str) -> bool:
|
||||
"""Update a setting (encrypt sensitive values)."""
|
||||
if key not in EDITABLE_SETTINGS:
|
||||
return False
|
||||
|
||||
meta = EDITABLE_SETTINGS[key]
|
||||
stored_value = _encrypt(str(value)) if meta["sensitive"] else str(value)
|
||||
|
||||
setting = await _get_raw_setting(db, key)
|
||||
if setting:
|
||||
setting.value = stored_value
|
||||
else:
|
||||
setting = AppSetting(
|
||||
key=key,
|
||||
value=stored_value,
|
||||
description=meta["description"],
|
||||
is_sensitive=meta["sensitive"],
|
||||
)
|
||||
db.add(setting)
|
||||
|
||||
await db.commit()
|
||||
logger.info("配置已更新: %s", key)
|
||||
return True
|
||||
|
||||
|
||||
async def list_settings(db: AsyncSession, mask_sensitive: bool = True) -> list[dict[str, Any]]:
|
||||
"""List all settings."""
|
||||
result = await db.execute(select(AppSetting))
|
||||
db_settings = {s.key: s for s in result.scalars().all()}
|
||||
|
||||
output = []
|
||||
for key, meta in EDITABLE_SETTINGS.items():
|
||||
setting = db_settings.get(key)
|
||||
is_sensitive = meta["sensitive"]
|
||||
|
||||
if setting:
|
||||
raw_value = setting.value
|
||||
updated_at = setting.updated_at.isoformat() if setting.updated_at else None
|
||||
else:
|
||||
raw_value = _get_env_default(key)
|
||||
updated_at = None
|
||||
|
||||
decrypted_value = _decrypt(raw_value) if is_sensitive else raw_value
|
||||
|
||||
if is_sensitive and mask_sensitive:
|
||||
display_value = _mask_sensitive(decrypted_value)
|
||||
is_masked = True
|
||||
else:
|
||||
display_value = decrypted_value
|
||||
is_masked = False
|
||||
|
||||
output.append({
|
||||
"key": key,
|
||||
"value": display_value,
|
||||
"real_value": decrypted_value if not mask_sensitive else None,
|
||||
"description": meta["description"],
|
||||
"is_sensitive": is_sensitive,
|
||||
"is_masked": is_masked,
|
||||
"updated_at": updated_at,
|
||||
})
|
||||
|
||||
return output
|
||||
|
||||
|
||||
async def apply_db_settings_to_config(db: AsyncSession) -> None:
|
||||
"""Apply DB settings to runtime config."""
|
||||
for key in EDITABLE_SETTINGS:
|
||||
db_value = await get_setting(db, key)
|
||||
if db_value is None or db_value == "":
|
||||
continue
|
||||
|
||||
field_info = settings.model_fields.get(key)
|
||||
if field_info is None:
|
||||
continue
|
||||
|
||||
target_type = field_info.annotation
|
||||
try:
|
||||
if target_type is int:
|
||||
converted = int(db_value)
|
||||
elif target_type is float:
|
||||
converted = float(db_value)
|
||||
elif target_type is bool:
|
||||
converted = db_value.lower() in ("true", "1", "yes")
|
||||
else:
|
||||
converted = db_value
|
||||
setattr(settings, key, converted)
|
||||
except Exception as exc:
|
||||
logger.error("应用配置 %s=%s 失败: %s", key, db_value, exc)
|
||||
raise ValueError(f"配置项 {key} 的值无效: {db_value}") from exc
|
||||
|
||||
|
||||
async def reset_settings(db: AsyncSession) -> None:
|
||||
"""Reset all settings to env defaults."""
|
||||
for key in EDITABLE_SETTINGS:
|
||||
await set_setting(db, key, _get_env_default(key))
|
||||
logger.info("配置已重置为环境变量默认值")
|
||||
@@ -0,0 +1,118 @@
|
||||
"""Task runtime progress tracking service."""
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.core.redis import get_redis
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
TASK_STATUS_IDLE = "idle"
|
||||
TASK_STATUS_RUNNING = "running"
|
||||
TASK_STATUS_SUCCESS = "success"
|
||||
TASK_STATUS_ERROR = "error"
|
||||
|
||||
|
||||
class TaskRuntime:
|
||||
"""Runtime task progress tracker using Redis."""
|
||||
|
||||
def __init__(self):
|
||||
self._redis = None
|
||||
|
||||
async def _get_redis(self):
|
||||
if self._redis is None:
|
||||
self._redis = await get_redis()
|
||||
return self._redis
|
||||
|
||||
def _key(self, task_key: str) -> str:
|
||||
return f"task_progress:{task_key}"
|
||||
|
||||
async def update_progress(
|
||||
self,
|
||||
task_key: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
stage: str | None = None,
|
||||
current: int | None = None,
|
||||
total: int | None = None,
|
||||
message: str | None = None,
|
||||
trigger: str | None = None,
|
||||
) -> None:
|
||||
"""Update task progress."""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
key = self._key(task_key)
|
||||
|
||||
existing = await redis.hgetall(key)
|
||||
data = dict(existing) if existing else {}
|
||||
|
||||
if status:
|
||||
data["status"] = status
|
||||
if stage:
|
||||
data["stage"] = stage
|
||||
if current is not None:
|
||||
data["current"] = str(current)
|
||||
if total is not None:
|
||||
data["total"] = str(total)
|
||||
if message is not None:
|
||||
data["message"] = message
|
||||
if trigger:
|
||||
data["trigger"] = trigger
|
||||
|
||||
data["updated_at"] = datetime.now(timezone.utc).isoformat()
|
||||
if status == TASK_STATUS_RUNNING and "started_at" not in data:
|
||||
data["started_at"] = data["updated_at"]
|
||||
if status in (TASK_STATUS_SUCCESS, TASK_STATUS_ERROR):
|
||||
data["finished_at"] = data["updated_at"]
|
||||
|
||||
await redis.hset(key, mapping=data)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to update task progress: %s", exc)
|
||||
|
||||
async def get_progress(self, task_key: str) -> dict[str, Any]:
|
||||
"""Get task progress."""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
data = await redis.hgetall(self._key(task_key))
|
||||
if not data:
|
||||
return self._empty_progress(task_key)
|
||||
return {
|
||||
"task_key": task_key,
|
||||
"status": data.get("status", TASK_STATUS_IDLE),
|
||||
"stage": data.get("stage", ""),
|
||||
"current": int(data.get("current", 0)),
|
||||
"total": int(data.get("total", 0)),
|
||||
"message": data.get("message"),
|
||||
"trigger": data.get("trigger"),
|
||||
"started_at": data.get("started_at"),
|
||||
"updated_at": data.get("updated_at"),
|
||||
"finished_at": data.get("finished_at"),
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to get task progress: %s", exc)
|
||||
return self._empty_progress(task_key)
|
||||
|
||||
async def reset_progress(self, task_key: str) -> None:
|
||||
"""Reset task progress to idle."""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
await redis.delete(self._key(task_key))
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to reset task progress: %s", exc)
|
||||
|
||||
def _empty_progress(self, task_key: str) -> dict[str, Any]:
|
||||
return {
|
||||
"task_key": task_key,
|
||||
"status": TASK_STATUS_IDLE,
|
||||
"stage": "",
|
||||
"current": 0,
|
||||
"total": 0,
|
||||
"message": None,
|
||||
"trigger": None,
|
||||
"started_at": None,
|
||||
"updated_at": None,
|
||||
"finished_at": None,
|
||||
}
|
||||
|
||||
|
||||
task_runtime = TaskRuntime()
|
||||
Reference in New Issue
Block a user