336 lines
8.9 KiB
Python
336 lines
8.9 KiB
Python
|
|
"""
|
|||
|
|
消息路由服务
|
|||
|
|
管理消息的发送和广播
|
|||
|
|
"""
|
|||
|
|
import uuid
|
|||
|
|
import asyncio
|
|||
|
|
from datetime import datetime
|
|||
|
|
from typing import List, Dict, Any, Optional, Callable, Set
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from loguru import logger
|
|||
|
|
from fastapi import WebSocket
|
|||
|
|
|
|||
|
|
from models.message import Message, MessageType
|
|||
|
|
from models.chatroom import ChatRoom
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class WebSocketConnection:
|
|||
|
|
"""WebSocket连接信息"""
|
|||
|
|
websocket: WebSocket
|
|||
|
|
room_id: str
|
|||
|
|
connected_at: datetime = field(default_factory=datetime.utcnow)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class MessageRouter:
|
|||
|
|
"""
|
|||
|
|
消息路由器
|
|||
|
|
管理WebSocket连接和消息广播
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 房间连接映射: room_id -> Set[WebSocket]
|
|||
|
|
_room_connections: Dict[str, Set[WebSocket]] = {}
|
|||
|
|
|
|||
|
|
# 消息回调: 用于外部订阅消息
|
|||
|
|
_message_callbacks: List[Callable] = []
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def connect(cls, room_id: str, websocket: WebSocket) -> None:
|
|||
|
|
"""
|
|||
|
|
建立WebSocket连接
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
room_id: 聊天室ID
|
|||
|
|
websocket: WebSocket实例
|
|||
|
|
"""
|
|||
|
|
await websocket.accept()
|
|||
|
|
|
|||
|
|
if room_id not in cls._room_connections:
|
|||
|
|
cls._room_connections[room_id] = set()
|
|||
|
|
|
|||
|
|
cls._room_connections[room_id].add(websocket)
|
|||
|
|
|
|||
|
|
logger.info(f"WebSocket连接建立: {room_id}, 当前连接数: {len(cls._room_connections[room_id])}")
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def disconnect(cls, room_id: str, websocket: WebSocket) -> None:
|
|||
|
|
"""
|
|||
|
|
断开WebSocket连接
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
room_id: 聊天室ID
|
|||
|
|
websocket: WebSocket实例
|
|||
|
|
"""
|
|||
|
|
if room_id in cls._room_connections:
|
|||
|
|
cls._room_connections[room_id].discard(websocket)
|
|||
|
|
|
|||
|
|
# 清理空房间
|
|||
|
|
if not cls._room_connections[room_id]:
|
|||
|
|
del cls._room_connections[room_id]
|
|||
|
|
|
|||
|
|
logger.info(f"WebSocket连接断开: {room_id}")
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def broadcast_to_room(
|
|||
|
|
cls,
|
|||
|
|
room_id: str,
|
|||
|
|
message: Dict[str, Any]
|
|||
|
|
) -> None:
|
|||
|
|
"""
|
|||
|
|
向聊天室广播消息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
room_id: 聊天室ID
|
|||
|
|
message: 消息内容
|
|||
|
|
"""
|
|||
|
|
if room_id not in cls._room_connections:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 获取所有连接
|
|||
|
|
connections = cls._room_connections[room_id].copy()
|
|||
|
|
|
|||
|
|
# 并发发送
|
|||
|
|
tasks = []
|
|||
|
|
for websocket in connections:
|
|||
|
|
tasks.append(cls._send_message(room_id, websocket, message))
|
|||
|
|
|
|||
|
|
if tasks:
|
|||
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def _send_message(
|
|||
|
|
cls,
|
|||
|
|
room_id: str,
|
|||
|
|
websocket: WebSocket,
|
|||
|
|
message: Dict[str, Any]
|
|||
|
|
) -> None:
|
|||
|
|
"""
|
|||
|
|
向单个WebSocket发送消息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
room_id: 聊天室ID
|
|||
|
|
websocket: WebSocket实例
|
|||
|
|
message: 消息内容
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
await websocket.send_json(message)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"WebSocket发送失败: {e}")
|
|||
|
|
# 移除断开的连接
|
|||
|
|
await cls.disconnect(room_id, websocket)
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def save_and_broadcast_message(
|
|||
|
|
cls,
|
|||
|
|
room_id: str,
|
|||
|
|
discussion_id: str,
|
|||
|
|
agent_id: Optional[str],
|
|||
|
|
content: str,
|
|||
|
|
message_type: str = MessageType.TEXT.value,
|
|||
|
|
round_num: int = 0,
|
|||
|
|
attachments: Optional[List[Dict[str, Any]]] = None,
|
|||
|
|
tool_calls: Optional[List[Dict[str, Any]]] = None,
|
|||
|
|
tool_results: Optional[List[Dict[str, Any]]] = None
|
|||
|
|
) -> Message:
|
|||
|
|
"""
|
|||
|
|
保存消息并广播
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
room_id: 聊天室ID
|
|||
|
|
discussion_id: 讨论ID
|
|||
|
|
agent_id: 发送Agent ID
|
|||
|
|
content: 消息内容
|
|||
|
|
message_type: 消息类型
|
|||
|
|
round_num: 轮次号
|
|||
|
|
attachments: 附件
|
|||
|
|
tool_calls: 工具调用
|
|||
|
|
tool_results: 工具结果
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
保存的Message文档
|
|||
|
|
"""
|
|||
|
|
# 创建消息
|
|||
|
|
message = Message(
|
|||
|
|
message_id=f"msg-{uuid.uuid4().hex[:12]}",
|
|||
|
|
room_id=room_id,
|
|||
|
|
discussion_id=discussion_id,
|
|||
|
|
agent_id=agent_id,
|
|||
|
|
content=content,
|
|||
|
|
message_type=message_type,
|
|||
|
|
attachments=attachments or [],
|
|||
|
|
round=round_num,
|
|||
|
|
token_count=len(content) // 4, # 粗略估计
|
|||
|
|
tool_calls=tool_calls or [],
|
|||
|
|
tool_results=tool_results or [],
|
|||
|
|
created_at=datetime.utcnow()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
await message.insert()
|
|||
|
|
|
|||
|
|
# 构建广播消息
|
|||
|
|
broadcast_data = {
|
|||
|
|
"type": "message",
|
|||
|
|
"data": {
|
|||
|
|
"message_id": message.message_id,
|
|||
|
|
"room_id": message.room_id,
|
|||
|
|
"discussion_id": message.discussion_id,
|
|||
|
|
"agent_id": message.agent_id,
|
|||
|
|
"content": message.content,
|
|||
|
|
"message_type": message.message_type,
|
|||
|
|
"round": message.round,
|
|||
|
|
"created_at": message.created_at.isoformat()
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 广播消息
|
|||
|
|
await cls.broadcast_to_room(room_id, broadcast_data)
|
|||
|
|
|
|||
|
|
# 触发回调
|
|||
|
|
for callback in cls._message_callbacks:
|
|||
|
|
try:
|
|||
|
|
await callback(message)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"消息回调执行失败: {e}")
|
|||
|
|
|
|||
|
|
return message
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def broadcast_status(
|
|||
|
|
cls,
|
|||
|
|
room_id: str,
|
|||
|
|
status: str,
|
|||
|
|
data: Optional[Dict[str, Any]] = None
|
|||
|
|
) -> None:
|
|||
|
|
"""
|
|||
|
|
广播状态更新
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
room_id: 聊天室ID
|
|||
|
|
status: 状态类型
|
|||
|
|
data: 附加数据
|
|||
|
|
"""
|
|||
|
|
message = {
|
|||
|
|
"type": "status",
|
|||
|
|
"status": status,
|
|||
|
|
"data": data or {},
|
|||
|
|
"timestamp": datetime.utcnow().isoformat()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
await cls.broadcast_to_room(room_id, message)
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def broadcast_typing(
|
|||
|
|
cls,
|
|||
|
|
room_id: str,
|
|||
|
|
agent_id: str,
|
|||
|
|
is_typing: bool = True
|
|||
|
|
) -> None:
|
|||
|
|
"""
|
|||
|
|
广播Agent输入状态
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
room_id: 聊天室ID
|
|||
|
|
agent_id: Agent ID
|
|||
|
|
is_typing: 是否正在输入
|
|||
|
|
"""
|
|||
|
|
message = {
|
|||
|
|
"type": "typing",
|
|||
|
|
"agent_id": agent_id,
|
|||
|
|
"is_typing": is_typing,
|
|||
|
|
"timestamp": datetime.utcnow().isoformat()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
await cls.broadcast_to_room(room_id, message)
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def broadcast_round_info(
|
|||
|
|
cls,
|
|||
|
|
room_id: str,
|
|||
|
|
round_num: int,
|
|||
|
|
total_rounds: int
|
|||
|
|
) -> None:
|
|||
|
|
"""
|
|||
|
|
广播轮次信息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
room_id: 聊天室ID
|
|||
|
|
round_num: 当前轮次
|
|||
|
|
total_rounds: 最大轮次
|
|||
|
|
"""
|
|||
|
|
message = {
|
|||
|
|
"type": "round",
|
|||
|
|
"round": round_num,
|
|||
|
|
"total_rounds": total_rounds,
|
|||
|
|
"timestamp": datetime.utcnow().isoformat()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
await cls.broadcast_to_room(room_id, message)
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def broadcast_error(
|
|||
|
|
cls,
|
|||
|
|
room_id: str,
|
|||
|
|
error: str,
|
|||
|
|
agent_id: Optional[str] = None
|
|||
|
|
) -> None:
|
|||
|
|
"""
|
|||
|
|
广播错误信息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
room_id: 聊天室ID
|
|||
|
|
error: 错误信息
|
|||
|
|
agent_id: 相关Agent ID
|
|||
|
|
"""
|
|||
|
|
message = {
|
|||
|
|
"type": "error",
|
|||
|
|
"error": error,
|
|||
|
|
"agent_id": agent_id,
|
|||
|
|
"timestamp": datetime.utcnow().isoformat()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
await cls.broadcast_to_room(room_id, message)
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def register_callback(cls, callback: Callable) -> None:
|
|||
|
|
"""
|
|||
|
|
注册消息回调
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
callback: 回调函数,接收Message参数
|
|||
|
|
"""
|
|||
|
|
cls._message_callbacks.append(callback)
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def unregister_callback(cls, callback: Callable) -> None:
|
|||
|
|
"""
|
|||
|
|
注销消息回调
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
callback: 回调函数
|
|||
|
|
"""
|
|||
|
|
if callback in cls._message_callbacks:
|
|||
|
|
cls._message_callbacks.remove(callback)
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def get_connection_count(cls, room_id: str) -> int:
|
|||
|
|
"""
|
|||
|
|
获取房间连接数
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
room_id: 聊天室ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
连接数
|
|||
|
|
"""
|
|||
|
|
return len(cls._room_connections.get(room_id, set()))
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def get_all_room_ids(cls) -> List[str]:
|
|||
|
|
"""
|
|||
|
|
获取所有活跃房间ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
房间ID列表
|
|||
|
|
"""
|
|||
|
|
return list(cls._room_connections.keys())
|