154 lines
5.0 KiB
Python
154 lines
5.0 KiB
Python
|
|
"""Distributed lock service with Redis and DB fallback."""
|
||
|
|
from datetime import datetime, timedelta, timezone
|
||
|
|
from uuid import uuid4
|
||
|
|
|
||
|
|
from sqlalchemy import select
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
|
|
||
|
|
from app.core.database import AsyncSessionLocal
|
||
|
|
from app.core.logging import get_logger
|
||
|
|
from app.core.redis import get_redis
|
||
|
|
from app.models.lock import Lock
|
||
|
|
|
||
|
|
logger = get_logger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class LockService:
|
||
|
|
"""Distributed lock service."""
|
||
|
|
|
||
|
|
def __init__(self, owner_id: str | None = None):
|
||
|
|
self.owner_id = owner_id or str(uuid4())
|
||
|
|
|
||
|
|
async def acquire(self, lock_name: str, ttl: int = 60) -> bool:
|
||
|
|
"""Acquire a lock with given TTL in seconds."""
|
||
|
|
# Try Redis first
|
||
|
|
try:
|
||
|
|
redis = await get_redis()
|
||
|
|
acquired = await redis.set(lock_name, self.owner_id, nx=True, ex=ttl)
|
||
|
|
if acquired:
|
||
|
|
return True
|
||
|
|
except Exception as exc:
|
||
|
|
logger.warning("Redis lock failed, falling back to DB: %s", exc)
|
||
|
|
|
||
|
|
# Fallback to DB
|
||
|
|
return await self._acquire_db(lock_name, ttl)
|
||
|
|
|
||
|
|
async def release(self, lock_name: str) -> bool:
|
||
|
|
"""Release a lock."""
|
||
|
|
# Try Redis first
|
||
|
|
try:
|
||
|
|
redis = await get_redis()
|
||
|
|
# Only release if we own it
|
||
|
|
current_owner = await redis.get(lock_name)
|
||
|
|
if current_owner == self.owner_id:
|
||
|
|
await redis.delete(lock_name)
|
||
|
|
return True
|
||
|
|
except Exception as exc:
|
||
|
|
logger.warning("Redis unlock failed, falling back to DB: %s", exc)
|
||
|
|
|
||
|
|
return await self._release_db(lock_name)
|
||
|
|
|
||
|
|
async def extend(self, lock_name: str, ttl: int = 60) -> bool:
|
||
|
|
"""Extend lock TTL."""
|
||
|
|
try:
|
||
|
|
redis = await get_redis()
|
||
|
|
current_owner = await redis.get(lock_name)
|
||
|
|
if current_owner == self.owner_id:
|
||
|
|
await redis.expire(lock_name, ttl)
|
||
|
|
return True
|
||
|
|
except Exception as exc:
|
||
|
|
logger.warning("Redis extend failed: %s", exc)
|
||
|
|
|
||
|
|
return await self._extend_db(lock_name, ttl)
|
||
|
|
|
||
|
|
async def is_locked(self, lock_name: str) -> bool:
|
||
|
|
"""Check if a lock is held."""
|
||
|
|
try:
|
||
|
|
redis = await get_redis()
|
||
|
|
exists = await redis.exists(lock_name)
|
||
|
|
if exists:
|
||
|
|
return True
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
|
||
|
|
async with AsyncSessionLocal() as db:
|
||
|
|
result = await db.execute(select(Lock).where(Lock.lock_name == lock_name))
|
||
|
|
lock = result.scalar_one_or_none()
|
||
|
|
if not lock:
|
||
|
|
return False
|
||
|
|
if lock.expires_at and lock.expires_at < datetime.now(timezone.utc):
|
||
|
|
return False
|
||
|
|
return True
|
||
|
|
|
||
|
|
async def _acquire_db(self, lock_name: str, ttl: int) -> bool:
|
||
|
|
async with AsyncSessionLocal() as db:
|
||
|
|
now = datetime.now(timezone.utc)
|
||
|
|
expires_at = now + timedelta(seconds=ttl)
|
||
|
|
|
||
|
|
# Try to update expired lock
|
||
|
|
result = await db.execute(
|
||
|
|
select(Lock).where(
|
||
|
|
Lock.lock_name == lock_name,
|
||
|
|
Lock.expires_at < now,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
lock = result.scalar_one_or_none()
|
||
|
|
if lock:
|
||
|
|
lock.owner_id = self.owner_id
|
||
|
|
lock.acquired_at = now
|
||
|
|
lock.expires_at = expires_at
|
||
|
|
await db.commit()
|
||
|
|
return True
|
||
|
|
|
||
|
|
# Try to insert new lock
|
||
|
|
lock = Lock(
|
||
|
|
lock_name=lock_name,
|
||
|
|
owner_id=self.owner_id,
|
||
|
|
acquired_at=now,
|
||
|
|
expires_at=expires_at,
|
||
|
|
)
|
||
|
|
db.add(lock)
|
||
|
|
try:
|
||
|
|
await db.commit()
|
||
|
|
return True
|
||
|
|
except Exception:
|
||
|
|
await db.rollback()
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def _release_db(self, lock_name: str) -> bool:
|
||
|
|
async with AsyncSessionLocal() as db:
|
||
|
|
result = await db.execute(
|
||
|
|
select(Lock).where(
|
||
|
|
Lock.lock_name == lock_name,
|
||
|
|
Lock.owner_id == self.owner_id,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
lock = result.scalar_one_or_none()
|
||
|
|
if not lock:
|
||
|
|
return False
|
||
|
|
|
||
|
|
await db.delete(lock)
|
||
|
|
await db.commit()
|
||
|
|
return True
|
||
|
|
|
||
|
|
async def _extend_db(self, lock_name: str, ttl: int) -> bool:
|
||
|
|
async with AsyncSessionLocal() as db:
|
||
|
|
result = await db.execute(
|
||
|
|
select(Lock).where(
|
||
|
|
Lock.lock_name == lock_name,
|
||
|
|
Lock.owner_id == self.owner_id,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
lock = result.scalar_one_or_none()
|
||
|
|
if not lock:
|
||
|
|
return False
|
||
|
|
|
||
|
|
lock.expires_at = datetime.now(timezone.utc) + timedelta(seconds=ttl)
|
||
|
|
await db.commit()
|
||
|
|
return True
|
||
|
|
|
||
|
|
|
||
|
|
async def get_lock_service(owner_id: str | None = None) -> LockService:
|
||
|
|
"""Get a lock service instance."""
|
||
|
|
return LockService(owner_id=owner_id)
|