""" 消息路由服务 管理消息的发送和广播 """ import uuid import asyncio from datetime import datetime from typing import List, Dict, Any, Optional, Callable, Set from dataclasses import dataclass, field from loguru import logger from fastapi import WebSocket from models.message import Message, MessageType from models.chatroom import ChatRoom @dataclass class WebSocketConnection: """WebSocket连接信息""" websocket: WebSocket room_id: str connected_at: datetime = field(default_factory=datetime.utcnow) class MessageRouter: """ 消息路由器 管理WebSocket连接和消息广播 """ # 房间连接映射: room_id -> Set[WebSocket] _room_connections: Dict[str, Set[WebSocket]] = {} # 消息回调: 用于外部订阅消息 _message_callbacks: List[Callable] = [] @classmethod async def connect(cls, room_id: str, websocket: WebSocket) -> None: """ 建立WebSocket连接 Args: room_id: 聊天室ID websocket: WebSocket实例 """ await websocket.accept() if room_id not in cls._room_connections: cls._room_connections[room_id] = set() cls._room_connections[room_id].add(websocket) logger.info(f"WebSocket连接建立: {room_id}, 当前连接数: {len(cls._room_connections[room_id])}") @classmethod async def disconnect(cls, room_id: str, websocket: WebSocket) -> None: """ 断开WebSocket连接 Args: room_id: 聊天室ID websocket: WebSocket实例 """ if room_id in cls._room_connections: cls._room_connections[room_id].discard(websocket) # 清理空房间 if not cls._room_connections[room_id]: del cls._room_connections[room_id] logger.info(f"WebSocket连接断开: {room_id}") @classmethod async def broadcast_to_room( cls, room_id: str, message: Dict[str, Any] ) -> None: """ 向聊天室广播消息 Args: room_id: 聊天室ID message: 消息内容 """ if room_id not in cls._room_connections: return # 获取所有连接 connections = cls._room_connections[room_id].copy() # 并发发送 tasks = [] for websocket in connections: tasks.append(cls._send_message(room_id, websocket, message)) if tasks: await asyncio.gather(*tasks, return_exceptions=True) @classmethod async def _send_message( cls, room_id: str, websocket: WebSocket, message: Dict[str, Any] ) -> None: """ 向单个WebSocket发送消息 Args: room_id: 聊天室ID websocket: WebSocket实例 message: 消息内容 """ try: await websocket.send_json(message) except Exception as e: logger.warning(f"WebSocket发送失败: {e}") # 移除断开的连接 await cls.disconnect(room_id, websocket) @classmethod async def save_and_broadcast_message( cls, room_id: str, discussion_id: str, agent_id: Optional[str], content: str, message_type: str = MessageType.TEXT.value, round_num: int = 0, attachments: Optional[List[Dict[str, Any]]] = None, tool_calls: Optional[List[Dict[str, Any]]] = None, tool_results: Optional[List[Dict[str, Any]]] = None ) -> Message: """ 保存消息并广播 Args: room_id: 聊天室ID discussion_id: 讨论ID agent_id: 发送Agent ID content: 消息内容 message_type: 消息类型 round_num: 轮次号 attachments: 附件 tool_calls: 工具调用 tool_results: 工具结果 Returns: 保存的Message文档 """ # 创建消息 message = Message( message_id=f"msg-{uuid.uuid4().hex[:12]}", room_id=room_id, discussion_id=discussion_id, agent_id=agent_id, content=content, message_type=message_type, attachments=attachments or [], round=round_num, token_count=len(content) // 4, # 粗略估计 tool_calls=tool_calls or [], tool_results=tool_results or [], created_at=datetime.utcnow() ) await message.insert() # 构建广播消息 broadcast_data = { "type": "message", "data": { "message_id": message.message_id, "room_id": message.room_id, "discussion_id": message.discussion_id, "agent_id": message.agent_id, "content": message.content, "message_type": message.message_type, "round": message.round, "created_at": message.created_at.isoformat() } } # 广播消息 await cls.broadcast_to_room(room_id, broadcast_data) # 触发回调 for callback in cls._message_callbacks: try: await callback(message) except Exception as e: logger.error(f"消息回调执行失败: {e}") return message @classmethod async def broadcast_status( cls, room_id: str, status: str, data: Optional[Dict[str, Any]] = None ) -> None: """ 广播状态更新 Args: room_id: 聊天室ID status: 状态类型 data: 附加数据 """ message = { "type": "status", "status": status, "data": data or {}, "timestamp": datetime.utcnow().isoformat() } await cls.broadcast_to_room(room_id, message) @classmethod async def broadcast_typing( cls, room_id: str, agent_id: str, is_typing: bool = True ) -> None: """ 广播Agent输入状态 Args: room_id: 聊天室ID agent_id: Agent ID is_typing: 是否正在输入 """ message = { "type": "typing", "agent_id": agent_id, "is_typing": is_typing, "timestamp": datetime.utcnow().isoformat() } await cls.broadcast_to_room(room_id, message) @classmethod async def broadcast_round_info( cls, room_id: str, round_num: int, total_rounds: int ) -> None: """ 广播轮次信息 Args: room_id: 聊天室ID round_num: 当前轮次 total_rounds: 最大轮次 """ message = { "type": "round", "round": round_num, "total_rounds": total_rounds, "timestamp": datetime.utcnow().isoformat() } await cls.broadcast_to_room(room_id, message) @classmethod async def broadcast_error( cls, room_id: str, error: str, agent_id: Optional[str] = None ) -> None: """ 广播错误信息 Args: room_id: 聊天室ID error: 错误信息 agent_id: 相关Agent ID """ message = { "type": "error", "error": error, "agent_id": agent_id, "timestamp": datetime.utcnow().isoformat() } await cls.broadcast_to_room(room_id, message) @classmethod def register_callback(cls, callback: Callable) -> None: """ 注册消息回调 Args: callback: 回调函数,接收Message参数 """ cls._message_callbacks.append(callback) @classmethod def unregister_callback(cls, callback: Callable) -> None: """ 注销消息回调 Args: callback: 回调函数 """ if callback in cls._message_callbacks: cls._message_callbacks.remove(callback) @classmethod def get_connection_count(cls, room_id: str) -> int: """ 获取房间连接数 Args: room_id: 聊天室ID Returns: 连接数 """ return len(cls._room_connections.get(room_id, set())) @classmethod def get_all_room_ids(cls) -> List[str]: """ 获取所有活跃房间ID Returns: 房间ID列表 """ return list(cls._room_connections.keys())