Files
AIChatRoom/backend/utils/rate_limiter.py

234 lines
6.5 KiB
Python
Raw Permalink Normal View History

"""
速率限制器模块
使用令牌桶算法控制请求频率
"""
import asyncio
import time
from typing import Dict, Optional
from dataclasses import dataclass, field
from loguru import logger
@dataclass
class TokenBucket:
"""令牌桶"""
capacity: int # 桶容量
tokens: float = field(init=False) # 当前令牌数
refill_rate: float # 每秒填充速率
last_refill: float = field(default_factory=time.time)
def __post_init__(self):
self.tokens = float(self.capacity)
def _refill(self) -> None:
"""填充令牌"""
now = time.time()
elapsed = now - self.last_refill
self.tokens = min(
self.capacity,
self.tokens + elapsed * self.refill_rate
)
self.last_refill = now
def consume(self, tokens: int = 1) -> bool:
"""
尝试消费令牌
Args:
tokens: 要消费的令牌数
Returns:
是否消费成功
"""
self._refill()
if self.tokens >= tokens:
self.tokens -= tokens
return True
return False
def wait_time(self, tokens: int = 1) -> float:
"""
计算需要等待的时间
Args:
tokens: 需要的令牌数
Returns:
需要等待的秒数
"""
self._refill()
if self.tokens >= tokens:
return 0.0
needed = tokens - self.tokens
return needed / self.refill_rate
class RateLimiter:
"""
速率限制器
管理多个提供商的速率限制
"""
def __init__(self):
self._buckets: Dict[str, TokenBucket] = {}
self._locks: Dict[str, asyncio.Lock] = {}
def register(
self,
provider_id: str,
requests_per_minute: int = 60,
tokens_per_minute: int = 100000
) -> None:
"""
注册提供商的速率限制
Args:
provider_id: 提供商ID
requests_per_minute: 每分钟请求数
tokens_per_minute: 每分钟token数
"""
# 请求限制桶
self._buckets[f"{provider_id}:requests"] = TokenBucket(
capacity=requests_per_minute,
refill_rate=requests_per_minute / 60.0
)
# Token限制桶
self._buckets[f"{provider_id}:tokens"] = TokenBucket(
capacity=tokens_per_minute,
refill_rate=tokens_per_minute / 60.0
)
# 创建锁
self._locks[provider_id] = asyncio.Lock()
logger.debug(
f"注册速率限制: {provider_id} - "
f"{requests_per_minute}请求/分钟, "
f"{tokens_per_minute}tokens/分钟"
)
def unregister(self, provider_id: str) -> None:
"""
取消注册提供商的速率限制
Args:
provider_id: 提供商ID
"""
self._buckets.pop(f"{provider_id}:requests", None)
self._buckets.pop(f"{provider_id}:tokens", None)
self._locks.pop(provider_id, None)
async def acquire(
self,
provider_id: str,
estimated_tokens: int = 1
) -> bool:
"""
获取请求许可非阻塞
Args:
provider_id: 提供商ID
estimated_tokens: 预估token数
Returns:
是否获取成功
"""
request_bucket = self._buckets.get(f"{provider_id}:requests")
token_bucket = self._buckets.get(f"{provider_id}:tokens")
if not request_bucket or not token_bucket:
# 未注册,默认允许
return True
lock = self._locks.get(provider_id)
if lock:
async with lock:
if request_bucket.consume(1) and token_bucket.consume(estimated_tokens):
return True
return False
async def acquire_wait(
self,
provider_id: str,
estimated_tokens: int = 1,
max_wait: float = 60.0
) -> bool:
"""
获取请求许可阻塞等待
Args:
provider_id: 提供商ID
estimated_tokens: 预估token数
max_wait: 最大等待时间()
Returns:
是否获取成功
"""
request_bucket = self._buckets.get(f"{provider_id}:requests")
token_bucket = self._buckets.get(f"{provider_id}:tokens")
if not request_bucket or not token_bucket:
return True
lock = self._locks.get(provider_id)
if not lock:
return True
start_time = time.time()
while True:
async with lock:
# 计算需要等待的时间
request_wait = request_bucket.wait_time(1)
token_wait = token_bucket.wait_time(estimated_tokens)
wait_time = max(request_wait, token_wait)
if wait_time == 0:
request_bucket.consume(1)
token_bucket.consume(estimated_tokens)
return True
# 检查是否超时
elapsed = time.time() - start_time
if elapsed + wait_time > max_wait:
logger.warning(
f"速率限制等待超时: {provider_id}, "
f"需要等待{wait_time:.2f}"
)
return False
# 在锁外等待
await asyncio.sleep(min(wait_time, 1.0))
def get_status(self, provider_id: str) -> Optional[Dict[str, any]]:
"""
获取提供商的速率限制状态
Args:
provider_id: 提供商ID
Returns:
状态字典
"""
request_bucket = self._buckets.get(f"{provider_id}:requests")
token_bucket = self._buckets.get(f"{provider_id}:tokens")
if not request_bucket or not token_bucket:
return None
request_bucket._refill()
token_bucket._refill()
return {
"requests_remaining": int(request_bucket.tokens),
"requests_capacity": request_bucket.capacity,
"tokens_remaining": int(token_bucket.tokens),
"tokens_capacity": token_bucket.capacity
}
# 全局速率限制器实例
rate_limiter = RateLimiter()