"""Web TUI 服务器
提供 Web 终端界面和 WebSocket 终端通信
"""
from __future__ import annotations
import asyncio
import uuid
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from minenasai.core import get_logger, get_settings, setup_logging
from minenasai.webtui.auth import get_auth_manager
from minenasai.webtui.ssh_manager import SSHSession, get_ssh_manager
logger = get_logger(__name__)
# 静态文件目录
STATIC_DIR = Path(__file__).parent / "static"
class TerminalConnection:
"""终端连接"""
def __init__(
self,
websocket: WebSocket,
session_id: str,
user_id: str,
) -> None:
self.websocket = websocket
self.session_id = session_id
self.user_id = user_id
self.ssh_session: SSHSession | None = None
self.authenticated = False
async def send_json(self, data: dict[str, Any]) -> None:
"""发送 JSON 消息"""
try:
await self.websocket.send_json(data)
except Exception as e:
logger.error("发送消息失败", error=str(e))
async def send_output(self, data: bytes) -> None:
"""发送终端输出"""
try:
await self.websocket.send_json({
"type": "output",
"data": data.decode("utf-8", errors="replace"),
})
except Exception as e:
logger.error("发送输出失败", error=str(e))
class ConnectionManager:
"""WebSocket 连接管理器"""
def __init__(self) -> None:
self.connections: dict[str, TerminalConnection] = {}
async def connect(
self,
websocket: WebSocket,
session_id: str,
user_id: str,
) -> TerminalConnection:
"""接受新连接"""
await websocket.accept()
conn = TerminalConnection(websocket, session_id, user_id)
self.connections[session_id] = conn
logger.info("终端连接建立", session_id=session_id)
return conn
async def disconnect(self, session_id: str) -> None:
"""断开连接"""
conn = self.connections.pop(session_id, None)
if conn and conn.ssh_session:
ssh_manager = get_ssh_manager()
await ssh_manager.close_session(session_id)
logger.info("终端连接断开", session_id=session_id)
def get_connection(self, session_id: str) -> TerminalConnection | None:
"""获取连接"""
return self.connections.get(session_id)
manager = ConnectionManager()
@asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
"""应用生命周期管理
Args:
_app: FastAPI 应用实例(lifespan 标准签名要求)
"""
settings = get_settings()
setup_logging(settings.logging)
logger.info("Web TUI 服务启动", port=settings.webtui.port)
yield
# 清理所有连接
ssh_manager = get_ssh_manager()
await ssh_manager.close_all()
logger.info("Web TUI 服务关闭")
app = FastAPI(
title="MineNASAI Web TUI",
description="Web 终端界面",
version="0.1.0",
lifespan=lifespan,
)
# CORS 配置
settings = get_settings()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 静态文件
if STATIC_DIR.exists():
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
@app.get("/")
async def index() -> HTMLResponse:
"""首页"""
index_file = STATIC_DIR / "index.html"
if index_file.exists():
return HTMLResponse(content=index_file.read_text(encoding="utf-8"))
return HTMLResponse(content="
MineNASAI Web TUI
")
@app.get("/health")
async def health() -> dict[str, str]:
"""健康检查"""
return {"status": "healthy"}
@app.get("/api/stats")
async def stats() -> dict[str, Any]:
"""获取统计信息"""
ssh_manager = get_ssh_manager()
auth_manager = get_auth_manager()
return {
"connections": len(manager.connections),
"ssh": ssh_manager.get_stats(),
"auth": auth_manager.get_stats(),
}
@app.post("/api/token")
async def generate_token(
user_id: str = "anonymous",
expires_in: int = 3600,
) -> dict[str, Any]:
"""生成访问令牌(用于测试)"""
auth_manager = get_auth_manager()
token = auth_manager.generate_token(user_id, expires_in)
return {
"token": token,
"user_id": user_id,
"expires_in": expires_in,
}
@app.websocket("/ws/terminal")
async def terminal_websocket(websocket: WebSocket) -> None:
"""终端 WebSocket 端点"""
session_id = str(uuid.uuid4())
conn: TerminalConnection | None = None
try:
conn = await manager.connect(websocket, session_id, "")
while True:
data = await websocket.receive_json()
msg_type = data.get("type")
if msg_type == "auth":
await handle_auth(conn, data)
elif msg_type == "input":
await handle_input(conn, data)
elif msg_type == "resize":
await handle_resize(conn, data)
elif msg_type == "pong":
pass # 心跳响应
else:
await conn.send_json({
"type": "error",
"message": f"未知消息类型: {msg_type}",
})
except WebSocketDisconnect:
pass
except Exception as e:
logger.error("WebSocket 错误", error=str(e), session_id=session_id)
finally:
if conn:
await manager.disconnect(session_id)
async def handle_auth(conn: TerminalConnection, data: dict[str, Any]) -> None:
"""处理认证"""
token = data.get("token", "")
auth_manager = get_auth_manager()
# 允许匿名访问(开发模式)
if token == "anonymous":
conn.authenticated = True
conn.user_id = "anonymous"
await conn.send_json({
"type": "auth_ok",
"session_id": conn.session_id,
"user_id": conn.user_id,
})
# 创建 SSH 会话
await create_ssh_session(conn)
return
# 验证令牌
auth_token = auth_manager.verify_token(token)
if auth_token is None:
await conn.send_json({
"type": "auth_error",
"message": "无效的令牌",
})
return
conn.authenticated = True
conn.user_id = auth_token.user_id
await conn.send_json({
"type": "auth_ok",
"session_id": conn.session_id,
"user_id": conn.user_id,
})
# 创建 SSH 会话
await create_ssh_session(conn)
async def create_ssh_session(conn: TerminalConnection) -> None:
"""创建 SSH 会话"""
ssh_manager = get_ssh_manager()
session = await ssh_manager.create_session(conn.session_id)
if session is None:
await conn.send_json({
"type": "error",
"message": "SSH 连接失败",
})
return
conn.ssh_session = session
# 设置输出回调
def on_output(data: bytes) -> None:
asyncio.create_task(conn.send_output(data))
session.set_output_callback(on_output)
# 开始读取输出
await session.start_reading()
async def handle_input(conn: TerminalConnection, data: dict[str, Any]) -> None:
"""处理终端输入"""
if not conn.authenticated:
await conn.send_json({
"type": "error",
"message": "未认证",
})
return
if conn.ssh_session is None:
await conn.send_json({
"type": "error",
"message": "SSH 会话未建立",
})
return
input_data = data.get("data", "")
await conn.ssh_session.write(input_data)
async def handle_resize(conn: TerminalConnection, data: dict[str, Any]) -> None:
"""处理终端大小调整"""
if conn.ssh_session is None:
return
cols = data.get("cols", 80)
rows = data.get("rows", 24)
await conn.ssh_session.resize(cols, rows)
await conn.send_json({
"type": "resize_ok",
"cols": cols,
"rows": rows,
})