""" 速率限制器模块 使用令牌桶算法控制请求频率 """ import asyncio import time from typing import Dict, Optional from dataclasses import dataclass, field from loguru import logger @dataclass class TokenBucket: """令牌桶""" capacity: int # 桶容量 tokens: float = field(init=False) # 当前令牌数 refill_rate: float # 每秒填充速率 last_refill: float = field(default_factory=time.time) def __post_init__(self): self.tokens = float(self.capacity) def _refill(self) -> None: """填充令牌""" now = time.time() elapsed = now - self.last_refill self.tokens = min( self.capacity, self.tokens + elapsed * self.refill_rate ) self.last_refill = now def consume(self, tokens: int = 1) -> bool: """ 尝试消费令牌 Args: tokens: 要消费的令牌数 Returns: 是否消费成功 """ self._refill() if self.tokens >= tokens: self.tokens -= tokens return True return False def wait_time(self, tokens: int = 1) -> float: """ 计算需要等待的时间 Args: tokens: 需要的令牌数 Returns: 需要等待的秒数 """ self._refill() if self.tokens >= tokens: return 0.0 needed = tokens - self.tokens return needed / self.refill_rate class RateLimiter: """ 速率限制器 管理多个提供商的速率限制 """ def __init__(self): self._buckets: Dict[str, TokenBucket] = {} self._locks: Dict[str, asyncio.Lock] = {} def register( self, provider_id: str, requests_per_minute: int = 60, tokens_per_minute: int = 100000 ) -> None: """ 注册提供商的速率限制 Args: provider_id: 提供商ID requests_per_minute: 每分钟请求数 tokens_per_minute: 每分钟token数 """ # 请求限制桶 self._buckets[f"{provider_id}:requests"] = TokenBucket( capacity=requests_per_minute, refill_rate=requests_per_minute / 60.0 ) # Token限制桶 self._buckets[f"{provider_id}:tokens"] = TokenBucket( capacity=tokens_per_minute, refill_rate=tokens_per_minute / 60.0 ) # 创建锁 self._locks[provider_id] = asyncio.Lock() logger.debug( f"注册速率限制: {provider_id} - " f"{requests_per_minute}请求/分钟, " f"{tokens_per_minute}tokens/分钟" ) def unregister(self, provider_id: str) -> None: """ 取消注册提供商的速率限制 Args: provider_id: 提供商ID """ self._buckets.pop(f"{provider_id}:requests", None) self._buckets.pop(f"{provider_id}:tokens", None) self._locks.pop(provider_id, None) async def acquire( self, provider_id: str, estimated_tokens: int = 1 ) -> bool: """ 获取请求许可(非阻塞) Args: provider_id: 提供商ID estimated_tokens: 预估token数 Returns: 是否获取成功 """ request_bucket = self._buckets.get(f"{provider_id}:requests") token_bucket = self._buckets.get(f"{provider_id}:tokens") if not request_bucket or not token_bucket: # 未注册,默认允许 return True lock = self._locks.get(provider_id) if lock: async with lock: if request_bucket.consume(1) and token_bucket.consume(estimated_tokens): return True return False async def acquire_wait( self, provider_id: str, estimated_tokens: int = 1, max_wait: float = 60.0 ) -> bool: """ 获取请求许可(阻塞等待) Args: provider_id: 提供商ID estimated_tokens: 预估token数 max_wait: 最大等待时间(秒) Returns: 是否获取成功 """ request_bucket = self._buckets.get(f"{provider_id}:requests") token_bucket = self._buckets.get(f"{provider_id}:tokens") if not request_bucket or not token_bucket: return True lock = self._locks.get(provider_id) if not lock: return True start_time = time.time() while True: async with lock: # 计算需要等待的时间 request_wait = request_bucket.wait_time(1) token_wait = token_bucket.wait_time(estimated_tokens) wait_time = max(request_wait, token_wait) if wait_time == 0: request_bucket.consume(1) token_bucket.consume(estimated_tokens) return True # 检查是否超时 elapsed = time.time() - start_time if elapsed + wait_time > max_wait: logger.warning( f"速率限制等待超时: {provider_id}, " f"需要等待{wait_time:.2f}秒" ) return False # 在锁外等待 await asyncio.sleep(min(wait_time, 1.0)) def get_status(self, provider_id: str) -> Optional[Dict[str, any]]: """ 获取提供商的速率限制状态 Args: provider_id: 提供商ID Returns: 状态字典 """ request_bucket = self._buckets.get(f"{provider_id}:requests") token_bucket = self._buckets.get(f"{provider_id}:tokens") if not request_bucket or not token_bucket: return None request_bucket._refill() token_bucket._refill() return { "requests_remaining": int(request_bucket.tokens), "requests_capacity": request_bucket.capacity, "tokens_remaining": int(token_bucket.tokens), "tokens_capacity": token_bucket.capacity } # 全局速率限制器实例 rate_limiter = RateLimiter()