Files
AIChatRoom/backend/utils/rate_limiter.py
Claude Code edbddf855d feat: AI聊天室多Agent协作讨论平台
- 实现Agent管理,支持AI辅助生成系统提示词
- 支持多个AI提供商(OpenRouter、智谱、MiniMax等)
- 实现聊天室和讨论引擎
- WebSocket实时消息推送
- 前端使用React + Ant Design
- 后端使用FastAPI + MongoDB

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 19:20:02 +08:00

234 lines
6.5 KiB
Python

"""
速率限制器模块
使用令牌桶算法控制请求频率
"""
import asyncio
import time
from typing import Dict, Optional
from dataclasses import dataclass, field
from loguru import logger
@dataclass
class TokenBucket:
"""令牌桶"""
capacity: int # 桶容量
tokens: float = field(init=False) # 当前令牌数
refill_rate: float # 每秒填充速率
last_refill: float = field(default_factory=time.time)
def __post_init__(self):
self.tokens = float(self.capacity)
def _refill(self) -> None:
"""填充令牌"""
now = time.time()
elapsed = now - self.last_refill
self.tokens = min(
self.capacity,
self.tokens + elapsed * self.refill_rate
)
self.last_refill = now
def consume(self, tokens: int = 1) -> bool:
"""
尝试消费令牌
Args:
tokens: 要消费的令牌数
Returns:
是否消费成功
"""
self._refill()
if self.tokens >= tokens:
self.tokens -= tokens
return True
return False
def wait_time(self, tokens: int = 1) -> float:
"""
计算需要等待的时间
Args:
tokens: 需要的令牌数
Returns:
需要等待的秒数
"""
self._refill()
if self.tokens >= tokens:
return 0.0
needed = tokens - self.tokens
return needed / self.refill_rate
class RateLimiter:
"""
速率限制器
管理多个提供商的速率限制
"""
def __init__(self):
self._buckets: Dict[str, TokenBucket] = {}
self._locks: Dict[str, asyncio.Lock] = {}
def register(
self,
provider_id: str,
requests_per_minute: int = 60,
tokens_per_minute: int = 100000
) -> None:
"""
注册提供商的速率限制
Args:
provider_id: 提供商ID
requests_per_minute: 每分钟请求数
tokens_per_minute: 每分钟token数
"""
# 请求限制桶
self._buckets[f"{provider_id}:requests"] = TokenBucket(
capacity=requests_per_minute,
refill_rate=requests_per_minute / 60.0
)
# Token限制桶
self._buckets[f"{provider_id}:tokens"] = TokenBucket(
capacity=tokens_per_minute,
refill_rate=tokens_per_minute / 60.0
)
# 创建锁
self._locks[provider_id] = asyncio.Lock()
logger.debug(
f"注册速率限制: {provider_id} - "
f"{requests_per_minute}请求/分钟, "
f"{tokens_per_minute}tokens/分钟"
)
def unregister(self, provider_id: str) -> None:
"""
取消注册提供商的速率限制
Args:
provider_id: 提供商ID
"""
self._buckets.pop(f"{provider_id}:requests", None)
self._buckets.pop(f"{provider_id}:tokens", None)
self._locks.pop(provider_id, None)
async def acquire(
self,
provider_id: str,
estimated_tokens: int = 1
) -> bool:
"""
获取请求许可(非阻塞)
Args:
provider_id: 提供商ID
estimated_tokens: 预估token数
Returns:
是否获取成功
"""
request_bucket = self._buckets.get(f"{provider_id}:requests")
token_bucket = self._buckets.get(f"{provider_id}:tokens")
if not request_bucket or not token_bucket:
# 未注册,默认允许
return True
lock = self._locks.get(provider_id)
if lock:
async with lock:
if request_bucket.consume(1) and token_bucket.consume(estimated_tokens):
return True
return False
async def acquire_wait(
self,
provider_id: str,
estimated_tokens: int = 1,
max_wait: float = 60.0
) -> bool:
"""
获取请求许可(阻塞等待)
Args:
provider_id: 提供商ID
estimated_tokens: 预估token数
max_wait: 最大等待时间(秒)
Returns:
是否获取成功
"""
request_bucket = self._buckets.get(f"{provider_id}:requests")
token_bucket = self._buckets.get(f"{provider_id}:tokens")
if not request_bucket or not token_bucket:
return True
lock = self._locks.get(provider_id)
if not lock:
return True
start_time = time.time()
while True:
async with lock:
# 计算需要等待的时间
request_wait = request_bucket.wait_time(1)
token_wait = token_bucket.wait_time(estimated_tokens)
wait_time = max(request_wait, token_wait)
if wait_time == 0:
request_bucket.consume(1)
token_bucket.consume(estimated_tokens)
return True
# 检查是否超时
elapsed = time.time() - start_time
if elapsed + wait_time > max_wait:
logger.warning(
f"速率限制等待超时: {provider_id}, "
f"需要等待{wait_time:.2f}"
)
return False
# 在锁外等待
await asyncio.sleep(min(wait_time, 1.0))
def get_status(self, provider_id: str) -> Optional[Dict[str, any]]:
"""
获取提供商的速率限制状态
Args:
provider_id: 提供商ID
Returns:
状态字典
"""
request_bucket = self._buckets.get(f"{provider_id}:requests")
token_bucket = self._buckets.get(f"{provider_id}:tokens")
if not request_bucket or not token_bucket:
return None
request_bucket._refill()
token_bucket._refill()
return {
"requests_remaining": int(request_bucket.tokens),
"requests_capacity": request_bucket.capacity,
"tokens_remaining": int(token_bucket.tokens),
"tokens_capacity": token_bucket.capacity
}
# 全局速率限制器实例
rate_limiter = RateLimiter()