feat: AI聊天室多Agent协作讨论平台
- 实现Agent管理,支持AI辅助生成系统提示词 - 支持多个AI提供商(OpenRouter、智谱、MiniMax等) - 实现聊天室和讨论引擎 - WebSocket实时消息推送 - 前端使用React + Ant Design - 后端使用FastAPI + MongoDB Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
233
backend/utils/rate_limiter.py
Normal file
233
backend/utils/rate_limiter.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
速率限制器模块
|
||||
使用令牌桶算法控制请求频率
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenBucket:
|
||||
"""令牌桶"""
|
||||
capacity: int # 桶容量
|
||||
tokens: float = field(init=False) # 当前令牌数
|
||||
refill_rate: float # 每秒填充速率
|
||||
last_refill: float = field(default_factory=time.time)
|
||||
|
||||
def __post_init__(self):
|
||||
self.tokens = float(self.capacity)
|
||||
|
||||
def _refill(self) -> None:
|
||||
"""填充令牌"""
|
||||
now = time.time()
|
||||
elapsed = now - self.last_refill
|
||||
self.tokens = min(
|
||||
self.capacity,
|
||||
self.tokens + elapsed * self.refill_rate
|
||||
)
|
||||
self.last_refill = now
|
||||
|
||||
def consume(self, tokens: int = 1) -> bool:
|
||||
"""
|
||||
尝试消费令牌
|
||||
|
||||
Args:
|
||||
tokens: 要消费的令牌数
|
||||
|
||||
Returns:
|
||||
是否消费成功
|
||||
"""
|
||||
self._refill()
|
||||
if self.tokens >= tokens:
|
||||
self.tokens -= tokens
|
||||
return True
|
||||
return False
|
||||
|
||||
def wait_time(self, tokens: int = 1) -> float:
|
||||
"""
|
||||
计算需要等待的时间
|
||||
|
||||
Args:
|
||||
tokens: 需要的令牌数
|
||||
|
||||
Returns:
|
||||
需要等待的秒数
|
||||
"""
|
||||
self._refill()
|
||||
if self.tokens >= tokens:
|
||||
return 0.0
|
||||
needed = tokens - self.tokens
|
||||
return needed / self.refill_rate
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
速率限制器
|
||||
管理多个提供商的速率限制
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._buckets: Dict[str, TokenBucket] = {}
|
||||
self._locks: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
provider_id: str,
|
||||
requests_per_minute: int = 60,
|
||||
tokens_per_minute: int = 100000
|
||||
) -> None:
|
||||
"""
|
||||
注册提供商的速率限制
|
||||
|
||||
Args:
|
||||
provider_id: 提供商ID
|
||||
requests_per_minute: 每分钟请求数
|
||||
tokens_per_minute: 每分钟token数
|
||||
"""
|
||||
# 请求限制桶
|
||||
self._buckets[f"{provider_id}:requests"] = TokenBucket(
|
||||
capacity=requests_per_minute,
|
||||
refill_rate=requests_per_minute / 60.0
|
||||
)
|
||||
|
||||
# Token限制桶
|
||||
self._buckets[f"{provider_id}:tokens"] = TokenBucket(
|
||||
capacity=tokens_per_minute,
|
||||
refill_rate=tokens_per_minute / 60.0
|
||||
)
|
||||
|
||||
# 创建锁
|
||||
self._locks[provider_id] = asyncio.Lock()
|
||||
|
||||
logger.debug(
|
||||
f"注册速率限制: {provider_id} - "
|
||||
f"{requests_per_minute}请求/分钟, "
|
||||
f"{tokens_per_minute}tokens/分钟"
|
||||
)
|
||||
|
||||
def unregister(self, provider_id: str) -> None:
|
||||
"""
|
||||
取消注册提供商的速率限制
|
||||
|
||||
Args:
|
||||
provider_id: 提供商ID
|
||||
"""
|
||||
self._buckets.pop(f"{provider_id}:requests", None)
|
||||
self._buckets.pop(f"{provider_id}:tokens", None)
|
||||
self._locks.pop(provider_id, None)
|
||||
|
||||
async def acquire(
|
||||
self,
|
||||
provider_id: str,
|
||||
estimated_tokens: int = 1
|
||||
) -> bool:
|
||||
"""
|
||||
获取请求许可(非阻塞)
|
||||
|
||||
Args:
|
||||
provider_id: 提供商ID
|
||||
estimated_tokens: 预估token数
|
||||
|
||||
Returns:
|
||||
是否获取成功
|
||||
"""
|
||||
request_bucket = self._buckets.get(f"{provider_id}:requests")
|
||||
token_bucket = self._buckets.get(f"{provider_id}:tokens")
|
||||
|
||||
if not request_bucket or not token_bucket:
|
||||
# 未注册,默认允许
|
||||
return True
|
||||
|
||||
lock = self._locks.get(provider_id)
|
||||
if lock:
|
||||
async with lock:
|
||||
if request_bucket.consume(1) and token_bucket.consume(estimated_tokens):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def acquire_wait(
|
||||
self,
|
||||
provider_id: str,
|
||||
estimated_tokens: int = 1,
|
||||
max_wait: float = 60.0
|
||||
) -> bool:
|
||||
"""
|
||||
获取请求许可(阻塞等待)
|
||||
|
||||
Args:
|
||||
provider_id: 提供商ID
|
||||
estimated_tokens: 预估token数
|
||||
max_wait: 最大等待时间(秒)
|
||||
|
||||
Returns:
|
||||
是否获取成功
|
||||
"""
|
||||
request_bucket = self._buckets.get(f"{provider_id}:requests")
|
||||
token_bucket = self._buckets.get(f"{provider_id}:tokens")
|
||||
|
||||
if not request_bucket or not token_bucket:
|
||||
return True
|
||||
|
||||
lock = self._locks.get(provider_id)
|
||||
if not lock:
|
||||
return True
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
while True:
|
||||
async with lock:
|
||||
# 计算需要等待的时间
|
||||
request_wait = request_bucket.wait_time(1)
|
||||
token_wait = token_bucket.wait_time(estimated_tokens)
|
||||
wait_time = max(request_wait, token_wait)
|
||||
|
||||
if wait_time == 0:
|
||||
request_bucket.consume(1)
|
||||
token_bucket.consume(estimated_tokens)
|
||||
return True
|
||||
|
||||
# 检查是否超时
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed + wait_time > max_wait:
|
||||
logger.warning(
|
||||
f"速率限制等待超时: {provider_id}, "
|
||||
f"需要等待{wait_time:.2f}秒"
|
||||
)
|
||||
return False
|
||||
|
||||
# 在锁外等待
|
||||
await asyncio.sleep(min(wait_time, 1.0))
|
||||
|
||||
def get_status(self, provider_id: str) -> Optional[Dict[str, any]]:
|
||||
"""
|
||||
获取提供商的速率限制状态
|
||||
|
||||
Args:
|
||||
provider_id: 提供商ID
|
||||
|
||||
Returns:
|
||||
状态字典
|
||||
"""
|
||||
request_bucket = self._buckets.get(f"{provider_id}:requests")
|
||||
token_bucket = self._buckets.get(f"{provider_id}:tokens")
|
||||
|
||||
if not request_bucket or not token_bucket:
|
||||
return None
|
||||
|
||||
request_bucket._refill()
|
||||
token_bucket._refill()
|
||||
|
||||
return {
|
||||
"requests_remaining": int(request_bucket.tokens),
|
||||
"requests_capacity": request_bucket.capacity,
|
||||
"tokens_remaining": int(token_bucket.tokens),
|
||||
"tokens_capacity": token_bucket.capacity
|
||||
}
|
||||
|
||||
|
||||
# 全局速率限制器实例
|
||||
rate_limiter = RateLimiter()
|
||||
Reference in New Issue
Block a user