144 lines
4.1 KiB
Python
144 lines
4.1 KiB
Python
|
|
"""Authentication router."""
|
||
|
|
from datetime import datetime, timezone
|
||
|
|
|
||
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||
|
|
from sqlalchemy import select
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
|
|
||
|
|
from app.api.deps import get_current_admin, get_current_user, get_db
|
||
|
|
from app.core.auth import (
|
||
|
|
create_access_token,
|
||
|
|
create_refresh_token,
|
||
|
|
decode_token,
|
||
|
|
get_password_hash,
|
||
|
|
revoke_token,
|
||
|
|
verify_password,
|
||
|
|
)
|
||
|
|
from app.models.user import User
|
||
|
|
from app.schemas.user import (
|
||
|
|
RefreshTokenRequest,
|
||
|
|
TokenResponse,
|
||
|
|
UserCreate,
|
||
|
|
UserLogin,
|
||
|
|
UserOut,
|
||
|
|
)
|
||
|
|
|
||
|
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/register", response_model=UserOut)
|
||
|
|
async def register(
|
||
|
|
user_in: UserCreate,
|
||
|
|
db: AsyncSession = Depends(get_db),
|
||
|
|
_: User = Depends(get_current_admin),
|
||
|
|
):
|
||
|
|
"""Register a new user (admin only)."""
|
||
|
|
# Check if username exists
|
||
|
|
result = await db.execute(select(User).where(User.username == user_in.username))
|
||
|
|
existing = result.scalar_one_or_none()
|
||
|
|
if existing:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_409_CONFLICT,
|
||
|
|
detail="Username already exists",
|
||
|
|
)
|
||
|
|
|
||
|
|
user = User(
|
||
|
|
username=user_in.username,
|
||
|
|
password_hash=get_password_hash(user_in.password),
|
||
|
|
role=user_in.role,
|
||
|
|
is_active=user_in.is_active,
|
||
|
|
)
|
||
|
|
db.add(user)
|
||
|
|
await db.commit()
|
||
|
|
await db.refresh(user)
|
||
|
|
return user
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/login", response_model=TokenResponse)
|
||
|
|
async def login(
|
||
|
|
credentials: UserLogin,
|
||
|
|
db: AsyncSession = Depends(get_db),
|
||
|
|
):
|
||
|
|
"""Login and get access/refresh tokens."""
|
||
|
|
result = await db.execute(select(User).where(User.username == credentials.username))
|
||
|
|
user = result.scalar_one_or_none()
|
||
|
|
|
||
|
|
if not user or not verify_password(credentials.password, user.password_hash):
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
detail="Incorrect username or password",
|
||
|
|
headers={"WWW-Authenticate": "Bearer"},
|
||
|
|
)
|
||
|
|
|
||
|
|
if not user.is_active:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
||
|
|
detail="Inactive user",
|
||
|
|
)
|
||
|
|
|
||
|
|
user.last_login_at = datetime.now(timezone.utc)
|
||
|
|
await db.commit()
|
||
|
|
|
||
|
|
access_token, _ = create_access_token(sub=str(user.id), role=user.role)
|
||
|
|
refresh_token, _ = create_refresh_token(sub=str(user.id))
|
||
|
|
return {
|
||
|
|
"access_token": access_token,
|
||
|
|
"refresh_token": refresh_token,
|
||
|
|
"token_type": "bearer",
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/refresh", response_model=TokenResponse)
|
||
|
|
async def refresh(
|
||
|
|
req: RefreshTokenRequest,
|
||
|
|
db: AsyncSession = Depends(get_db),
|
||
|
|
):
|
||
|
|
"""Exchange a valid refresh token for a new token pair."""
|
||
|
|
try:
|
||
|
|
payload = decode_token(req.refresh_token, expected_type="refresh")
|
||
|
|
except ValueError as exc:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
detail=f"Invalid refresh token: {exc}",
|
||
|
|
headers={"WWW-Authenticate": "Bearer"},
|
||
|
|
) from exc
|
||
|
|
|
||
|
|
user = await db.get(User, payload["sub"])
|
||
|
|
if user is None or not user.is_active:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
detail="Invalid user",
|
||
|
|
headers={"WWW-Authenticate": "Bearer"},
|
||
|
|
)
|
||
|
|
|
||
|
|
access_token, _ = create_access_token(sub=str(user.id), role=user.role)
|
||
|
|
refresh_token, _ = create_refresh_token(sub=str(user.id))
|
||
|
|
return {
|
||
|
|
"access_token": access_token,
|
||
|
|
"refresh_token": refresh_token,
|
||
|
|
"token_type": "bearer",
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT)
|
||
|
|
async def logout(
|
||
|
|
req: RefreshTokenRequest,
|
||
|
|
):
|
||
|
|
"""Revoke the provided refresh token."""
|
||
|
|
try:
|
||
|
|
payload = decode_token(req.refresh_token, expected_type="refresh")
|
||
|
|
except ValueError:
|
||
|
|
return None
|
||
|
|
|
||
|
|
exp = payload.get("exp")
|
||
|
|
if exp:
|
||
|
|
expires_at = datetime.fromtimestamp(exp, tz=timezone.utc)
|
||
|
|
await revoke_token(payload["jti"], expires_at)
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/me", response_model=UserOut)
|
||
|
|
async def get_me(current_user: User = Depends(get_current_user)):
|
||
|
|
"""Get current user info."""
|
||
|
|
return current_user
|