Files
MineNasAI/src/minenasai/core/cache.py

267 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""缓存模块
提供内存缓存和 TTL 管理
"""
from __future__ import annotations
import asyncio
import hashlib
import time
from dataclasses import dataclass, field
from typing import Any, Generic, TypeVar
from minenasai.core.logging import get_logger
logger = get_logger(__name__)
T = TypeVar("T")
@dataclass
class CacheEntry(Generic[T]):
"""缓存条目"""
key: str
value: T
created_at: float
expires_at: float
hits: int = 0
@property
def is_expired(self) -> bool:
"""是否过期"""
return time.time() > self.expires_at
@property
def ttl_remaining(self) -> float:
"""剩余 TTL"""
return max(0, self.expires_at - time.time())
class MemoryCache(Generic[T]):
"""内存缓存
支持 TTL、最大容量、LRU 淘汰
"""
def __init__(
self,
max_size: int = 1000,
default_ttl: float = 300.0, # 5分钟
cleanup_interval: float = 60.0, # 1分钟清理一次
) -> None:
self._cache: dict[str, CacheEntry[T]] = {}
self._max_size = max_size
self._default_ttl = default_ttl
self._cleanup_interval = cleanup_interval
self._cleanup_task: asyncio.Task[None] | None = None
# 统计
self._hits = 0
self._misses = 0
async def start(self) -> None:
"""启动后台清理任务"""
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info("cache_cleanup_started", interval=self._cleanup_interval)
async def stop(self) -> None:
"""停止后台清理任务"""
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
self._cleanup_task = None
async def _cleanup_loop(self) -> None:
"""清理循环"""
while True:
await asyncio.sleep(self._cleanup_interval)
self._cleanup_expired()
def _cleanup_expired(self) -> int:
"""清理过期条目"""
expired_keys = [k for k, v in self._cache.items() if v.is_expired]
for key in expired_keys:
del self._cache[key]
if expired_keys:
logger.debug("cache_cleanup", removed=len(expired_keys))
return len(expired_keys)
def _evict_lru(self) -> None:
"""LRU 淘汰"""
if len(self._cache) < self._max_size:
return
# 按命中次数和创建时间排序,淘汰最少使用的
sorted_entries = sorted(
self._cache.items(),
key=lambda x: (x[1].hits, x[1].created_at),
)
# 淘汰 10% 的条目
evict_count = max(1, len(sorted_entries) // 10)
for key, _ in sorted_entries[:evict_count]:
del self._cache[key]
logger.debug("cache_eviction", evicted=evict_count)
def get(self, key: str) -> T | None:
"""获取缓存值"""
entry = self._cache.get(key)
if entry is None:
self._misses += 1
return None
if entry.is_expired:
del self._cache[key]
self._misses += 1
return None
entry.hits += 1
self._hits += 1
return entry.value
def set(self, key: str, value: T, ttl: float | None = None) -> None:
"""设置缓存值"""
if ttl is None:
ttl = self._default_ttl
# 检查容量
if len(self._cache) >= self._max_size:
self._evict_lru()
now = time.time()
self._cache[key] = CacheEntry(
key=key,
value=value,
created_at=now,
expires_at=now + ttl,
)
def delete(self, key: str) -> bool:
"""删除缓存"""
if key in self._cache:
del self._cache[key]
return True
return False
def clear(self) -> int:
"""清空缓存"""
count = len(self._cache)
self._cache.clear()
return count
def exists(self, key: str) -> bool:
"""检查 key 是否存在且未过期"""
entry = self._cache.get(key)
if entry is None:
return False
if entry.is_expired:
del self._cache[key]
return False
return True
def get_stats(self) -> dict[str, Any]:
"""获取统计信息"""
total = self._hits + self._misses
hit_rate = self._hits / total if total > 0 else 0.0
return {
"size": len(self._cache),
"max_size": self._max_size,
"hits": self._hits,
"misses": self._misses,
"hit_rate": round(hit_rate * 100, 2),
"default_ttl": self._default_ttl,
}
def make_cache_key(*args: Any, **kwargs: Any) -> str:
"""生成缓存 key"""
key_parts = [str(arg) for arg in args]
key_parts.extend(f"{k}={v}" for k, v in sorted(kwargs.items()))
key_string = ":".join(key_parts)
return hashlib.md5(key_string.encode()).hexdigest()
# 全局缓存实例
_response_cache: MemoryCache[dict[str, Any]] | None = None
def get_response_cache() -> MemoryCache[dict[str, Any]]:
"""获取响应缓存"""
global _response_cache
if _response_cache is None:
_response_cache = MemoryCache(
max_size=500,
default_ttl=300.0, # 5分钟
)
return _response_cache
@dataclass
class RateLimiter:
"""速率限制器
令牌桶算法实现
"""
rate: float # 每秒允许的请求数
burst: int # 突发容量
_tokens: float = field(init=False)
_last_update: float = field(init=False)
def __post_init__(self) -> None:
self._tokens = float(self.burst)
self._last_update = time.time()
def _refill(self) -> None:
"""补充令牌"""
now = time.time()
elapsed = now - self._last_update
self._tokens = min(self.burst, self._tokens + elapsed * self.rate)
self._last_update = now
def acquire(self, tokens: int = 1) -> bool:
"""尝试获取令牌"""
self._refill()
if self._tokens >= tokens:
self._tokens -= tokens
return True
return False
async def wait(self, tokens: int = 1) -> None:
"""等待获取令牌"""
while not self.acquire(tokens):
# 计算需要等待的时间
needed = tokens - self._tokens
wait_time = needed / self.rate
await asyncio.sleep(min(wait_time, 0.1))
@property
def available_tokens(self) -> float:
"""可用令牌数"""
self._refill()
return self._tokens
# 全局速率限制器
_rate_limiters: dict[str, RateLimiter] = {}
def get_rate_limiter(name: str, rate: float = 10.0, burst: int = 20) -> RateLimiter:
"""获取或创建速率限制器"""
if name not in _rate_limiters:
_rate_limiters[name] = RateLimiter(rate=rate, burst=burst)
return _rate_limiters[name]