Files
AIChatRoom/backend/services/message_router.py

336 lines
8.9 KiB
Python
Raw Normal View History

"""
消息路由服务
管理消息的发送和广播
"""
import uuid
import asyncio
from datetime import datetime
from typing import List, Dict, Any, Optional, Callable, Set
from dataclasses import dataclass, field
from loguru import logger
from fastapi import WebSocket
from models.message import Message, MessageType
from models.chatroom import ChatRoom
@dataclass
class WebSocketConnection:
"""WebSocket连接信息"""
websocket: WebSocket
room_id: str
connected_at: datetime = field(default_factory=datetime.utcnow)
class MessageRouter:
"""
消息路由器
管理WebSocket连接和消息广播
"""
# 房间连接映射: room_id -> Set[WebSocket]
_room_connections: Dict[str, Set[WebSocket]] = {}
# 消息回调: 用于外部订阅消息
_message_callbacks: List[Callable] = []
@classmethod
async def connect(cls, room_id: str, websocket: WebSocket) -> None:
"""
建立WebSocket连接
Args:
room_id: 聊天室ID
websocket: WebSocket实例
"""
await websocket.accept()
if room_id not in cls._room_connections:
cls._room_connections[room_id] = set()
cls._room_connections[room_id].add(websocket)
logger.info(f"WebSocket连接建立: {room_id}, 当前连接数: {len(cls._room_connections[room_id])}")
@classmethod
async def disconnect(cls, room_id: str, websocket: WebSocket) -> None:
"""
断开WebSocket连接
Args:
room_id: 聊天室ID
websocket: WebSocket实例
"""
if room_id in cls._room_connections:
cls._room_connections[room_id].discard(websocket)
# 清理空房间
if not cls._room_connections[room_id]:
del cls._room_connections[room_id]
logger.info(f"WebSocket连接断开: {room_id}")
@classmethod
async def broadcast_to_room(
cls,
room_id: str,
message: Dict[str, Any]
) -> None:
"""
向聊天室广播消息
Args:
room_id: 聊天室ID
message: 消息内容
"""
if room_id not in cls._room_connections:
return
# 获取所有连接
connections = cls._room_connections[room_id].copy()
# 并发发送
tasks = []
for websocket in connections:
tasks.append(cls._send_message(room_id, websocket, message))
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
@classmethod
async def _send_message(
cls,
room_id: str,
websocket: WebSocket,
message: Dict[str, Any]
) -> None:
"""
向单个WebSocket发送消息
Args:
room_id: 聊天室ID
websocket: WebSocket实例
message: 消息内容
"""
try:
await websocket.send_json(message)
except Exception as e:
logger.warning(f"WebSocket发送失败: {e}")
# 移除断开的连接
await cls.disconnect(room_id, websocket)
@classmethod
async def save_and_broadcast_message(
cls,
room_id: str,
discussion_id: str,
agent_id: Optional[str],
content: str,
message_type: str = MessageType.TEXT.value,
round_num: int = 0,
attachments: Optional[List[Dict[str, Any]]] = None,
tool_calls: Optional[List[Dict[str, Any]]] = None,
tool_results: Optional[List[Dict[str, Any]]] = None
) -> Message:
"""
保存消息并广播
Args:
room_id: 聊天室ID
discussion_id: 讨论ID
agent_id: 发送Agent ID
content: 消息内容
message_type: 消息类型
round_num: 轮次号
attachments: 附件
tool_calls: 工具调用
tool_results: 工具结果
Returns:
保存的Message文档
"""
# 创建消息
message = Message(
message_id=f"msg-{uuid.uuid4().hex[:12]}",
room_id=room_id,
discussion_id=discussion_id,
agent_id=agent_id,
content=content,
message_type=message_type,
attachments=attachments or [],
round=round_num,
token_count=len(content) // 4, # 粗略估计
tool_calls=tool_calls or [],
tool_results=tool_results or [],
created_at=datetime.utcnow()
)
await message.insert()
# 构建广播消息
broadcast_data = {
"type": "message",
"data": {
"message_id": message.message_id,
"room_id": message.room_id,
"discussion_id": message.discussion_id,
"agent_id": message.agent_id,
"content": message.content,
"message_type": message.message_type,
"round": message.round,
"created_at": message.created_at.isoformat()
}
}
# 广播消息
await cls.broadcast_to_room(room_id, broadcast_data)
# 触发回调
for callback in cls._message_callbacks:
try:
await callback(message)
except Exception as e:
logger.error(f"消息回调执行失败: {e}")
return message
@classmethod
async def broadcast_status(
cls,
room_id: str,
status: str,
data: Optional[Dict[str, Any]] = None
) -> None:
"""
广播状态更新
Args:
room_id: 聊天室ID
status: 状态类型
data: 附加数据
"""
message = {
"type": "status",
"status": status,
"data": data or {},
"timestamp": datetime.utcnow().isoformat()
}
await cls.broadcast_to_room(room_id, message)
@classmethod
async def broadcast_typing(
cls,
room_id: str,
agent_id: str,
is_typing: bool = True
) -> None:
"""
广播Agent输入状态
Args:
room_id: 聊天室ID
agent_id: Agent ID
is_typing: 是否正在输入
"""
message = {
"type": "typing",
"agent_id": agent_id,
"is_typing": is_typing,
"timestamp": datetime.utcnow().isoformat()
}
await cls.broadcast_to_room(room_id, message)
@classmethod
async def broadcast_round_info(
cls,
room_id: str,
round_num: int,
total_rounds: int
) -> None:
"""
广播轮次信息
Args:
room_id: 聊天室ID
round_num: 当前轮次
total_rounds: 最大轮次
"""
message = {
"type": "round",
"round": round_num,
"total_rounds": total_rounds,
"timestamp": datetime.utcnow().isoformat()
}
await cls.broadcast_to_room(room_id, message)
@classmethod
async def broadcast_error(
cls,
room_id: str,
error: str,
agent_id: Optional[str] = None
) -> None:
"""
广播错误信息
Args:
room_id: 聊天室ID
error: 错误信息
agent_id: 相关Agent ID
"""
message = {
"type": "error",
"error": error,
"agent_id": agent_id,
"timestamp": datetime.utcnow().isoformat()
}
await cls.broadcast_to_room(room_id, message)
@classmethod
def register_callback(cls, callback: Callable) -> None:
"""
注册消息回调
Args:
callback: 回调函数接收Message参数
"""
cls._message_callbacks.append(callback)
@classmethod
def unregister_callback(cls, callback: Callable) -> None:
"""
注销消息回调
Args:
callback: 回调函数
"""
if callback in cls._message_callbacks:
cls._message_callbacks.remove(callback)
@classmethod
def get_connection_count(cls, room_id: str) -> int:
"""
获取房间连接数
Args:
room_id: 聊天室ID
Returns:
连接数
"""
return len(cls._room_connections.get(room_id, set()))
@classmethod
def get_all_room_ids(cls) -> List[str]:
"""
获取所有活跃房间ID
Returns:
房间ID列表
"""
return list(cls._room_connections.keys())