267 lines
7.0 KiB
Python
267 lines
7.0 KiB
Python
"""缓存模块
|
||
|
||
提供内存缓存和 TTL 管理
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import hashlib
|
||
import time
|
||
from dataclasses import dataclass, field
|
||
from typing import Any, Generic, TypeVar
|
||
|
||
from minenasai.core.logging import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
T = TypeVar("T")
|
||
|
||
|
||
@dataclass
|
||
class CacheEntry(Generic[T]):
|
||
"""缓存条目"""
|
||
|
||
key: str
|
||
value: T
|
||
created_at: float
|
||
expires_at: float
|
||
hits: int = 0
|
||
|
||
@property
|
||
def is_expired(self) -> bool:
|
||
"""是否过期"""
|
||
return time.time() > self.expires_at
|
||
|
||
@property
|
||
def ttl_remaining(self) -> float:
|
||
"""剩余 TTL(秒)"""
|
||
return max(0, self.expires_at - time.time())
|
||
|
||
|
||
class MemoryCache(Generic[T]):
|
||
"""内存缓存
|
||
|
||
支持 TTL、最大容量、LRU 淘汰
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
max_size: int = 1000,
|
||
default_ttl: float = 300.0, # 5分钟
|
||
cleanup_interval: float = 60.0, # 1分钟清理一次
|
||
) -> None:
|
||
self._cache: dict[str, CacheEntry[T]] = {}
|
||
self._max_size = max_size
|
||
self._default_ttl = default_ttl
|
||
self._cleanup_interval = cleanup_interval
|
||
self._cleanup_task: asyncio.Task[None] | None = None
|
||
|
||
# 统计
|
||
self._hits = 0
|
||
self._misses = 0
|
||
|
||
async def start(self) -> None:
|
||
"""启动后台清理任务"""
|
||
if self._cleanup_task is None:
|
||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||
logger.info("cache_cleanup_started", interval=self._cleanup_interval)
|
||
|
||
async def stop(self) -> None:
|
||
"""停止后台清理任务"""
|
||
if self._cleanup_task:
|
||
self._cleanup_task.cancel()
|
||
try:
|
||
await self._cleanup_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
self._cleanup_task = None
|
||
|
||
async def _cleanup_loop(self) -> None:
|
||
"""清理循环"""
|
||
while True:
|
||
await asyncio.sleep(self._cleanup_interval)
|
||
self._cleanup_expired()
|
||
|
||
def _cleanup_expired(self) -> int:
|
||
"""清理过期条目"""
|
||
expired_keys = [k for k, v in self._cache.items() if v.is_expired]
|
||
for key in expired_keys:
|
||
del self._cache[key]
|
||
|
||
if expired_keys:
|
||
logger.debug("cache_cleanup", removed=len(expired_keys))
|
||
|
||
return len(expired_keys)
|
||
|
||
def _evict_lru(self) -> None:
|
||
"""LRU 淘汰"""
|
||
if len(self._cache) < self._max_size:
|
||
return
|
||
|
||
# 按命中次数和创建时间排序,淘汰最少使用的
|
||
sorted_entries = sorted(
|
||
self._cache.items(),
|
||
key=lambda x: (x[1].hits, x[1].created_at),
|
||
)
|
||
|
||
# 淘汰 10% 的条目
|
||
evict_count = max(1, len(sorted_entries) // 10)
|
||
for key, _ in sorted_entries[:evict_count]:
|
||
del self._cache[key]
|
||
|
||
logger.debug("cache_eviction", evicted=evict_count)
|
||
|
||
def get(self, key: str) -> T | None:
|
||
"""获取缓存值"""
|
||
entry = self._cache.get(key)
|
||
|
||
if entry is None:
|
||
self._misses += 1
|
||
return None
|
||
|
||
if entry.is_expired:
|
||
del self._cache[key]
|
||
self._misses += 1
|
||
return None
|
||
|
||
entry.hits += 1
|
||
self._hits += 1
|
||
return entry.value
|
||
|
||
def set(self, key: str, value: T, ttl: float | None = None) -> None:
|
||
"""设置缓存值"""
|
||
if ttl is None:
|
||
ttl = self._default_ttl
|
||
|
||
# 检查容量
|
||
if len(self._cache) >= self._max_size:
|
||
self._evict_lru()
|
||
|
||
now = time.time()
|
||
self._cache[key] = CacheEntry(
|
||
key=key,
|
||
value=value,
|
||
created_at=now,
|
||
expires_at=now + ttl,
|
||
)
|
||
|
||
def delete(self, key: str) -> bool:
|
||
"""删除缓存"""
|
||
if key in self._cache:
|
||
del self._cache[key]
|
||
return True
|
||
return False
|
||
|
||
def clear(self) -> int:
|
||
"""清空缓存"""
|
||
count = len(self._cache)
|
||
self._cache.clear()
|
||
return count
|
||
|
||
def exists(self, key: str) -> bool:
|
||
"""检查 key 是否存在且未过期"""
|
||
entry = self._cache.get(key)
|
||
if entry is None:
|
||
return False
|
||
if entry.is_expired:
|
||
del self._cache[key]
|
||
return False
|
||
return True
|
||
|
||
def get_stats(self) -> dict[str, Any]:
|
||
"""获取统计信息"""
|
||
total = self._hits + self._misses
|
||
hit_rate = self._hits / total if total > 0 else 0.0
|
||
|
||
return {
|
||
"size": len(self._cache),
|
||
"max_size": self._max_size,
|
||
"hits": self._hits,
|
||
"misses": self._misses,
|
||
"hit_rate": round(hit_rate * 100, 2),
|
||
"default_ttl": self._default_ttl,
|
||
}
|
||
|
||
|
||
def make_cache_key(*args: Any, **kwargs: Any) -> str:
|
||
"""生成缓存 key"""
|
||
key_parts = [str(arg) for arg in args]
|
||
key_parts.extend(f"{k}={v}" for k, v in sorted(kwargs.items()))
|
||
key_string = ":".join(key_parts)
|
||
return hashlib.md5(key_string.encode()).hexdigest()
|
||
|
||
|
||
# 全局缓存实例
|
||
_response_cache: MemoryCache[dict[str, Any]] | None = None
|
||
|
||
|
||
def get_response_cache() -> MemoryCache[dict[str, Any]]:
|
||
"""获取响应缓存"""
|
||
global _response_cache
|
||
if _response_cache is None:
|
||
_response_cache = MemoryCache(
|
||
max_size=500,
|
||
default_ttl=300.0, # 5分钟
|
||
)
|
||
return _response_cache
|
||
|
||
|
||
@dataclass
|
||
class RateLimiter:
|
||
"""速率限制器
|
||
|
||
令牌桶算法实现
|
||
"""
|
||
|
||
rate: float # 每秒允许的请求数
|
||
burst: int # 突发容量
|
||
|
||
_tokens: float = field(init=False)
|
||
_last_update: float = field(init=False)
|
||
|
||
def __post_init__(self) -> None:
|
||
self._tokens = float(self.burst)
|
||
self._last_update = time.time()
|
||
|
||
def _refill(self) -> None:
|
||
"""补充令牌"""
|
||
now = time.time()
|
||
elapsed = now - self._last_update
|
||
self._tokens = min(self.burst, self._tokens + elapsed * self.rate)
|
||
self._last_update = now
|
||
|
||
def acquire(self, tokens: int = 1) -> bool:
|
||
"""尝试获取令牌"""
|
||
self._refill()
|
||
|
||
if self._tokens >= tokens:
|
||
self._tokens -= tokens
|
||
return True
|
||
return False
|
||
|
||
async def wait(self, tokens: int = 1) -> None:
|
||
"""等待获取令牌"""
|
||
while not self.acquire(tokens):
|
||
# 计算需要等待的时间
|
||
needed = tokens - self._tokens
|
||
wait_time = needed / self.rate
|
||
await asyncio.sleep(min(wait_time, 0.1))
|
||
|
||
@property
|
||
def available_tokens(self) -> float:
|
||
"""可用令牌数"""
|
||
self._refill()
|
||
return self._tokens
|
||
|
||
|
||
# 全局速率限制器
|
||
_rate_limiters: dict[str, RateLimiter] = {}
|
||
|
||
|
||
def get_rate_limiter(name: str, rate: float = 10.0, burst: int = 20) -> RateLimiter:
|
||
"""获取或创建速率限制器"""
|
||
if name not in _rate_limiters:
|
||
_rate_limiters[name] = RateLimiter(rate=rate, burst=burst)
|
||
return _rate_limiters[name]
|