"""缓存模块 提供内存缓存和 TTL 管理 """ from __future__ import annotations import asyncio import hashlib import time from dataclasses import dataclass, field from typing import Any, Generic, TypeVar from minenasai.core.logging import get_logger logger = get_logger(__name__) T = TypeVar("T") @dataclass class CacheEntry(Generic[T]): """缓存条目""" key: str value: T created_at: float expires_at: float hits: int = 0 @property def is_expired(self) -> bool: """是否过期""" return time.time() > self.expires_at @property def ttl_remaining(self) -> float: """剩余 TTL(秒)""" return max(0, self.expires_at - time.time()) class MemoryCache(Generic[T]): """内存缓存 支持 TTL、最大容量、LRU 淘汰 """ def __init__( self, max_size: int = 1000, default_ttl: float = 300.0, # 5分钟 cleanup_interval: float = 60.0, # 1分钟清理一次 ) -> None: self._cache: dict[str, CacheEntry[T]] = {} self._max_size = max_size self._default_ttl = default_ttl self._cleanup_interval = cleanup_interval self._cleanup_task: asyncio.Task[None] | None = None # 统计 self._hits = 0 self._misses = 0 async def start(self) -> None: """启动后台清理任务""" if self._cleanup_task is None: self._cleanup_task = asyncio.create_task(self._cleanup_loop()) logger.info("cache_cleanup_started", interval=self._cleanup_interval) async def stop(self) -> None: """停止后台清理任务""" if self._cleanup_task: self._cleanup_task.cancel() try: await self._cleanup_task except asyncio.CancelledError: pass self._cleanup_task = None async def _cleanup_loop(self) -> None: """清理循环""" while True: await asyncio.sleep(self._cleanup_interval) self._cleanup_expired() def _cleanup_expired(self) -> int: """清理过期条目""" expired_keys = [k for k, v in self._cache.items() if v.is_expired] for key in expired_keys: del self._cache[key] if expired_keys: logger.debug("cache_cleanup", removed=len(expired_keys)) return len(expired_keys) def _evict_lru(self) -> None: """LRU 淘汰""" if len(self._cache) < self._max_size: return # 按命中次数和创建时间排序,淘汰最少使用的 sorted_entries = sorted( self._cache.items(), key=lambda x: (x[1].hits, x[1].created_at), ) # 淘汰 10% 的条目 evict_count = max(1, len(sorted_entries) // 10) for key, _ in sorted_entries[:evict_count]: del self._cache[key] logger.debug("cache_eviction", evicted=evict_count) def get(self, key: str) -> T | None: """获取缓存值""" entry = self._cache.get(key) if entry is None: self._misses += 1 return None if entry.is_expired: del self._cache[key] self._misses += 1 return None entry.hits += 1 self._hits += 1 return entry.value def set(self, key: str, value: T, ttl: float | None = None) -> None: """设置缓存值""" if ttl is None: ttl = self._default_ttl # 检查容量 if len(self._cache) >= self._max_size: self._evict_lru() now = time.time() self._cache[key] = CacheEntry( key=key, value=value, created_at=now, expires_at=now + ttl, ) def delete(self, key: str) -> bool: """删除缓存""" if key in self._cache: del self._cache[key] return True return False def clear(self) -> int: """清空缓存""" count = len(self._cache) self._cache.clear() return count def exists(self, key: str) -> bool: """检查 key 是否存在且未过期""" entry = self._cache.get(key) if entry is None: return False if entry.is_expired: del self._cache[key] return False return True def get_stats(self) -> dict[str, Any]: """获取统计信息""" total = self._hits + self._misses hit_rate = self._hits / total if total > 0 else 0.0 return { "size": len(self._cache), "max_size": self._max_size, "hits": self._hits, "misses": self._misses, "hit_rate": round(hit_rate * 100, 2), "default_ttl": self._default_ttl, } def make_cache_key(*args: Any, **kwargs: Any) -> str: """生成缓存 key""" key_parts = [str(arg) for arg in args] key_parts.extend(f"{k}={v}" for k, v in sorted(kwargs.items())) key_string = ":".join(key_parts) return hashlib.md5(key_string.encode()).hexdigest() # 全局缓存实例 _response_cache: MemoryCache[dict[str, Any]] | None = None def get_response_cache() -> MemoryCache[dict[str, Any]]: """获取响应缓存""" global _response_cache if _response_cache is None: _response_cache = MemoryCache( max_size=500, default_ttl=300.0, # 5分钟 ) return _response_cache @dataclass class RateLimiter: """速率限制器 令牌桶算法实现 """ rate: float # 每秒允许的请求数 burst: int # 突发容量 _tokens: float = field(init=False) _last_update: float = field(init=False) def __post_init__(self) -> None: self._tokens = float(self.burst) self._last_update = time.time() def _refill(self) -> None: """补充令牌""" now = time.time() elapsed = now - self._last_update self._tokens = min(self.burst, self._tokens + elapsed * self.rate) self._last_update = now def acquire(self, tokens: int = 1) -> bool: """尝试获取令牌""" self._refill() if self._tokens >= tokens: self._tokens -= tokens return True return False async def wait(self, tokens: int = 1) -> None: """等待获取令牌""" while not self.acquire(tokens): # 计算需要等待的时间 needed = tokens - self._tokens wait_time = needed / self.rate await asyncio.sleep(min(wait_time, 0.1)) @property def available_tokens(self) -> float: """可用令牌数""" self._refill() return self._tokens # 全局速率限制器 _rate_limiters: dict[str, RateLimiter] = {} def get_rate_limiter(name: str, rate: float = 10.0, burst: int = 20) -> RateLimiter: """获取或创建速率限制器""" if name not in _rate_limiters: _rate_limiters[name] = RateLimiter(rate=rate, burst=burst) return _rate_limiters[name]