Files

154 lines
5.0 KiB
Python
Raw Permalink Normal View History

"""Distributed lock service with Redis and DB fallback."""
from datetime import datetime, timedelta, timezone
from uuid import uuid4
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import AsyncSessionLocal
from app.core.logging import get_logger
from app.core.redis import get_redis
from app.models.lock import Lock
logger = get_logger(__name__)
class LockService:
"""Distributed lock service."""
def __init__(self, owner_id: str | None = None):
self.owner_id = owner_id or str(uuid4())
async def acquire(self, lock_name: str, ttl: int = 60) -> bool:
"""Acquire a lock with given TTL in seconds."""
# Try Redis first
try:
redis = await get_redis()
acquired = await redis.set(lock_name, self.owner_id, nx=True, ex=ttl)
if acquired:
return True
except Exception as exc:
logger.warning("Redis lock failed, falling back to DB: %s", exc)
# Fallback to DB
return await self._acquire_db(lock_name, ttl)
async def release(self, lock_name: str) -> bool:
"""Release a lock."""
# Try Redis first
try:
redis = await get_redis()
# Only release if we own it
current_owner = await redis.get(lock_name)
if current_owner == self.owner_id:
await redis.delete(lock_name)
return True
except Exception as exc:
logger.warning("Redis unlock failed, falling back to DB: %s", exc)
return await self._release_db(lock_name)
async def extend(self, lock_name: str, ttl: int = 60) -> bool:
"""Extend lock TTL."""
try:
redis = await get_redis()
current_owner = await redis.get(lock_name)
if current_owner == self.owner_id:
await redis.expire(lock_name, ttl)
return True
except Exception as exc:
logger.warning("Redis extend failed: %s", exc)
return await self._extend_db(lock_name, ttl)
async def is_locked(self, lock_name: str) -> bool:
"""Check if a lock is held."""
try:
redis = await get_redis()
exists = await redis.exists(lock_name)
if exists:
return True
except Exception:
pass
async with AsyncSessionLocal() as db:
result = await db.execute(select(Lock).where(Lock.lock_name == lock_name))
lock = result.scalar_one_or_none()
if not lock:
return False
if lock.expires_at and lock.expires_at < datetime.now(timezone.utc):
return False
return True
async def _acquire_db(self, lock_name: str, ttl: int) -> bool:
async with AsyncSessionLocal() as db:
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=ttl)
# Try to update expired lock
result = await db.execute(
select(Lock).where(
Lock.lock_name == lock_name,
Lock.expires_at < now,
)
)
lock = result.scalar_one_or_none()
if lock:
lock.owner_id = self.owner_id
lock.acquired_at = now
lock.expires_at = expires_at
await db.commit()
return True
# Try to insert new lock
lock = Lock(
lock_name=lock_name,
owner_id=self.owner_id,
acquired_at=now,
expires_at=expires_at,
)
db.add(lock)
try:
await db.commit()
return True
except Exception:
await db.rollback()
return False
async def _release_db(self, lock_name: str) -> bool:
async with AsyncSessionLocal() as db:
result = await db.execute(
select(Lock).where(
Lock.lock_name == lock_name,
Lock.owner_id == self.owner_id,
)
)
lock = result.scalar_one_or_none()
if not lock:
return False
await db.delete(lock)
await db.commit()
return True
async def _extend_db(self, lock_name: str, ttl: int) -> bool:
async with AsyncSessionLocal() as db:
result = await db.execute(
select(Lock).where(
Lock.lock_name == lock_name,
Lock.owner_id == self.owner_id,
)
)
lock = result.scalar_one_or_none()
if not lock:
return False
lock.expires_at = datetime.now(timezone.utc) + timedelta(seconds=ttl)
await db.commit()
return True
async def get_lock_service(owner_id: str | None = None) -> LockService:
"""Get a lock service instance."""
return LockService(owner_id=owner_id)