Initial commit: RSS platform phase 1 skeleton with code review fixes
Features: - FastAPI + SQLAlchemy 2.0 async + PostgreSQL/pgvector + Redis backend - Vue 3 + TypeScript + Element Plus frontend - JWT auth with access/refresh tokens and revocation - Admin/member RBAC - RSS feed CRUD and article listing - Settings management with Fernet encryption for sensitive values - Redis distributed lock service - Alembic initial migration - Docker Compose development environment Fixes from code review: - Fix DB session leak in dependency injection - Restrict registration to admin only - Add default admin password warning - Implement JWT refresh tokens and jti blacklist - Strengthen password policy - Use func.count for pagination totals - Replace NullPool with AsyncAdaptedQueuePool - Remove init_db from lifespan to enforce alembic migrations - Add request_id middleware and logging filter - Fix vite.config.ts env loading - Add frontend token refresh interceptor - Add Vue error handler Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,102 @@
|
||||
"""Authentication and authorization utilities."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.redis import get_redis
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
TOKEN_TYPE_ACCESS = "access"
|
||||
TOKEN_TYPE_REFRESH = "refresh"
|
||||
|
||||
# Redis key for revoked JWT jti set
|
||||
REVOKED_JTIS_KEY = "auth:revoked_jtis"
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a plain password against a hash."""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Hash a password."""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def _create_token(data: dict[str, Any], expires_delta: timedelta, token_type: str) -> tuple[str, str]:
|
||||
"""Create a JWT with jti/type claims. Returns (token, jti)."""
|
||||
jti = str(uuid4())
|
||||
to_encode = data.copy()
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"iat": datetime.now(timezone.utc),
|
||||
"jti": jti,
|
||||
"type": token_type,
|
||||
})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt, jti
|
||||
|
||||
|
||||
def create_access_token(sub: str, role: str | None = None) -> tuple[str, str]:
|
||||
"""Create a short-lived JWT access token. Returns (token, jti)."""
|
||||
data: dict[str, Any] = {"sub": sub}
|
||||
if role is not None:
|
||||
data["role"] = role
|
||||
return _create_token(
|
||||
data,
|
||||
timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||
TOKEN_TYPE_ACCESS,
|
||||
)
|
||||
|
||||
|
||||
def create_refresh_token(sub: str) -> tuple[str, str]:
|
||||
"""Create a long-lived JWT refresh token. Returns (token, jti)."""
|
||||
return _create_token(
|
||||
{"sub": sub},
|
||||
timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
|
||||
TOKEN_TYPE_REFRESH,
|
||||
)
|
||||
|
||||
|
||||
def decode_token(token: str, expected_type: str = TOKEN_TYPE_ACCESS) -> dict[str, Any]:
|
||||
"""Decode and validate a JWT token, checking type claim."""
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
|
||||
except JWTError as exc:
|
||||
raise ValueError("Invalid token") from exc
|
||||
|
||||
token_type = payload.get("type")
|
||||
if token_type != expected_type:
|
||||
raise ValueError(f"Invalid token type: expected {expected_type}, got {token_type}")
|
||||
|
||||
if "sub" not in payload or "jti" not in payload:
|
||||
raise ValueError("Invalid token payload")
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
async def is_token_revoked(jti: str) -> bool:
|
||||
"""Check whether a token jti has been revoked."""
|
||||
redis = await get_redis()
|
||||
if redis is None:
|
||||
# Without Redis we cannot reliably maintain a revocation list.
|
||||
return False
|
||||
return await redis.sismember(REVOKED_JTIS_KEY, jti)
|
||||
|
||||
|
||||
async def revoke_token(jti: str, expires_at: datetime) -> None:
|
||||
"""Revoke a token by its jti with TTL matching token expiry."""
|
||||
redis = await get_redis()
|
||||
if redis is None:
|
||||
return
|
||||
ttl_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds())
|
||||
if ttl_seconds > 0:
|
||||
await redis.sadd(REVOKED_JTIS_KEY, jti)
|
||||
await redis.expire(REVOKED_JTIS_KEY, ttl_seconds)
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Application configuration."""
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# Database
|
||||
DATABASE_URL: str = "postgresql+asyncpg://rss:rss@postgres:5432/rss_platform"
|
||||
|
||||
# Redis
|
||||
REDIS_URL: str = "redis://redis:6379/0"
|
||||
|
||||
# JWT
|
||||
SECRET_KEY: str = Field(..., min_length=32)
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 7
|
||||
|
||||
# AI
|
||||
AI_DEFAULT_PROVIDER: str = "openai"
|
||||
AI_DEFAULT_MODEL: str = "gpt-4o-mini"
|
||||
|
||||
# Storage
|
||||
STORAGE_TYPE: str = "minio"
|
||||
MINIO_ENDPOINT: str = "minio:9000"
|
||||
MINIO_ACCESS_KEY: str = "minioadmin"
|
||||
MINIO_SECRET_KEY: str = "minioadmin"
|
||||
MINIO_BUCKET: str = "rss-platform"
|
||||
|
||||
# CORS
|
||||
CORS_ALLOWED_ORIGINS: str = ""
|
||||
|
||||
# Default admin
|
||||
DEFAULT_ADMIN_USERNAME: str = "admin"
|
||||
DEFAULT_ADMIN_PASSWORD: str = "admin"
|
||||
|
||||
# Sensitive settings encryption
|
||||
SETTINGS_ENCRYPTION_KEY: str = ""
|
||||
|
||||
# Logging
|
||||
LOG_LEVEL: str = "INFO"
|
||||
|
||||
# RSS Fetching
|
||||
FETCH_CONCURRENCY: int = 10
|
||||
FETCH_TIMEOUT: int = 30
|
||||
DEFAULT_FETCH_INTERVAL: int = 60
|
||||
MIN_FETCH_INTERVAL: int = 15
|
||||
|
||||
# Ports (for reference)
|
||||
BACKEND_PORT: int = 8000
|
||||
FRONTEND_PORT: int = 5173
|
||||
|
||||
@property
|
||||
def cors_origins(self) -> list[str]:
|
||||
"""Parse CORS_ALLOWED_ORIGINS into list."""
|
||||
if not self.CORS_ALLOWED_ORIGINS:
|
||||
return []
|
||||
return [origin.strip() for origin in self.CORS_ALLOWED_ORIGINS.split(",") if origin.strip()]
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Database configuration and session management."""
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=settings.LOG_LEVEL == "DEBUG",
|
||||
future=True,
|
||||
pool_size=10,
|
||||
max_overflow=20,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_db() -> AsyncSession:
|
||||
"""Dependency for FastAPI to get async DB session."""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def close_db() -> None:
|
||||
"""Close database connections."""
|
||||
await engine.dispose()
|
||||
@@ -0,0 +1,64 @@
|
||||
"""Custom exceptions and error handlers."""
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PlatformException(Exception):
|
||||
"""Base exception for the platform."""
|
||||
|
||||
def __init__(self, message: str, status_code: int = 400):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class AuthenticationError(PlatformException):
|
||||
"""Authentication failed."""
|
||||
|
||||
def __init__(self, message: str = "Authentication failed"):
|
||||
super().__init__(message, status_code=401)
|
||||
|
||||
|
||||
class AuthorizationError(PlatformException):
|
||||
"""Authorization failed."""
|
||||
|
||||
def __init__(self, message: str = "Forbidden"):
|
||||
super().__init__(message, status_code=403)
|
||||
|
||||
|
||||
class NotFoundError(PlatformException):
|
||||
"""Resource not found."""
|
||||
|
||||
def __init__(self, message: str = "Resource not found"):
|
||||
super().__init__(message, status_code=404)
|
||||
|
||||
|
||||
class ConflictError(PlatformException):
|
||||
"""Resource conflict."""
|
||||
|
||||
def __init__(self, message: str = "Conflict"):
|
||||
super().__init__(message, status_code=409)
|
||||
|
||||
|
||||
def add_exception_handlers(app: FastAPI) -> None:
|
||||
"""Register global exception handlers."""
|
||||
|
||||
@app.exception_handler(PlatformException)
|
||||
async def platform_exception_handler(request: Request, exc: PlatformException):
|
||||
logger.warning("Platform exception: %s", exc.message)
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"detail": exc.message},
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def generic_exception_handler(request: Request, exc: Exception):
|
||||
logger.exception("Unhandled exception: %s", exc)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "Internal server error"},
|
||||
)
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Logging configuration."""
|
||||
import logging
|
||||
import sys
|
||||
from contextvars import ContextVar
|
||||
|
||||
request_id_var: ContextVar[str] = ContextVar("request_id", default="")
|
||||
|
||||
|
||||
def configure_logging(log_level: str = "INFO") -> None:
|
||||
"""Configure structured logging."""
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - [%(request_id)s] %(message)s"
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
handler.addFilter(RequestIdFilter())
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(getattr(logging, log_level.upper(), logging.INFO))
|
||||
root_logger.handlers = []
|
||||
root_logger.addHandler(handler)
|
||||
|
||||
# Reduce noise from third-party libraries
|
||||
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class RequestIdFilter(logging.Filter):
|
||||
"""Inject request_id into log records."""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
record.request_id = request_id_var.get() # type: ignore[attr-defined]
|
||||
return True
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Get a logger with request_id filter."""
|
||||
logger = logging.getLogger(name)
|
||||
logger.addFilter(RequestIdFilter())
|
||||
return logger
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Role-based access control."""
|
||||
from enum import Enum
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class Role(str, Enum):
|
||||
"""User roles."""
|
||||
|
||||
ADMIN = "admin"
|
||||
MEMBER = "member"
|
||||
|
||||
|
||||
def require_admin(current_user: User) -> User:
|
||||
"""Dependency that requires admin role."""
|
||||
if current_user.role != Role.ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin privileges required",
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def has_permission(user: User, required_role: Role) -> bool:
|
||||
"""Check if user has required role."""
|
||||
if user.role == Role.ADMIN:
|
||||
return True
|
||||
return user.role == required_role
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Redis connection management."""
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
_redis: Redis | None = None
|
||||
|
||||
|
||||
async def get_redis() -> Redis:
|
||||
"""Get or create Redis connection."""
|
||||
global _redis
|
||||
if _redis is None:
|
||||
_redis = Redis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
return _redis
|
||||
|
||||
|
||||
async def close_redis() -> None:
|
||||
"""Close Redis connection."""
|
||||
global _redis
|
||||
if _redis:
|
||||
await _redis.close()
|
||||
_redis = None
|
||||
|
||||
|
||||
async def check_redis_health() -> bool:
|
||||
"""Check if Redis is reachable."""
|
||||
try:
|
||||
redis = await get_redis()
|
||||
await redis.ping()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
Reference in New Issue
Block a user