Files
MineNasAI/src/minenasai/core/cache.py

267 lines
7.0 KiB
Python
Raw Normal View History

"""缓存模块
提供内存缓存和 TTL 管理
"""
from __future__ import annotations
import asyncio
import hashlib
import time
from dataclasses import dataclass, field
from typing import Any, Generic, TypeVar
from minenasai.core.logging import get_logger
logger = get_logger(__name__)
T = TypeVar("T")
@dataclass
class CacheEntry(Generic[T]):
"""缓存条目"""
key: str
value: T
created_at: float
expires_at: float
hits: int = 0
@property
def is_expired(self) -> bool:
"""是否过期"""
return time.time() > self.expires_at
@property
def ttl_remaining(self) -> float:
"""剩余 TTL"""
return max(0, self.expires_at - time.time())
class MemoryCache(Generic[T]):
"""内存缓存
支持 TTL最大容量LRU 淘汰
"""
def __init__(
self,
max_size: int = 1000,
default_ttl: float = 300.0, # 5分钟
cleanup_interval: float = 60.0, # 1分钟清理一次
) -> None:
self._cache: dict[str, CacheEntry[T]] = {}
self._max_size = max_size
self._default_ttl = default_ttl
self._cleanup_interval = cleanup_interval
self._cleanup_task: asyncio.Task[None] | None = None
# 统计
self._hits = 0
self._misses = 0
async def start(self) -> None:
"""启动后台清理任务"""
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info("cache_cleanup_started", interval=self._cleanup_interval)
async def stop(self) -> None:
"""停止后台清理任务"""
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
self._cleanup_task = None
async def _cleanup_loop(self) -> None:
"""清理循环"""
while True:
await asyncio.sleep(self._cleanup_interval)
self._cleanup_expired()
def _cleanup_expired(self) -> int:
"""清理过期条目"""
expired_keys = [k for k, v in self._cache.items() if v.is_expired]
for key in expired_keys:
del self._cache[key]
if expired_keys:
logger.debug("cache_cleanup", removed=len(expired_keys))
return len(expired_keys)
def _evict_lru(self) -> None:
"""LRU 淘汰"""
if len(self._cache) < self._max_size:
return
# 按命中次数和创建时间排序,淘汰最少使用的
sorted_entries = sorted(
self._cache.items(),
key=lambda x: (x[1].hits, x[1].created_at),
)
# 淘汰 10% 的条目
evict_count = max(1, len(sorted_entries) // 10)
for key, _ in sorted_entries[:evict_count]:
del self._cache[key]
logger.debug("cache_eviction", evicted=evict_count)
def get(self, key: str) -> T | None:
"""获取缓存值"""
entry = self._cache.get(key)
if entry is None:
self._misses += 1
return None
if entry.is_expired:
del self._cache[key]
self._misses += 1
return None
entry.hits += 1
self._hits += 1
return entry.value
def set(self, key: str, value: T, ttl: float | None = None) -> None:
"""设置缓存值"""
if ttl is None:
ttl = self._default_ttl
# 检查容量
if len(self._cache) >= self._max_size:
self._evict_lru()
now = time.time()
self._cache[key] = CacheEntry(
key=key,
value=value,
created_at=now,
expires_at=now + ttl,
)
def delete(self, key: str) -> bool:
"""删除缓存"""
if key in self._cache:
del self._cache[key]
return True
return False
def clear(self) -> int:
"""清空缓存"""
count = len(self._cache)
self._cache.clear()
return count
def exists(self, key: str) -> bool:
"""检查 key 是否存在且未过期"""
entry = self._cache.get(key)
if entry is None:
return False
if entry.is_expired:
del self._cache[key]
return False
return True
def get_stats(self) -> dict[str, Any]:
"""获取统计信息"""
total = self._hits + self._misses
hit_rate = self._hits / total if total > 0 else 0.0
return {
"size": len(self._cache),
"max_size": self._max_size,
"hits": self._hits,
"misses": self._misses,
"hit_rate": round(hit_rate * 100, 2),
"default_ttl": self._default_ttl,
}
def make_cache_key(*args: Any, **kwargs: Any) -> str:
"""生成缓存 key"""
key_parts = [str(arg) for arg in args]
key_parts.extend(f"{k}={v}" for k, v in sorted(kwargs.items()))
key_string = ":".join(key_parts)
return hashlib.md5(key_string.encode()).hexdigest()
# 全局缓存实例
_response_cache: MemoryCache[dict[str, Any]] | None = None
def get_response_cache() -> MemoryCache[dict[str, Any]]:
"""获取响应缓存"""
global _response_cache
if _response_cache is None:
_response_cache = MemoryCache(
max_size=500,
default_ttl=300.0, # 5分钟
)
return _response_cache
@dataclass
class RateLimiter:
"""速率限制器
令牌桶算法实现
"""
rate: float # 每秒允许的请求数
burst: int # 突发容量
_tokens: float = field(init=False)
_last_update: float = field(init=False)
def __post_init__(self) -> None:
self._tokens = float(self.burst)
self._last_update = time.time()
def _refill(self) -> None:
"""补充令牌"""
now = time.time()
elapsed = now - self._last_update
self._tokens = min(self.burst, self._tokens + elapsed * self.rate)
self._last_update = now
def acquire(self, tokens: int = 1) -> bool:
"""尝试获取令牌"""
self._refill()
if self._tokens >= tokens:
self._tokens -= tokens
return True
return False
async def wait(self, tokens: int = 1) -> None:
"""等待获取令牌"""
while not self.acquire(tokens):
# 计算需要等待的时间
needed = tokens - self._tokens
wait_time = needed / self.rate
await asyncio.sleep(min(wait_time, 0.1))
@property
def available_tokens(self) -> float:
"""可用令牌数"""
self._refill()
return self._tokens
# 全局速率限制器
_rate_limiters: dict[str, RateLimiter] = {}
def get_rate_limiter(name: str, rate: float = 10.0, burst: int = 20) -> RateLimiter:
"""获取或创建速率限制器"""
if name not in _rate_limiters:
_rate_limiters[name] = RateLimiter(rate=rate, burst=burst)
return _rate_limiters[name]