103 lines
3.3 KiB
Python
103 lines
3.3 KiB
Python
|
|
"""Authentication and authorization utilities."""
|
||
|
|
from datetime import datetime, timedelta, timezone
|
||
|
|
from typing import Any
|
||
|
|
from uuid import uuid4
|
||
|
|
|
||
|
|
from jose import JWTError, jwt
|
||
|
|
from passlib.context import CryptContext
|
||
|
|
|
||
|
|
from app.core.config import settings
|
||
|
|
from app.core.redis import get_redis
|
||
|
|
|
||
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||
|
|
|
||
|
|
ALGORITHM = "HS256"
|
||
|
|
TOKEN_TYPE_ACCESS = "access"
|
||
|
|
TOKEN_TYPE_REFRESH = "refresh"
|
||
|
|
|
||
|
|
# Redis key for revoked JWT jti set
|
||
|
|
REVOKED_JTIS_KEY = "auth:revoked_jtis"
|
||
|
|
|
||
|
|
|
||
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||
|
|
"""Verify a plain password against a hash."""
|
||
|
|
return pwd_context.verify(plain_password, hashed_password)
|
||
|
|
|
||
|
|
|
||
|
|
def get_password_hash(password: str) -> str:
|
||
|
|
"""Hash a password."""
|
||
|
|
return pwd_context.hash(password)
|
||
|
|
|
||
|
|
|
||
|
|
def _create_token(data: dict[str, Any], expires_delta: timedelta, token_type: str) -> tuple[str, str]:
|
||
|
|
"""Create a JWT with jti/type claims. Returns (token, jti)."""
|
||
|
|
jti = str(uuid4())
|
||
|
|
to_encode = data.copy()
|
||
|
|
expire = datetime.now(timezone.utc) + expires_delta
|
||
|
|
to_encode.update({
|
||
|
|
"exp": expire,
|
||
|
|
"iat": datetime.now(timezone.utc),
|
||
|
|
"jti": jti,
|
||
|
|
"type": token_type,
|
||
|
|
})
|
||
|
|
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
|
||
|
|
return encoded_jwt, jti
|
||
|
|
|
||
|
|
|
||
|
|
def create_access_token(sub: str, role: str | None = None) -> tuple[str, str]:
|
||
|
|
"""Create a short-lived JWT access token. Returns (token, jti)."""
|
||
|
|
data: dict[str, Any] = {"sub": sub}
|
||
|
|
if role is not None:
|
||
|
|
data["role"] = role
|
||
|
|
return _create_token(
|
||
|
|
data,
|
||
|
|
timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
|
||
|
|
TOKEN_TYPE_ACCESS,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def create_refresh_token(sub: str) -> tuple[str, str]:
|
||
|
|
"""Create a long-lived JWT refresh token. Returns (token, jti)."""
|
||
|
|
return _create_token(
|
||
|
|
{"sub": sub},
|
||
|
|
timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
|
||
|
|
TOKEN_TYPE_REFRESH,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def decode_token(token: str, expected_type: str = TOKEN_TYPE_ACCESS) -> dict[str, Any]:
|
||
|
|
"""Decode and validate a JWT token, checking type claim."""
|
||
|
|
try:
|
||
|
|
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
|
||
|
|
except JWTError as exc:
|
||
|
|
raise ValueError("Invalid token") from exc
|
||
|
|
|
||
|
|
token_type = payload.get("type")
|
||
|
|
if token_type != expected_type:
|
||
|
|
raise ValueError(f"Invalid token type: expected {expected_type}, got {token_type}")
|
||
|
|
|
||
|
|
if "sub" not in payload or "jti" not in payload:
|
||
|
|
raise ValueError("Invalid token payload")
|
||
|
|
|
||
|
|
return payload
|
||
|
|
|
||
|
|
|
||
|
|
async def is_token_revoked(jti: str) -> bool:
|
||
|
|
"""Check whether a token jti has been revoked."""
|
||
|
|
redis = await get_redis()
|
||
|
|
if redis is None:
|
||
|
|
# Without Redis we cannot reliably maintain a revocation list.
|
||
|
|
return False
|
||
|
|
return await redis.sismember(REVOKED_JTIS_KEY, jti)
|
||
|
|
|
||
|
|
|
||
|
|
async def revoke_token(jti: str, expires_at: datetime) -> None:
|
||
|
|
"""Revoke a token by its jti with TTL matching token expiry."""
|
||
|
|
redis = await get_redis()
|
||
|
|
if redis is None:
|
||
|
|
return
|
||
|
|
ttl_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds())
|
||
|
|
if ttl_seconds > 0:
|
||
|
|
await redis.sadd(REVOKED_JTIS_KEY, jti)
|
||
|
|
await redis.expire(REVOKED_JTIS_KEY, ttl_seconds)
|