237 lines
6.0 KiB
Python
237 lines
6.0 KiB
Python
|
|
"""用户认证模块
|
|||
|
|
|
|||
|
|
提供 Token 认证和会话管理
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import hashlib
|
|||
|
|
import hmac
|
|||
|
|
import secrets
|
|||
|
|
import time
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
from minenasai.core import get_logger, get_settings
|
|||
|
|
|
|||
|
|
logger = get_logger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class AuthToken:
|
|||
|
|
"""认证令牌"""
|
|||
|
|
|
|||
|
|
token: str
|
|||
|
|
user_id: str
|
|||
|
|
created_at: float
|
|||
|
|
expires_at: float
|
|||
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def is_expired(self) -> bool:
|
|||
|
|
"""是否已过期"""
|
|||
|
|
return time.time() > self.expires_at
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def remaining_time(self) -> float:
|
|||
|
|
"""剩余有效时间(秒)"""
|
|||
|
|
return max(0, self.expires_at - time.time())
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AuthManager:
|
|||
|
|
"""认证管理器"""
|
|||
|
|
|
|||
|
|
def __init__(self, secret_key: str | None = None) -> None:
|
|||
|
|
"""初始化认证管理器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
secret_key: 签名密钥
|
|||
|
|
"""
|
|||
|
|
self.settings = get_settings()
|
|||
|
|
self.secret_key = secret_key or self.settings.webtui_secret_key
|
|||
|
|
self._tokens: dict[str, AuthToken] = {}
|
|||
|
|
self._user_sessions: dict[str, set[str]] = {} # user_id -> set of tokens
|
|||
|
|
|
|||
|
|
def generate_token(
|
|||
|
|
self,
|
|||
|
|
user_id: str,
|
|||
|
|
expires_in: int | None = None,
|
|||
|
|
metadata: dict[str, Any] | None = None,
|
|||
|
|
) -> str:
|
|||
|
|
"""生成访问令牌
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
user_id: 用户 ID
|
|||
|
|
expires_in: 有效期(秒),默认使用配置
|
|||
|
|
metadata: 额外元数据
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
访问令牌
|
|||
|
|
"""
|
|||
|
|
# 生成随机令牌
|
|||
|
|
random_part = secrets.token_urlsafe(24)
|
|||
|
|
timestamp = str(int(time.time()))
|
|||
|
|
|
|||
|
|
# 签名
|
|||
|
|
signature = self._sign(f"{random_part}:{timestamp}:{user_id}")
|
|||
|
|
token = f"{random_part}.{timestamp}.{signature[:16]}"
|
|||
|
|
|
|||
|
|
# 有效期
|
|||
|
|
if expires_in is None:
|
|||
|
|
expires_in = self.settings.webtui.session_timeout
|
|||
|
|
|
|||
|
|
now = time.time()
|
|||
|
|
auth_token = AuthToken(
|
|||
|
|
token=token,
|
|||
|
|
user_id=user_id,
|
|||
|
|
created_at=now,
|
|||
|
|
expires_at=now + expires_in,
|
|||
|
|
metadata=metadata or {},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 存储
|
|||
|
|
self._tokens[token] = auth_token
|
|||
|
|
|
|||
|
|
# 记录用户会话
|
|||
|
|
if user_id not in self._user_sessions:
|
|||
|
|
self._user_sessions[user_id] = set()
|
|||
|
|
self._user_sessions[user_id].add(token)
|
|||
|
|
|
|||
|
|
logger.info("生成令牌", user_id=user_id, expires_in=expires_in)
|
|||
|
|
return token
|
|||
|
|
|
|||
|
|
def verify_token(self, token: str) -> AuthToken | None:
|
|||
|
|
"""验证令牌
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
token: 访问令牌
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
AuthToken 或 None(无效时)
|
|||
|
|
"""
|
|||
|
|
auth_token = self._tokens.get(token)
|
|||
|
|
|
|||
|
|
if auth_token is None:
|
|||
|
|
logger.debug("令牌不存在", token=token[:20] + "...")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
if auth_token.is_expired:
|
|||
|
|
logger.debug("令牌已过期", token=token[:20] + "...")
|
|||
|
|
self.revoke_token(token)
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
return auth_token
|
|||
|
|
|
|||
|
|
def revoke_token(self, token: str) -> bool:
|
|||
|
|
"""撤销令牌
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
token: 访问令牌
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
是否成功撤销
|
|||
|
|
"""
|
|||
|
|
auth_token = self._tokens.pop(token, None)
|
|||
|
|
if auth_token:
|
|||
|
|
# 从用户会话中移除
|
|||
|
|
if auth_token.user_id in self._user_sessions:
|
|||
|
|
self._user_sessions[auth_token.user_id].discard(token)
|
|||
|
|
logger.info("撤销令牌", user_id=auth_token.user_id)
|
|||
|
|
return True
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def revoke_user_tokens(self, user_id: str) -> int:
|
|||
|
|
"""撤销用户的所有令牌
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
user_id: 用户 ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
撤销的令牌数量
|
|||
|
|
"""
|
|||
|
|
tokens = self._user_sessions.pop(user_id, set())
|
|||
|
|
for token in tokens:
|
|||
|
|
self._tokens.pop(token, None)
|
|||
|
|
|
|||
|
|
if tokens:
|
|||
|
|
logger.info("撤销用户所有令牌", user_id=user_id, count=len(tokens))
|
|||
|
|
|
|||
|
|
return len(tokens)
|
|||
|
|
|
|||
|
|
def refresh_token(self, token: str, extends_by: int | None = None) -> str | None:
|
|||
|
|
"""刷新令牌
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
token: 当前令牌
|
|||
|
|
extends_by: 延长时间(秒)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
新令牌或 None(失败时)
|
|||
|
|
"""
|
|||
|
|
auth_token = self.verify_token(token)
|
|||
|
|
if auth_token is None:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 生成新令牌
|
|||
|
|
new_token = self.generate_token(
|
|||
|
|
user_id=auth_token.user_id,
|
|||
|
|
expires_in=extends_by,
|
|||
|
|
metadata=auth_token.metadata,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 撤销旧令牌
|
|||
|
|
self.revoke_token(token)
|
|||
|
|
|
|||
|
|
return new_token
|
|||
|
|
|
|||
|
|
def cleanup_expired(self) -> int:
|
|||
|
|
"""清理过期令牌
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
清理的令牌数量
|
|||
|
|
"""
|
|||
|
|
now = time.time()
|
|||
|
|
expired = [
|
|||
|
|
token for token, auth_token in self._tokens.items()
|
|||
|
|
if auth_token.expires_at < now
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
for token in expired:
|
|||
|
|
self.revoke_token(token)
|
|||
|
|
|
|||
|
|
if expired:
|
|||
|
|
logger.info("清理过期令牌", count=len(expired))
|
|||
|
|
|
|||
|
|
return len(expired)
|
|||
|
|
|
|||
|
|
def _sign(self, data: str) -> str:
|
|||
|
|
"""生成签名"""
|
|||
|
|
return hmac.new(
|
|||
|
|
self.secret_key.encode(),
|
|||
|
|
data.encode(),
|
|||
|
|
hashlib.sha256
|
|||
|
|
).hexdigest()
|
|||
|
|
|
|||
|
|
def get_stats(self) -> dict[str, Any]:
|
|||
|
|
"""获取统计信息"""
|
|||
|
|
return {
|
|||
|
|
"total_tokens": len(self._tokens),
|
|||
|
|
"total_users": len(self._user_sessions),
|
|||
|
|
"tokens_per_user": {
|
|||
|
|
user_id: len(tokens)
|
|||
|
|
for user_id, tokens in self._user_sessions.items()
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局认证管理器
|
|||
|
|
_auth_manager: AuthManager | None = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_auth_manager() -> AuthManager:
|
|||
|
|
"""获取全局认证管理器"""
|
|||
|
|
global _auth_manager
|
|||
|
|
if _auth_manager is None:
|
|||
|
|
_auth_manager = AuthManager()
|
|||
|
|
return _auth_manager
|