- 实现Agent管理,支持AI辅助生成系统提示词 - 支持多个AI提供商(OpenRouter、智谱、MiniMax等) - 实现聊天室和讨论引擎 - WebSocket实时消息推送 - 前端使用React + Ant Design - 后端使用FastAPI + MongoDB Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
234 lines
6.5 KiB
Python
234 lines
6.5 KiB
Python
"""
|
|
速率限制器模块
|
|
使用令牌桶算法控制请求频率
|
|
"""
|
|
import asyncio
|
|
import time
|
|
from typing import Dict, Optional
|
|
from dataclasses import dataclass, field
|
|
from loguru import logger
|
|
|
|
|
|
@dataclass
|
|
class TokenBucket:
|
|
"""令牌桶"""
|
|
capacity: int # 桶容量
|
|
tokens: float = field(init=False) # 当前令牌数
|
|
refill_rate: float # 每秒填充速率
|
|
last_refill: float = field(default_factory=time.time)
|
|
|
|
def __post_init__(self):
|
|
self.tokens = float(self.capacity)
|
|
|
|
def _refill(self) -> None:
|
|
"""填充令牌"""
|
|
now = time.time()
|
|
elapsed = now - self.last_refill
|
|
self.tokens = min(
|
|
self.capacity,
|
|
self.tokens + elapsed * self.refill_rate
|
|
)
|
|
self.last_refill = now
|
|
|
|
def consume(self, tokens: int = 1) -> bool:
|
|
"""
|
|
尝试消费令牌
|
|
|
|
Args:
|
|
tokens: 要消费的令牌数
|
|
|
|
Returns:
|
|
是否消费成功
|
|
"""
|
|
self._refill()
|
|
if self.tokens >= tokens:
|
|
self.tokens -= tokens
|
|
return True
|
|
return False
|
|
|
|
def wait_time(self, tokens: int = 1) -> float:
|
|
"""
|
|
计算需要等待的时间
|
|
|
|
Args:
|
|
tokens: 需要的令牌数
|
|
|
|
Returns:
|
|
需要等待的秒数
|
|
"""
|
|
self._refill()
|
|
if self.tokens >= tokens:
|
|
return 0.0
|
|
needed = tokens - self.tokens
|
|
return needed / self.refill_rate
|
|
|
|
|
|
class RateLimiter:
|
|
"""
|
|
速率限制器
|
|
管理多个提供商的速率限制
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._buckets: Dict[str, TokenBucket] = {}
|
|
self._locks: Dict[str, asyncio.Lock] = {}
|
|
|
|
def register(
|
|
self,
|
|
provider_id: str,
|
|
requests_per_minute: int = 60,
|
|
tokens_per_minute: int = 100000
|
|
) -> None:
|
|
"""
|
|
注册提供商的速率限制
|
|
|
|
Args:
|
|
provider_id: 提供商ID
|
|
requests_per_minute: 每分钟请求数
|
|
tokens_per_minute: 每分钟token数
|
|
"""
|
|
# 请求限制桶
|
|
self._buckets[f"{provider_id}:requests"] = TokenBucket(
|
|
capacity=requests_per_minute,
|
|
refill_rate=requests_per_minute / 60.0
|
|
)
|
|
|
|
# Token限制桶
|
|
self._buckets[f"{provider_id}:tokens"] = TokenBucket(
|
|
capacity=tokens_per_minute,
|
|
refill_rate=tokens_per_minute / 60.0
|
|
)
|
|
|
|
# 创建锁
|
|
self._locks[provider_id] = asyncio.Lock()
|
|
|
|
logger.debug(
|
|
f"注册速率限制: {provider_id} - "
|
|
f"{requests_per_minute}请求/分钟, "
|
|
f"{tokens_per_minute}tokens/分钟"
|
|
)
|
|
|
|
def unregister(self, provider_id: str) -> None:
|
|
"""
|
|
取消注册提供商的速率限制
|
|
|
|
Args:
|
|
provider_id: 提供商ID
|
|
"""
|
|
self._buckets.pop(f"{provider_id}:requests", None)
|
|
self._buckets.pop(f"{provider_id}:tokens", None)
|
|
self._locks.pop(provider_id, None)
|
|
|
|
async def acquire(
|
|
self,
|
|
provider_id: str,
|
|
estimated_tokens: int = 1
|
|
) -> bool:
|
|
"""
|
|
获取请求许可(非阻塞)
|
|
|
|
Args:
|
|
provider_id: 提供商ID
|
|
estimated_tokens: 预估token数
|
|
|
|
Returns:
|
|
是否获取成功
|
|
"""
|
|
request_bucket = self._buckets.get(f"{provider_id}:requests")
|
|
token_bucket = self._buckets.get(f"{provider_id}:tokens")
|
|
|
|
if not request_bucket or not token_bucket:
|
|
# 未注册,默认允许
|
|
return True
|
|
|
|
lock = self._locks.get(provider_id)
|
|
if lock:
|
|
async with lock:
|
|
if request_bucket.consume(1) and token_bucket.consume(estimated_tokens):
|
|
return True
|
|
|
|
return False
|
|
|
|
async def acquire_wait(
|
|
self,
|
|
provider_id: str,
|
|
estimated_tokens: int = 1,
|
|
max_wait: float = 60.0
|
|
) -> bool:
|
|
"""
|
|
获取请求许可(阻塞等待)
|
|
|
|
Args:
|
|
provider_id: 提供商ID
|
|
estimated_tokens: 预估token数
|
|
max_wait: 最大等待时间(秒)
|
|
|
|
Returns:
|
|
是否获取成功
|
|
"""
|
|
request_bucket = self._buckets.get(f"{provider_id}:requests")
|
|
token_bucket = self._buckets.get(f"{provider_id}:tokens")
|
|
|
|
if not request_bucket or not token_bucket:
|
|
return True
|
|
|
|
lock = self._locks.get(provider_id)
|
|
if not lock:
|
|
return True
|
|
|
|
start_time = time.time()
|
|
|
|
while True:
|
|
async with lock:
|
|
# 计算需要等待的时间
|
|
request_wait = request_bucket.wait_time(1)
|
|
token_wait = token_bucket.wait_time(estimated_tokens)
|
|
wait_time = max(request_wait, token_wait)
|
|
|
|
if wait_time == 0:
|
|
request_bucket.consume(1)
|
|
token_bucket.consume(estimated_tokens)
|
|
return True
|
|
|
|
# 检查是否超时
|
|
elapsed = time.time() - start_time
|
|
if elapsed + wait_time > max_wait:
|
|
logger.warning(
|
|
f"速率限制等待超时: {provider_id}, "
|
|
f"需要等待{wait_time:.2f}秒"
|
|
)
|
|
return False
|
|
|
|
# 在锁外等待
|
|
await asyncio.sleep(min(wait_time, 1.0))
|
|
|
|
def get_status(self, provider_id: str) -> Optional[Dict[str, any]]:
|
|
"""
|
|
获取提供商的速率限制状态
|
|
|
|
Args:
|
|
provider_id: 提供商ID
|
|
|
|
Returns:
|
|
状态字典
|
|
"""
|
|
request_bucket = self._buckets.get(f"{provider_id}:requests")
|
|
token_bucket = self._buckets.get(f"{provider_id}:tokens")
|
|
|
|
if not request_bucket or not token_bucket:
|
|
return None
|
|
|
|
request_bucket._refill()
|
|
token_bucket._refill()
|
|
|
|
return {
|
|
"requests_remaining": int(request_bucket.tokens),
|
|
"requests_capacity": request_bucket.capacity,
|
|
"tokens_remaining": int(token_bucket.tokens),
|
|
"tokens_capacity": token_bucket.capacity
|
|
}
|
|
|
|
|
|
# 全局速率限制器实例
|
|
rate_limiter = RateLimiter()
|