267 lines
7.0 KiB
Python
267 lines
7.0 KiB
Python
|
|
"""缓存模块
|
|||
|
|
|
|||
|
|
提供内存缓存和 TTL 管理
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import asyncio
|
|||
|
|
import hashlib
|
|||
|
|
import time
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from typing import Any, Generic, TypeVar
|
|||
|
|
|
|||
|
|
from minenasai.core.logging import get_logger
|
|||
|
|
|
|||
|
|
logger = get_logger(__name__)
|
|||
|
|
|
|||
|
|
T = TypeVar("T")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class CacheEntry(Generic[T]):
|
|||
|
|
"""缓存条目"""
|
|||
|
|
|
|||
|
|
key: str
|
|||
|
|
value: T
|
|||
|
|
created_at: float
|
|||
|
|
expires_at: float
|
|||
|
|
hits: int = 0
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def is_expired(self) -> bool:
|
|||
|
|
"""是否过期"""
|
|||
|
|
return time.time() > self.expires_at
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def ttl_remaining(self) -> float:
|
|||
|
|
"""剩余 TTL(秒)"""
|
|||
|
|
return max(0, self.expires_at - time.time())
|
|||
|
|
|
|||
|
|
|
|||
|
|
class MemoryCache(Generic[T]):
|
|||
|
|
"""内存缓存
|
|||
|
|
|
|||
|
|
支持 TTL、最大容量、LRU 淘汰
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
max_size: int = 1000,
|
|||
|
|
default_ttl: float = 300.0, # 5分钟
|
|||
|
|
cleanup_interval: float = 60.0, # 1分钟清理一次
|
|||
|
|
) -> None:
|
|||
|
|
self._cache: dict[str, CacheEntry[T]] = {}
|
|||
|
|
self._max_size = max_size
|
|||
|
|
self._default_ttl = default_ttl
|
|||
|
|
self._cleanup_interval = cleanup_interval
|
|||
|
|
self._cleanup_task: asyncio.Task[None] | None = None
|
|||
|
|
|
|||
|
|
# 统计
|
|||
|
|
self._hits = 0
|
|||
|
|
self._misses = 0
|
|||
|
|
|
|||
|
|
async def start(self) -> None:
|
|||
|
|
"""启动后台清理任务"""
|
|||
|
|
if self._cleanup_task is None:
|
|||
|
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
|||
|
|
logger.info("cache_cleanup_started", interval=self._cleanup_interval)
|
|||
|
|
|
|||
|
|
async def stop(self) -> None:
|
|||
|
|
"""停止后台清理任务"""
|
|||
|
|
if self._cleanup_task:
|
|||
|
|
self._cleanup_task.cancel()
|
|||
|
|
try:
|
|||
|
|
await self._cleanup_task
|
|||
|
|
except asyncio.CancelledError:
|
|||
|
|
pass
|
|||
|
|
self._cleanup_task = None
|
|||
|
|
|
|||
|
|
async def _cleanup_loop(self) -> None:
|
|||
|
|
"""清理循环"""
|
|||
|
|
while True:
|
|||
|
|
await asyncio.sleep(self._cleanup_interval)
|
|||
|
|
self._cleanup_expired()
|
|||
|
|
|
|||
|
|
def _cleanup_expired(self) -> int:
|
|||
|
|
"""清理过期条目"""
|
|||
|
|
expired_keys = [k for k, v in self._cache.items() if v.is_expired]
|
|||
|
|
for key in expired_keys:
|
|||
|
|
del self._cache[key]
|
|||
|
|
|
|||
|
|
if expired_keys:
|
|||
|
|
logger.debug("cache_cleanup", removed=len(expired_keys))
|
|||
|
|
|
|||
|
|
return len(expired_keys)
|
|||
|
|
|
|||
|
|
def _evict_lru(self) -> None:
|
|||
|
|
"""LRU 淘汰"""
|
|||
|
|
if len(self._cache) < self._max_size:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 按命中次数和创建时间排序,淘汰最少使用的
|
|||
|
|
sorted_entries = sorted(
|
|||
|
|
self._cache.items(),
|
|||
|
|
key=lambda x: (x[1].hits, x[1].created_at),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 淘汰 10% 的条目
|
|||
|
|
evict_count = max(1, len(sorted_entries) // 10)
|
|||
|
|
for key, _ in sorted_entries[:evict_count]:
|
|||
|
|
del self._cache[key]
|
|||
|
|
|
|||
|
|
logger.debug("cache_eviction", evicted=evict_count)
|
|||
|
|
|
|||
|
|
def get(self, key: str) -> T | None:
|
|||
|
|
"""获取缓存值"""
|
|||
|
|
entry = self._cache.get(key)
|
|||
|
|
|
|||
|
|
if entry is None:
|
|||
|
|
self._misses += 1
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
if entry.is_expired:
|
|||
|
|
del self._cache[key]
|
|||
|
|
self._misses += 1
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
entry.hits += 1
|
|||
|
|
self._hits += 1
|
|||
|
|
return entry.value
|
|||
|
|
|
|||
|
|
def set(self, key: str, value: T, ttl: float | None = None) -> None:
|
|||
|
|
"""设置缓存值"""
|
|||
|
|
if ttl is None:
|
|||
|
|
ttl = self._default_ttl
|
|||
|
|
|
|||
|
|
# 检查容量
|
|||
|
|
if len(self._cache) >= self._max_size:
|
|||
|
|
self._evict_lru()
|
|||
|
|
|
|||
|
|
now = time.time()
|
|||
|
|
self._cache[key] = CacheEntry(
|
|||
|
|
key=key,
|
|||
|
|
value=value,
|
|||
|
|
created_at=now,
|
|||
|
|
expires_at=now + ttl,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def delete(self, key: str) -> bool:
|
|||
|
|
"""删除缓存"""
|
|||
|
|
if key in self._cache:
|
|||
|
|
del self._cache[key]
|
|||
|
|
return True
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def clear(self) -> int:
|
|||
|
|
"""清空缓存"""
|
|||
|
|
count = len(self._cache)
|
|||
|
|
self._cache.clear()
|
|||
|
|
return count
|
|||
|
|
|
|||
|
|
def exists(self, key: str) -> bool:
|
|||
|
|
"""检查 key 是否存在且未过期"""
|
|||
|
|
entry = self._cache.get(key)
|
|||
|
|
if entry is None:
|
|||
|
|
return False
|
|||
|
|
if entry.is_expired:
|
|||
|
|
del self._cache[key]
|
|||
|
|
return False
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
def get_stats(self) -> dict[str, Any]:
|
|||
|
|
"""获取统计信息"""
|
|||
|
|
total = self._hits + self._misses
|
|||
|
|
hit_rate = self._hits / total if total > 0 else 0.0
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"size": len(self._cache),
|
|||
|
|
"max_size": self._max_size,
|
|||
|
|
"hits": self._hits,
|
|||
|
|
"misses": self._misses,
|
|||
|
|
"hit_rate": round(hit_rate * 100, 2),
|
|||
|
|
"default_ttl": self._default_ttl,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def make_cache_key(*args: Any, **kwargs: Any) -> str:
|
|||
|
|
"""生成缓存 key"""
|
|||
|
|
key_parts = [str(arg) for arg in args]
|
|||
|
|
key_parts.extend(f"{k}={v}" for k, v in sorted(kwargs.items()))
|
|||
|
|
key_string = ":".join(key_parts)
|
|||
|
|
return hashlib.md5(key_string.encode()).hexdigest()
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局缓存实例
|
|||
|
|
_response_cache: MemoryCache[dict[str, Any]] | None = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_response_cache() -> MemoryCache[dict[str, Any]]:
|
|||
|
|
"""获取响应缓存"""
|
|||
|
|
global _response_cache
|
|||
|
|
if _response_cache is None:
|
|||
|
|
_response_cache = MemoryCache(
|
|||
|
|
max_size=500,
|
|||
|
|
default_ttl=300.0, # 5分钟
|
|||
|
|
)
|
|||
|
|
return _response_cache
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class RateLimiter:
|
|||
|
|
"""速率限制器
|
|||
|
|
|
|||
|
|
令牌桶算法实现
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
rate: float # 每秒允许的请求数
|
|||
|
|
burst: int # 突发容量
|
|||
|
|
|
|||
|
|
_tokens: float = field(init=False)
|
|||
|
|
_last_update: float = field(init=False)
|
|||
|
|
|
|||
|
|
def __post_init__(self) -> None:
|
|||
|
|
self._tokens = float(self.burst)
|
|||
|
|
self._last_update = time.time()
|
|||
|
|
|
|||
|
|
def _refill(self) -> None:
|
|||
|
|
"""补充令牌"""
|
|||
|
|
now = time.time()
|
|||
|
|
elapsed = now - self._last_update
|
|||
|
|
self._tokens = min(self.burst, self._tokens + elapsed * self.rate)
|
|||
|
|
self._last_update = now
|
|||
|
|
|
|||
|
|
def acquire(self, tokens: int = 1) -> bool:
|
|||
|
|
"""尝试获取令牌"""
|
|||
|
|
self._refill()
|
|||
|
|
|
|||
|
|
if self._tokens >= tokens:
|
|||
|
|
self._tokens -= tokens
|
|||
|
|
return True
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
async def wait(self, tokens: int = 1) -> None:
|
|||
|
|
"""等待获取令牌"""
|
|||
|
|
while not self.acquire(tokens):
|
|||
|
|
# 计算需要等待的时间
|
|||
|
|
needed = tokens - self._tokens
|
|||
|
|
wait_time = needed / self.rate
|
|||
|
|
await asyncio.sleep(min(wait_time, 0.1))
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def available_tokens(self) -> float:
|
|||
|
|
"""可用令牌数"""
|
|||
|
|
self._refill()
|
|||
|
|
return self._tokens
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局速率限制器
|
|||
|
|
_rate_limiters: dict[str, RateLimiter] = {}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_rate_limiter(name: str, rate: float = 10.0, burst: int = 20) -> RateLimiter:
|
|||
|
|
"""获取或创建速率限制器"""
|
|||
|
|
if name not in _rate_limiters:
|
|||
|
|
_rate_limiters[name] = RateLimiter(rate=rate, burst=burst)
|
|||
|
|
return _rate_limiters[name]
|