- 实现Agent管理,支持AI辅助生成系统提示词 - 支持多个AI提供商(OpenRouter、智谱、MiniMax等) - 实现聊天室和讨论引擎 - WebSocket实时消息推送 - 前端使用React + Ant Design - 后端使用FastAPI + MongoDB Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
336 lines
8.9 KiB
Python
336 lines
8.9 KiB
Python
"""
|
||
消息路由服务
|
||
管理消息的发送和广播
|
||
"""
|
||
import uuid
|
||
import asyncio
|
||
from datetime import datetime
|
||
from typing import List, Dict, Any, Optional, Callable, Set
|
||
from dataclasses import dataclass, field
|
||
from loguru import logger
|
||
from fastapi import WebSocket
|
||
|
||
from models.message import Message, MessageType
|
||
from models.chatroom import ChatRoom
|
||
|
||
|
||
@dataclass
|
||
class WebSocketConnection:
|
||
"""WebSocket连接信息"""
|
||
websocket: WebSocket
|
||
room_id: str
|
||
connected_at: datetime = field(default_factory=datetime.utcnow)
|
||
|
||
|
||
class MessageRouter:
|
||
"""
|
||
消息路由器
|
||
管理WebSocket连接和消息广播
|
||
"""
|
||
|
||
# 房间连接映射: room_id -> Set[WebSocket]
|
||
_room_connections: Dict[str, Set[WebSocket]] = {}
|
||
|
||
# 消息回调: 用于外部订阅消息
|
||
_message_callbacks: List[Callable] = []
|
||
|
||
@classmethod
|
||
async def connect(cls, room_id: str, websocket: WebSocket) -> None:
|
||
"""
|
||
建立WebSocket连接
|
||
|
||
Args:
|
||
room_id: 聊天室ID
|
||
websocket: WebSocket实例
|
||
"""
|
||
await websocket.accept()
|
||
|
||
if room_id not in cls._room_connections:
|
||
cls._room_connections[room_id] = set()
|
||
|
||
cls._room_connections[room_id].add(websocket)
|
||
|
||
logger.info(f"WebSocket连接建立: {room_id}, 当前连接数: {len(cls._room_connections[room_id])}")
|
||
|
||
@classmethod
|
||
async def disconnect(cls, room_id: str, websocket: WebSocket) -> None:
|
||
"""
|
||
断开WebSocket连接
|
||
|
||
Args:
|
||
room_id: 聊天室ID
|
||
websocket: WebSocket实例
|
||
"""
|
||
if room_id in cls._room_connections:
|
||
cls._room_connections[room_id].discard(websocket)
|
||
|
||
# 清理空房间
|
||
if not cls._room_connections[room_id]:
|
||
del cls._room_connections[room_id]
|
||
|
||
logger.info(f"WebSocket连接断开: {room_id}")
|
||
|
||
@classmethod
|
||
async def broadcast_to_room(
|
||
cls,
|
||
room_id: str,
|
||
message: Dict[str, Any]
|
||
) -> None:
|
||
"""
|
||
向聊天室广播消息
|
||
|
||
Args:
|
||
room_id: 聊天室ID
|
||
message: 消息内容
|
||
"""
|
||
if room_id not in cls._room_connections:
|
||
return
|
||
|
||
# 获取所有连接
|
||
connections = cls._room_connections[room_id].copy()
|
||
|
||
# 并发发送
|
||
tasks = []
|
||
for websocket in connections:
|
||
tasks.append(cls._send_message(room_id, websocket, message))
|
||
|
||
if tasks:
|
||
await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
@classmethod
|
||
async def _send_message(
|
||
cls,
|
||
room_id: str,
|
||
websocket: WebSocket,
|
||
message: Dict[str, Any]
|
||
) -> None:
|
||
"""
|
||
向单个WebSocket发送消息
|
||
|
||
Args:
|
||
room_id: 聊天室ID
|
||
websocket: WebSocket实例
|
||
message: 消息内容
|
||
"""
|
||
try:
|
||
await websocket.send_json(message)
|
||
except Exception as e:
|
||
logger.warning(f"WebSocket发送失败: {e}")
|
||
# 移除断开的连接
|
||
await cls.disconnect(room_id, websocket)
|
||
|
||
@classmethod
|
||
async def save_and_broadcast_message(
|
||
cls,
|
||
room_id: str,
|
||
discussion_id: str,
|
||
agent_id: Optional[str],
|
||
content: str,
|
||
message_type: str = MessageType.TEXT.value,
|
||
round_num: int = 0,
|
||
attachments: Optional[List[Dict[str, Any]]] = None,
|
||
tool_calls: Optional[List[Dict[str, Any]]] = None,
|
||
tool_results: Optional[List[Dict[str, Any]]] = None
|
||
) -> Message:
|
||
"""
|
||
保存消息并广播
|
||
|
||
Args:
|
||
room_id: 聊天室ID
|
||
discussion_id: 讨论ID
|
||
agent_id: 发送Agent ID
|
||
content: 消息内容
|
||
message_type: 消息类型
|
||
round_num: 轮次号
|
||
attachments: 附件
|
||
tool_calls: 工具调用
|
||
tool_results: 工具结果
|
||
|
||
Returns:
|
||
保存的Message文档
|
||
"""
|
||
# 创建消息
|
||
message = Message(
|
||
message_id=f"msg-{uuid.uuid4().hex[:12]}",
|
||
room_id=room_id,
|
||
discussion_id=discussion_id,
|
||
agent_id=agent_id,
|
||
content=content,
|
||
message_type=message_type,
|
||
attachments=attachments or [],
|
||
round=round_num,
|
||
token_count=len(content) // 4, # 粗略估计
|
||
tool_calls=tool_calls or [],
|
||
tool_results=tool_results or [],
|
||
created_at=datetime.utcnow()
|
||
)
|
||
|
||
await message.insert()
|
||
|
||
# 构建广播消息
|
||
broadcast_data = {
|
||
"type": "message",
|
||
"data": {
|
||
"message_id": message.message_id,
|
||
"room_id": message.room_id,
|
||
"discussion_id": message.discussion_id,
|
||
"agent_id": message.agent_id,
|
||
"content": message.content,
|
||
"message_type": message.message_type,
|
||
"round": message.round,
|
||
"created_at": message.created_at.isoformat()
|
||
}
|
||
}
|
||
|
||
# 广播消息
|
||
await cls.broadcast_to_room(room_id, broadcast_data)
|
||
|
||
# 触发回调
|
||
for callback in cls._message_callbacks:
|
||
try:
|
||
await callback(message)
|
||
except Exception as e:
|
||
logger.error(f"消息回调执行失败: {e}")
|
||
|
||
return message
|
||
|
||
@classmethod
|
||
async def broadcast_status(
|
||
cls,
|
||
room_id: str,
|
||
status: str,
|
||
data: Optional[Dict[str, Any]] = None
|
||
) -> None:
|
||
"""
|
||
广播状态更新
|
||
|
||
Args:
|
||
room_id: 聊天室ID
|
||
status: 状态类型
|
||
data: 附加数据
|
||
"""
|
||
message = {
|
||
"type": "status",
|
||
"status": status,
|
||
"data": data or {},
|
||
"timestamp": datetime.utcnow().isoformat()
|
||
}
|
||
|
||
await cls.broadcast_to_room(room_id, message)
|
||
|
||
@classmethod
|
||
async def broadcast_typing(
|
||
cls,
|
||
room_id: str,
|
||
agent_id: str,
|
||
is_typing: bool = True
|
||
) -> None:
|
||
"""
|
||
广播Agent输入状态
|
||
|
||
Args:
|
||
room_id: 聊天室ID
|
||
agent_id: Agent ID
|
||
is_typing: 是否正在输入
|
||
"""
|
||
message = {
|
||
"type": "typing",
|
||
"agent_id": agent_id,
|
||
"is_typing": is_typing,
|
||
"timestamp": datetime.utcnow().isoformat()
|
||
}
|
||
|
||
await cls.broadcast_to_room(room_id, message)
|
||
|
||
@classmethod
|
||
async def broadcast_round_info(
|
||
cls,
|
||
room_id: str,
|
||
round_num: int,
|
||
total_rounds: int
|
||
) -> None:
|
||
"""
|
||
广播轮次信息
|
||
|
||
Args:
|
||
room_id: 聊天室ID
|
||
round_num: 当前轮次
|
||
total_rounds: 最大轮次
|
||
"""
|
||
message = {
|
||
"type": "round",
|
||
"round": round_num,
|
||
"total_rounds": total_rounds,
|
||
"timestamp": datetime.utcnow().isoformat()
|
||
}
|
||
|
||
await cls.broadcast_to_room(room_id, message)
|
||
|
||
@classmethod
|
||
async def broadcast_error(
|
||
cls,
|
||
room_id: str,
|
||
error: str,
|
||
agent_id: Optional[str] = None
|
||
) -> None:
|
||
"""
|
||
广播错误信息
|
||
|
||
Args:
|
||
room_id: 聊天室ID
|
||
error: 错误信息
|
||
agent_id: 相关Agent ID
|
||
"""
|
||
message = {
|
||
"type": "error",
|
||
"error": error,
|
||
"agent_id": agent_id,
|
||
"timestamp": datetime.utcnow().isoformat()
|
||
}
|
||
|
||
await cls.broadcast_to_room(room_id, message)
|
||
|
||
@classmethod
|
||
def register_callback(cls, callback: Callable) -> None:
|
||
"""
|
||
注册消息回调
|
||
|
||
Args:
|
||
callback: 回调函数,接收Message参数
|
||
"""
|
||
cls._message_callbacks.append(callback)
|
||
|
||
@classmethod
|
||
def unregister_callback(cls, callback: Callable) -> None:
|
||
"""
|
||
注销消息回调
|
||
|
||
Args:
|
||
callback: 回调函数
|
||
"""
|
||
if callback in cls._message_callbacks:
|
||
cls._message_callbacks.remove(callback)
|
||
|
||
@classmethod
|
||
def get_connection_count(cls, room_id: str) -> int:
|
||
"""
|
||
获取房间连接数
|
||
|
||
Args:
|
||
room_id: 聊天室ID
|
||
|
||
Returns:
|
||
连接数
|
||
"""
|
||
return len(cls._room_connections.get(room_id, set()))
|
||
|
||
@classmethod
|
||
def get_all_room_ids(cls) -> List[str]:
|
||
"""
|
||
获取所有活跃房间ID
|
||
|
||
Returns:
|
||
房间ID列表
|
||
"""
|
||
return list(cls._room_connections.keys())
|