Files
MineNasAI/src/minenasai/webtui/auth.py

237 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""用户认证模块
提供 Token 认证和会话管理
"""
from __future__ import annotations
import hashlib
import hmac
import secrets
import time
from dataclasses import dataclass, field
from typing import Any
from minenasai.core import get_logger, get_settings
logger = get_logger(__name__)
@dataclass
class AuthToken:
"""认证令牌"""
token: str
user_id: str
created_at: float
expires_at: float
metadata: dict[str, Any] = field(default_factory=dict)
@property
def is_expired(self) -> bool:
"""是否已过期"""
return time.time() > self.expires_at
@property
def remaining_time(self) -> float:
"""剩余有效时间(秒)"""
return max(0, self.expires_at - time.time())
class AuthManager:
"""认证管理器"""
def __init__(self, secret_key: str | None = None) -> None:
"""初始化认证管理器
Args:
secret_key: 签名密钥
"""
self.settings = get_settings()
self.secret_key = secret_key or self.settings.webtui_secret_key
self._tokens: dict[str, AuthToken] = {}
self._user_sessions: dict[str, set[str]] = {} # user_id -> set of tokens
def generate_token(
self,
user_id: str,
expires_in: int | None = None,
metadata: dict[str, Any] | None = None,
) -> str:
"""生成访问令牌
Args:
user_id: 用户 ID
expires_in: 有效期(秒),默认使用配置
metadata: 额外元数据
Returns:
访问令牌
"""
# 生成随机令牌
random_part = secrets.token_urlsafe(24)
timestamp = str(int(time.time()))
# 签名
signature = self._sign(f"{random_part}:{timestamp}:{user_id}")
token = f"{random_part}.{timestamp}.{signature[:16]}"
# 有效期
if expires_in is None:
expires_in = self.settings.webtui.session_timeout
now = time.time()
auth_token = AuthToken(
token=token,
user_id=user_id,
created_at=now,
expires_at=now + expires_in,
metadata=metadata or {},
)
# 存储
self._tokens[token] = auth_token
# 记录用户会话
if user_id not in self._user_sessions:
self._user_sessions[user_id] = set()
self._user_sessions[user_id].add(token)
logger.info("生成令牌", user_id=user_id, expires_in=expires_in)
return token
def verify_token(self, token: str) -> AuthToken | None:
"""验证令牌
Args:
token: 访问令牌
Returns:
AuthToken 或 None无效时
"""
auth_token = self._tokens.get(token)
if auth_token is None:
logger.debug("令牌不存在", token=token[:20] + "...")
return None
if auth_token.is_expired:
logger.debug("令牌已过期", token=token[:20] + "...")
self.revoke_token(token)
return None
return auth_token
def revoke_token(self, token: str) -> bool:
"""撤销令牌
Args:
token: 访问令牌
Returns:
是否成功撤销
"""
auth_token = self._tokens.pop(token, None)
if auth_token:
# 从用户会话中移除
if auth_token.user_id in self._user_sessions:
self._user_sessions[auth_token.user_id].discard(token)
logger.info("撤销令牌", user_id=auth_token.user_id)
return True
return False
def revoke_user_tokens(self, user_id: str) -> int:
"""撤销用户的所有令牌
Args:
user_id: 用户 ID
Returns:
撤销的令牌数量
"""
tokens = self._user_sessions.pop(user_id, set())
for token in tokens:
self._tokens.pop(token, None)
if tokens:
logger.info("撤销用户所有令牌", user_id=user_id, count=len(tokens))
return len(tokens)
def refresh_token(self, token: str, extends_by: int | None = None) -> str | None:
"""刷新令牌
Args:
token: 当前令牌
extends_by: 延长时间(秒)
Returns:
新令牌或 None失败时
"""
auth_token = self.verify_token(token)
if auth_token is None:
return None
# 生成新令牌
new_token = self.generate_token(
user_id=auth_token.user_id,
expires_in=extends_by,
metadata=auth_token.metadata,
)
# 撤销旧令牌
self.revoke_token(token)
return new_token
def cleanup_expired(self) -> int:
"""清理过期令牌
Returns:
清理的令牌数量
"""
now = time.time()
expired = [
token for token, auth_token in self._tokens.items()
if auth_token.expires_at < now
]
for token in expired:
self.revoke_token(token)
if expired:
logger.info("清理过期令牌", count=len(expired))
return len(expired)
def _sign(self, data: str) -> str:
"""生成签名"""
return hmac.new(
self.secret_key.encode(),
data.encode(),
hashlib.sha256
).hexdigest()
def get_stats(self) -> dict[str, Any]:
"""获取统计信息"""
return {
"total_tokens": len(self._tokens),
"total_users": len(self._user_sessions),
"tokens_per_user": {
user_id: len(tokens)
for user_id, tokens in self._user_sessions.items()
},
}
# 全局认证管理器
_auth_manager: AuthManager | None = None
def get_auth_manager() -> AuthManager:
"""获取全局认证管理器"""
global _auth_manager
if _auth_manager is None:
_auth_manager = AuthManager()
return _auth_manager