Files
MineNasAI/tests/test_cache.py

221 lines
5.7 KiB
Python

"""缓存模块测试"""
from __future__ import annotations
import time
import pytest
from minenasai.core.cache import (
MemoryCache,
RateLimiter,
get_rate_limiter,
get_response_cache,
make_cache_key,
)
class TestMemoryCache:
"""MemoryCache 测试"""
def test_set_and_get(self):
"""测试设置和获取"""
cache: MemoryCache[str] = MemoryCache()
cache.set("key1", "value1")
assert cache.get("key1") == "value1"
def test_get_nonexistent(self):
"""测试获取不存在的 key"""
cache: MemoryCache[str] = MemoryCache()
assert cache.get("nonexistent") is None
def test_ttl_expiration(self):
"""测试 TTL 过期"""
cache: MemoryCache[str] = MemoryCache(default_ttl=0.1)
cache.set("key1", "value1")
assert cache.get("key1") == "value1"
time.sleep(0.15)
assert cache.get("key1") is None
def test_custom_ttl(self):
"""测试自定义 TTL"""
cache: MemoryCache[str] = MemoryCache(default_ttl=10.0)
cache.set("key1", "value1", ttl=0.1)
time.sleep(0.15)
assert cache.get("key1") is None
def test_delete(self):
"""测试删除"""
cache: MemoryCache[str] = MemoryCache()
cache.set("key1", "value1")
assert cache.delete("key1") is True
assert cache.get("key1") is None
assert cache.delete("key1") is False
def test_clear(self):
"""测试清空"""
cache: MemoryCache[str] = MemoryCache()
cache.set("key1", "value1")
cache.set("key2", "value2")
count = cache.clear()
assert count == 2
assert cache.get("key1") is None
assert cache.get("key2") is None
def test_exists(self):
"""测试存在检查"""
cache: MemoryCache[str] = MemoryCache()
cache.set("key1", "value1")
assert cache.exists("key1") is True
assert cache.exists("key2") is False
def test_max_size_eviction(self):
"""测试最大容量淘汰"""
cache: MemoryCache[int] = MemoryCache(max_size=5)
for i in range(10):
cache.set(f"key{i}", i)
# 应该只保留部分
assert len(cache._cache) <= 5
def test_hit_tracking(self):
"""测试命中跟踪"""
cache: MemoryCache[str] = MemoryCache()
cache.set("key1", "value1")
cache.get("key1")
cache.get("key1")
cache.get("nonexistent")
stats = cache.get_stats()
assert stats["hits"] == 2
assert stats["misses"] == 1
def test_get_stats(self):
"""测试获取统计"""
cache: MemoryCache[str] = MemoryCache(max_size=100, default_ttl=60.0)
cache.set("key1", "value1")
cache.get("key1")
stats = cache.get_stats()
assert stats["size"] == 1
assert stats["max_size"] == 100
assert stats["default_ttl"] == 60.0
assert "hit_rate" in stats
class TestRateLimiter:
"""RateLimiter 测试"""
def test_acquire_within_limit(self):
"""测试在限制内获取"""
limiter = RateLimiter(rate=10.0, burst=5)
# 可以获取 burst 数量的令牌
for _ in range(5):
assert limiter.acquire() is True
def test_acquire_exceeds_limit(self):
"""测试超出限制"""
limiter = RateLimiter(rate=10.0, burst=2)
assert limiter.acquire() is True
assert limiter.acquire() is True
assert limiter.acquire() is False
def test_token_refill(self):
"""测试令牌补充"""
limiter = RateLimiter(rate=100.0, burst=2)
# 消耗所有令牌
limiter.acquire()
limiter.acquire()
assert limiter.acquire() is False
# 等待补充
time.sleep(0.05)
assert limiter.acquire() is True
def test_available_tokens(self):
"""测试可用令牌数"""
limiter = RateLimiter(rate=10.0, burst=5)
assert limiter.available_tokens == pytest.approx(5.0, abs=0.1)
limiter.acquire(2)
assert limiter.available_tokens == pytest.approx(3.0, abs=0.1)
@pytest.mark.asyncio
async def test_wait(self):
"""测试等待获取"""
limiter = RateLimiter(rate=100.0, burst=1)
limiter.acquire()
start = time.time()
await limiter.wait()
elapsed = time.time() - start
# 应该等待了一小段时间
assert elapsed > 0
class TestCacheKey:
"""make_cache_key 测试"""
def test_same_args_same_key(self):
"""测试相同参数生成相同 key"""
key1 = make_cache_key("a", "b", c=1)
key2 = make_cache_key("a", "b", c=1)
assert key1 == key2
def test_different_args_different_key(self):
"""测试不同参数生成不同 key"""
key1 = make_cache_key("a", "b")
key2 = make_cache_key("a", "c")
assert key1 != key2
def test_kwargs_order_independent(self):
"""测试 kwargs 顺序无关"""
key1 = make_cache_key(a=1, b=2)
key2 = make_cache_key(b=2, a=1)
assert key1 == key2
class TestGlobalInstances:
"""全局实例测试"""
def test_get_response_cache(self):
"""测试获取响应缓存"""
cache = get_response_cache()
assert isinstance(cache, MemoryCache)
def test_get_rate_limiter(self):
"""测试获取速率限制器"""
limiter = get_rate_limiter("test", rate=10.0, burst=20)
assert isinstance(limiter, RateLimiter)
def test_get_rate_limiter_reuse(self):
"""测试速率限制器复用"""
limiter1 = get_rate_limiter("shared")
limiter2 = get_rate_limiter("shared")
assert limiter1 is limiter2