feat: AI聊天室多Agent协作讨论平台
- 实现Agent管理,支持AI辅助生成系统提示词 - 支持多个AI提供商(OpenRouter、智谱、MiniMax等) - 实现聊天室和讨论引擎 - WebSocket实时消息推送 - 前端使用React + Ant Design - 后端使用FastAPI + MongoDB Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
13
backend/utils/__init__.py
Normal file
13
backend/utils/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
工具函数模块
|
||||
"""
|
||||
from .encryption import encrypt_api_key, decrypt_api_key
|
||||
from .proxy_handler import get_http_client
|
||||
from .rate_limiter import RateLimiter
|
||||
|
||||
__all__ = [
|
||||
"encrypt_api_key",
|
||||
"decrypt_api_key",
|
||||
"get_http_client",
|
||||
"RateLimiter",
|
||||
]
|
||||
97
backend/utils/encryption.py
Normal file
97
backend/utils/encryption.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
加密工具模块
|
||||
用于API密钥的加密和解密
|
||||
"""
|
||||
import base64
|
||||
from cryptography.fernet import Fernet
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
from loguru import logger
|
||||
|
||||
from config import settings
|
||||
|
||||
|
||||
def _get_fernet() -> Fernet:
|
||||
"""
|
||||
获取Fernet加密器实例
|
||||
使用配置的加密密钥派生加密密钥
|
||||
|
||||
Returns:
|
||||
Fernet加密器
|
||||
"""
|
||||
# 使用PBKDF2从密钥派生32字节密钥
|
||||
salt = b"ai_chatroom_salt" # 固定salt,实际生产环境应使用随机salt
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=salt,
|
||||
iterations=100000,
|
||||
)
|
||||
key = base64.urlsafe_b64encode(
|
||||
kdf.derive(settings.ENCRYPTION_KEY.encode())
|
||||
)
|
||||
return Fernet(key)
|
||||
|
||||
|
||||
def encrypt_api_key(api_key: str) -> str:
|
||||
"""
|
||||
加密API密钥
|
||||
|
||||
Args:
|
||||
api_key: 原始API密钥
|
||||
|
||||
Returns:
|
||||
加密后的密钥字符串
|
||||
"""
|
||||
if not api_key:
|
||||
return ""
|
||||
|
||||
try:
|
||||
fernet = _get_fernet()
|
||||
encrypted = fernet.encrypt(api_key.encode())
|
||||
return encrypted.decode()
|
||||
except Exception as e:
|
||||
logger.error(f"API密钥加密失败: {e}")
|
||||
raise ValueError("加密失败")
|
||||
|
||||
|
||||
def decrypt_api_key(encrypted_key: str) -> str:
|
||||
"""
|
||||
解密API密钥
|
||||
|
||||
Args:
|
||||
encrypted_key: 加密的密钥字符串
|
||||
|
||||
Returns:
|
||||
解密后的原始API密钥
|
||||
"""
|
||||
if not encrypted_key:
|
||||
return ""
|
||||
|
||||
try:
|
||||
fernet = _get_fernet()
|
||||
decrypted = fernet.decrypt(encrypted_key.encode())
|
||||
return decrypted.decode()
|
||||
except Exception as e:
|
||||
logger.error(f"API密钥解密失败: {e}")
|
||||
raise ValueError("解密失败,密钥可能已损坏或被篡改")
|
||||
|
||||
|
||||
def mask_api_key(api_key: str, visible_chars: int = 4) -> str:
|
||||
"""
|
||||
掩码API密钥,用于安全显示
|
||||
|
||||
Args:
|
||||
api_key: 原始API密钥
|
||||
visible_chars: 末尾可见字符数
|
||||
|
||||
Returns:
|
||||
掩码后的密钥 (如: ****abc1)
|
||||
"""
|
||||
if not api_key:
|
||||
return ""
|
||||
|
||||
if len(api_key) <= visible_chars:
|
||||
return "*" * len(api_key)
|
||||
|
||||
return "*" * (len(api_key) - visible_chars) + api_key[-visible_chars:]
|
||||
135
backend/utils/proxy_handler.py
Normal file
135
backend/utils/proxy_handler.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
代理处理模块
|
||||
处理HTTP代理配置
|
||||
"""
|
||||
from typing import Optional, Dict, Any
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from config import settings
|
||||
|
||||
|
||||
def get_proxy_dict(
|
||||
use_proxy: bool,
|
||||
proxy_config: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
获取代理配置字典
|
||||
|
||||
Args:
|
||||
use_proxy: 是否使用代理
|
||||
proxy_config: 代理配置
|
||||
|
||||
Returns:
|
||||
代理配置字典或None
|
||||
"""
|
||||
if not use_proxy:
|
||||
return None
|
||||
|
||||
proxies = {}
|
||||
|
||||
if proxy_config:
|
||||
http_proxy = proxy_config.get("http_proxy")
|
||||
https_proxy = proxy_config.get("https_proxy")
|
||||
else:
|
||||
# 使用全局默认代理
|
||||
http_proxy = settings.DEFAULT_HTTP_PROXY
|
||||
https_proxy = settings.DEFAULT_HTTPS_PROXY
|
||||
|
||||
if http_proxy:
|
||||
proxies["http://"] = http_proxy
|
||||
if https_proxy:
|
||||
proxies["https://"] = https_proxy
|
||||
|
||||
return proxies if proxies else None
|
||||
|
||||
|
||||
def get_http_client(
|
||||
use_proxy: bool = False,
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 60,
|
||||
**kwargs
|
||||
) -> httpx.AsyncClient:
|
||||
"""
|
||||
获取配置好的HTTP异步客户端
|
||||
|
||||
Args:
|
||||
use_proxy: 是否使用代理
|
||||
proxy_config: 代理配置
|
||||
timeout: 超时时间(秒)
|
||||
**kwargs: 其他httpx参数
|
||||
|
||||
Returns:
|
||||
配置好的httpx.AsyncClient实例
|
||||
"""
|
||||
proxies = get_proxy_dict(use_proxy, proxy_config)
|
||||
|
||||
client_kwargs = {
|
||||
"timeout": httpx.Timeout(timeout),
|
||||
"follow_redirects": True,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
if proxies:
|
||||
client_kwargs["proxies"] = proxies
|
||||
logger.debug(f"HTTP客户端使用代理: {proxies}")
|
||||
|
||||
return httpx.AsyncClient(**client_kwargs)
|
||||
|
||||
|
||||
async def test_proxy_connection(
|
||||
proxy_config: Dict[str, Any],
|
||||
test_url: str = "https://www.google.com"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
测试代理连接是否可用
|
||||
|
||||
Args:
|
||||
proxy_config: 代理配置
|
||||
test_url: 测试URL
|
||||
|
||||
Returns:
|
||||
测试结果字典,包含 success, message, latency_ms
|
||||
"""
|
||||
try:
|
||||
async with get_http_client(
|
||||
use_proxy=True,
|
||||
proxy_config=proxy_config,
|
||||
timeout=10
|
||||
) as client:
|
||||
import time
|
||||
start = time.time()
|
||||
response = await client.get(test_url)
|
||||
latency = (time.time() - start) * 1000
|
||||
|
||||
if response.status_code == 200:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "代理连接正常",
|
||||
"latency_ms": round(latency, 2)
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"代理返回状态码: {response.status_code}",
|
||||
"latency_ms": round(latency, 2)
|
||||
}
|
||||
|
||||
except httpx.ProxyError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"代理连接失败: {str(e)}",
|
||||
"latency_ms": None
|
||||
}
|
||||
except httpx.TimeoutException:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "代理连接超时",
|
||||
"latency_ms": None
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"连接错误: {str(e)}",
|
||||
"latency_ms": None
|
||||
}
|
||||
233
backend/utils/rate_limiter.py
Normal file
233
backend/utils/rate_limiter.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
速率限制器模块
|
||||
使用令牌桶算法控制请求频率
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenBucket:
|
||||
"""令牌桶"""
|
||||
capacity: int # 桶容量
|
||||
tokens: float = field(init=False) # 当前令牌数
|
||||
refill_rate: float # 每秒填充速率
|
||||
last_refill: float = field(default_factory=time.time)
|
||||
|
||||
def __post_init__(self):
|
||||
self.tokens = float(self.capacity)
|
||||
|
||||
def _refill(self) -> None:
|
||||
"""填充令牌"""
|
||||
now = time.time()
|
||||
elapsed = now - self.last_refill
|
||||
self.tokens = min(
|
||||
self.capacity,
|
||||
self.tokens + elapsed * self.refill_rate
|
||||
)
|
||||
self.last_refill = now
|
||||
|
||||
def consume(self, tokens: int = 1) -> bool:
|
||||
"""
|
||||
尝试消费令牌
|
||||
|
||||
Args:
|
||||
tokens: 要消费的令牌数
|
||||
|
||||
Returns:
|
||||
是否消费成功
|
||||
"""
|
||||
self._refill()
|
||||
if self.tokens >= tokens:
|
||||
self.tokens -= tokens
|
||||
return True
|
||||
return False
|
||||
|
||||
def wait_time(self, tokens: int = 1) -> float:
|
||||
"""
|
||||
计算需要等待的时间
|
||||
|
||||
Args:
|
||||
tokens: 需要的令牌数
|
||||
|
||||
Returns:
|
||||
需要等待的秒数
|
||||
"""
|
||||
self._refill()
|
||||
if self.tokens >= tokens:
|
||||
return 0.0
|
||||
needed = tokens - self.tokens
|
||||
return needed / self.refill_rate
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
速率限制器
|
||||
管理多个提供商的速率限制
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._buckets: Dict[str, TokenBucket] = {}
|
||||
self._locks: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
provider_id: str,
|
||||
requests_per_minute: int = 60,
|
||||
tokens_per_minute: int = 100000
|
||||
) -> None:
|
||||
"""
|
||||
注册提供商的速率限制
|
||||
|
||||
Args:
|
||||
provider_id: 提供商ID
|
||||
requests_per_minute: 每分钟请求数
|
||||
tokens_per_minute: 每分钟token数
|
||||
"""
|
||||
# 请求限制桶
|
||||
self._buckets[f"{provider_id}:requests"] = TokenBucket(
|
||||
capacity=requests_per_minute,
|
||||
refill_rate=requests_per_minute / 60.0
|
||||
)
|
||||
|
||||
# Token限制桶
|
||||
self._buckets[f"{provider_id}:tokens"] = TokenBucket(
|
||||
capacity=tokens_per_minute,
|
||||
refill_rate=tokens_per_minute / 60.0
|
||||
)
|
||||
|
||||
# 创建锁
|
||||
self._locks[provider_id] = asyncio.Lock()
|
||||
|
||||
logger.debug(
|
||||
f"注册速率限制: {provider_id} - "
|
||||
f"{requests_per_minute}请求/分钟, "
|
||||
f"{tokens_per_minute}tokens/分钟"
|
||||
)
|
||||
|
||||
def unregister(self, provider_id: str) -> None:
|
||||
"""
|
||||
取消注册提供商的速率限制
|
||||
|
||||
Args:
|
||||
provider_id: 提供商ID
|
||||
"""
|
||||
self._buckets.pop(f"{provider_id}:requests", None)
|
||||
self._buckets.pop(f"{provider_id}:tokens", None)
|
||||
self._locks.pop(provider_id, None)
|
||||
|
||||
async def acquire(
|
||||
self,
|
||||
provider_id: str,
|
||||
estimated_tokens: int = 1
|
||||
) -> bool:
|
||||
"""
|
||||
获取请求许可(非阻塞)
|
||||
|
||||
Args:
|
||||
provider_id: 提供商ID
|
||||
estimated_tokens: 预估token数
|
||||
|
||||
Returns:
|
||||
是否获取成功
|
||||
"""
|
||||
request_bucket = self._buckets.get(f"{provider_id}:requests")
|
||||
token_bucket = self._buckets.get(f"{provider_id}:tokens")
|
||||
|
||||
if not request_bucket or not token_bucket:
|
||||
# 未注册,默认允许
|
||||
return True
|
||||
|
||||
lock = self._locks.get(provider_id)
|
||||
if lock:
|
||||
async with lock:
|
||||
if request_bucket.consume(1) and token_bucket.consume(estimated_tokens):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def acquire_wait(
|
||||
self,
|
||||
provider_id: str,
|
||||
estimated_tokens: int = 1,
|
||||
max_wait: float = 60.0
|
||||
) -> bool:
|
||||
"""
|
||||
获取请求许可(阻塞等待)
|
||||
|
||||
Args:
|
||||
provider_id: 提供商ID
|
||||
estimated_tokens: 预估token数
|
||||
max_wait: 最大等待时间(秒)
|
||||
|
||||
Returns:
|
||||
是否获取成功
|
||||
"""
|
||||
request_bucket = self._buckets.get(f"{provider_id}:requests")
|
||||
token_bucket = self._buckets.get(f"{provider_id}:tokens")
|
||||
|
||||
if not request_bucket or not token_bucket:
|
||||
return True
|
||||
|
||||
lock = self._locks.get(provider_id)
|
||||
if not lock:
|
||||
return True
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
while True:
|
||||
async with lock:
|
||||
# 计算需要等待的时间
|
||||
request_wait = request_bucket.wait_time(1)
|
||||
token_wait = token_bucket.wait_time(estimated_tokens)
|
||||
wait_time = max(request_wait, token_wait)
|
||||
|
||||
if wait_time == 0:
|
||||
request_bucket.consume(1)
|
||||
token_bucket.consume(estimated_tokens)
|
||||
return True
|
||||
|
||||
# 检查是否超时
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed + wait_time > max_wait:
|
||||
logger.warning(
|
||||
f"速率限制等待超时: {provider_id}, "
|
||||
f"需要等待{wait_time:.2f}秒"
|
||||
)
|
||||
return False
|
||||
|
||||
# 在锁外等待
|
||||
await asyncio.sleep(min(wait_time, 1.0))
|
||||
|
||||
def get_status(self, provider_id: str) -> Optional[Dict[str, any]]:
|
||||
"""
|
||||
获取提供商的速率限制状态
|
||||
|
||||
Args:
|
||||
provider_id: 提供商ID
|
||||
|
||||
Returns:
|
||||
状态字典
|
||||
"""
|
||||
request_bucket = self._buckets.get(f"{provider_id}:requests")
|
||||
token_bucket = self._buckets.get(f"{provider_id}:tokens")
|
||||
|
||||
if not request_bucket or not token_bucket:
|
||||
return None
|
||||
|
||||
request_bucket._refill()
|
||||
token_bucket._refill()
|
||||
|
||||
return {
|
||||
"requests_remaining": int(request_bucket.tokens),
|
||||
"requests_capacity": request_bucket.capacity,
|
||||
"tokens_remaining": int(token_bucket.tokens),
|
||||
"tokens_capacity": token_bucket.capacity
|
||||
}
|
||||
|
||||
|
||||
# 全局速率限制器实例
|
||||
rate_limiter = RateLimiter()
|
||||
Reference in New Issue
Block a user