98 lines
2.3 KiB
Python
98 lines
2.3 KiB
Python
|
|
"""
|
|||
|
|
加密工具模块
|
|||
|
|
用于API密钥的加密和解密
|
|||
|
|
"""
|
|||
|
|
import base64
|
|||
|
|
from cryptography.fernet import Fernet
|
|||
|
|
from cryptography.hazmat.primitives import hashes
|
|||
|
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
|||
|
|
from loguru import logger
|
|||
|
|
|
|||
|
|
from config import settings
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _get_fernet() -> Fernet:
|
|||
|
|
"""
|
|||
|
|
获取Fernet加密器实例
|
|||
|
|
使用配置的加密密钥派生加密密钥
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Fernet加密器
|
|||
|
|
"""
|
|||
|
|
# 使用PBKDF2从密钥派生32字节密钥
|
|||
|
|
salt = b"ai_chatroom_salt" # 固定salt,实际生产环境应使用随机salt
|
|||
|
|
kdf = PBKDF2HMAC(
|
|||
|
|
algorithm=hashes.SHA256(),
|
|||
|
|
length=32,
|
|||
|
|
salt=salt,
|
|||
|
|
iterations=100000,
|
|||
|
|
)
|
|||
|
|
key = base64.urlsafe_b64encode(
|
|||
|
|
kdf.derive(settings.ENCRYPTION_KEY.encode())
|
|||
|
|
)
|
|||
|
|
return Fernet(key)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def encrypt_api_key(api_key: str) -> str:
|
|||
|
|
"""
|
|||
|
|
加密API密钥
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
api_key: 原始API密钥
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
加密后的密钥字符串
|
|||
|
|
"""
|
|||
|
|
if not api_key:
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
fernet = _get_fernet()
|
|||
|
|
encrypted = fernet.encrypt(api_key.encode())
|
|||
|
|
return encrypted.decode()
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"API密钥加密失败: {e}")
|
|||
|
|
raise ValueError("加密失败")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def decrypt_api_key(encrypted_key: str) -> str:
|
|||
|
|
"""
|
|||
|
|
解密API密钥
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
encrypted_key: 加密的密钥字符串
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
解密后的原始API密钥
|
|||
|
|
"""
|
|||
|
|
if not encrypted_key:
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
fernet = _get_fernet()
|
|||
|
|
decrypted = fernet.decrypt(encrypted_key.encode())
|
|||
|
|
return decrypted.decode()
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"API密钥解密失败: {e}")
|
|||
|
|
raise ValueError("解密失败,密钥可能已损坏或被篡改")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def mask_api_key(api_key: str, visible_chars: int = 4) -> str:
|
|||
|
|
"""
|
|||
|
|
掩码API密钥,用于安全显示
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
api_key: 原始API密钥
|
|||
|
|
visible_chars: 末尾可见字符数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
掩码后的密钥 (如: ****abc1)
|
|||
|
|
"""
|
|||
|
|
if not api_key:
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
if len(api_key) <= visible_chars:
|
|||
|
|
return "*" * len(api_key)
|
|||
|
|
|
|||
|
|
return "*" * (len(api_key) - visible_chars) + api_key[-visible_chars:]
|