"""Distributed lock service with Redis and DB fallback.""" from datetime import datetime, timedelta, timezone from uuid import uuid4 from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.database import AsyncSessionLocal from app.core.logging import get_logger from app.core.redis import get_redis from app.models.lock import Lock logger = get_logger(__name__) class LockService: """Distributed lock service.""" def __init__(self, owner_id: str | None = None): self.owner_id = owner_id or str(uuid4()) async def acquire(self, lock_name: str, ttl: int = 60) -> bool: """Acquire a lock with given TTL in seconds.""" # Try Redis first try: redis = await get_redis() acquired = await redis.set(lock_name, self.owner_id, nx=True, ex=ttl) if acquired: return True except Exception as exc: logger.warning("Redis lock failed, falling back to DB: %s", exc) # Fallback to DB return await self._acquire_db(lock_name, ttl) async def release(self, lock_name: str) -> bool: """Release a lock.""" # Try Redis first try: redis = await get_redis() # Only release if we own it current_owner = await redis.get(lock_name) if current_owner == self.owner_id: await redis.delete(lock_name) return True except Exception as exc: logger.warning("Redis unlock failed, falling back to DB: %s", exc) return await self._release_db(lock_name) async def extend(self, lock_name: str, ttl: int = 60) -> bool: """Extend lock TTL.""" try: redis = await get_redis() current_owner = await redis.get(lock_name) if current_owner == self.owner_id: await redis.expire(lock_name, ttl) return True except Exception as exc: logger.warning("Redis extend failed: %s", exc) return await self._extend_db(lock_name, ttl) async def is_locked(self, lock_name: str) -> bool: """Check if a lock is held.""" try: redis = await get_redis() exists = await redis.exists(lock_name) if exists: return True except Exception: pass async with AsyncSessionLocal() as db: result = await db.execute(select(Lock).where(Lock.lock_name == lock_name)) lock = result.scalar_one_or_none() if not lock: return False if lock.expires_at and lock.expires_at < datetime.now(timezone.utc): return False return True async def _acquire_db(self, lock_name: str, ttl: int) -> bool: async with AsyncSessionLocal() as db: now = datetime.now(timezone.utc) expires_at = now + timedelta(seconds=ttl) # Try to update expired lock result = await db.execute( select(Lock).where( Lock.lock_name == lock_name, Lock.expires_at < now, ) ) lock = result.scalar_one_or_none() if lock: lock.owner_id = self.owner_id lock.acquired_at = now lock.expires_at = expires_at await db.commit() return True # Try to insert new lock lock = Lock( lock_name=lock_name, owner_id=self.owner_id, acquired_at=now, expires_at=expires_at, ) db.add(lock) try: await db.commit() return True except Exception: await db.rollback() return False async def _release_db(self, lock_name: str) -> bool: async with AsyncSessionLocal() as db: result = await db.execute( select(Lock).where( Lock.lock_name == lock_name, Lock.owner_id == self.owner_id, ) ) lock = result.scalar_one_or_none() if not lock: return False await db.delete(lock) await db.commit() return True async def _extend_db(self, lock_name: str, ttl: int) -> bool: async with AsyncSessionLocal() as db: result = await db.execute( select(Lock).where( Lock.lock_name == lock_name, Lock.owner_id == self.owner_id, ) ) lock = result.scalar_one_or_none() if not lock: return False lock.expires_at = datetime.now(timezone.utc) + timedelta(seconds=ttl) await db.commit() return True async def get_lock_service(owner_id: str | None = None) -> LockService: """Get a lock service instance.""" return LockService(owner_id=owner_id)