"""Authentication and authorization utilities.""" from datetime import datetime, timedelta, timezone from typing import Any from uuid import uuid4 from jose import JWTError, jwt from passlib.context import CryptContext from app.core.config import settings from app.core.redis import get_redis pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") ALGORITHM = "HS256" TOKEN_TYPE_ACCESS = "access" TOKEN_TYPE_REFRESH = "refresh" # Redis key for revoked JWT jti set REVOKED_JTIS_KEY = "auth:revoked_jtis" def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a plain password against a hash.""" return pwd_context.verify(plain_password, hashed_password) def get_password_hash(password: str) -> str: """Hash a password.""" return pwd_context.hash(password) def _create_token(data: dict[str, Any], expires_delta: timedelta, token_type: str) -> tuple[str, str]: """Create a JWT with jti/type claims. Returns (token, jti).""" jti = str(uuid4()) to_encode = data.copy() expire = datetime.now(timezone.utc) + expires_delta to_encode.update({ "exp": expire, "iat": datetime.now(timezone.utc), "jti": jti, "type": token_type, }) encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt, jti def create_access_token(sub: str, role: str | None = None) -> tuple[str, str]: """Create a short-lived JWT access token. Returns (token, jti).""" data: dict[str, Any] = {"sub": sub} if role is not None: data["role"] = role return _create_token( data, timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES), TOKEN_TYPE_ACCESS, ) def create_refresh_token(sub: str) -> tuple[str, str]: """Create a long-lived JWT refresh token. Returns (token, jti).""" return _create_token( {"sub": sub}, timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS), TOKEN_TYPE_REFRESH, ) def decode_token(token: str, expected_type: str = TOKEN_TYPE_ACCESS) -> dict[str, Any]: """Decode and validate a JWT token, checking type claim.""" try: payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) except JWTError as exc: raise ValueError("Invalid token") from exc token_type = payload.get("type") if token_type != expected_type: raise ValueError(f"Invalid token type: expected {expected_type}, got {token_type}") if "sub" not in payload or "jti" not in payload: raise ValueError("Invalid token payload") return payload async def is_token_revoked(jti: str) -> bool: """Check whether a token jti has been revoked.""" redis = await get_redis() if redis is None: # Without Redis we cannot reliably maintain a revocation list. return False return await redis.sismember(REVOKED_JTIS_KEY, jti) async def revoke_token(jti: str, expires_at: datetime) -> None: """Revoke a token by its jti with TTL matching token expiry.""" redis = await get_redis() if redis is None: return ttl_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds()) if ttl_seconds > 0: await redis.sadd(REVOKED_JTIS_KEY, jti) await redis.expire(REVOKED_JTIS_KEY, ttl_seconds)