feat: AI聊天室多Agent协作讨论平台
- 实现Agent管理,支持AI辅助生成系统提示词 - 支持多个AI提供商(OpenRouter、智谱、MiniMax等) - 实现聊天室和讨论引擎 - WebSocket实时消息推送 - 前端使用React + Ant Design - 后端使用FastAPI + MongoDB Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
335
backend/services/message_router.py
Normal file
335
backend/services/message_router.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
消息路由服务
|
||||
管理消息的发送和广播
|
||||
"""
|
||||
import uuid
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, Callable, Set
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
from fastapi import WebSocket
|
||||
|
||||
from models.message import Message, MessageType
|
||||
from models.chatroom import ChatRoom
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSocketConnection:
|
||||
"""WebSocket连接信息"""
|
||||
websocket: WebSocket
|
||||
room_id: str
|
||||
connected_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class MessageRouter:
|
||||
"""
|
||||
消息路由器
|
||||
管理WebSocket连接和消息广播
|
||||
"""
|
||||
|
||||
# 房间连接映射: room_id -> Set[WebSocket]
|
||||
_room_connections: Dict[str, Set[WebSocket]] = {}
|
||||
|
||||
# 消息回调: 用于外部订阅消息
|
||||
_message_callbacks: List[Callable] = []
|
||||
|
||||
@classmethod
|
||||
async def connect(cls, room_id: str, websocket: WebSocket) -> None:
|
||||
"""
|
||||
建立WebSocket连接
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
websocket: WebSocket实例
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
if room_id not in cls._room_connections:
|
||||
cls._room_connections[room_id] = set()
|
||||
|
||||
cls._room_connections[room_id].add(websocket)
|
||||
|
||||
logger.info(f"WebSocket连接建立: {room_id}, 当前连接数: {len(cls._room_connections[room_id])}")
|
||||
|
||||
@classmethod
|
||||
async def disconnect(cls, room_id: str, websocket: WebSocket) -> None:
|
||||
"""
|
||||
断开WebSocket连接
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
websocket: WebSocket实例
|
||||
"""
|
||||
if room_id in cls._room_connections:
|
||||
cls._room_connections[room_id].discard(websocket)
|
||||
|
||||
# 清理空房间
|
||||
if not cls._room_connections[room_id]:
|
||||
del cls._room_connections[room_id]
|
||||
|
||||
logger.info(f"WebSocket连接断开: {room_id}")
|
||||
|
||||
@classmethod
|
||||
async def broadcast_to_room(
|
||||
cls,
|
||||
room_id: str,
|
||||
message: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
向聊天室广播消息
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
message: 消息内容
|
||||
"""
|
||||
if room_id not in cls._room_connections:
|
||||
return
|
||||
|
||||
# 获取所有连接
|
||||
connections = cls._room_connections[room_id].copy()
|
||||
|
||||
# 并发发送
|
||||
tasks = []
|
||||
for websocket in connections:
|
||||
tasks.append(cls._send_message(room_id, websocket, message))
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
@classmethod
|
||||
async def _send_message(
|
||||
cls,
|
||||
room_id: str,
|
||||
websocket: WebSocket,
|
||||
message: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
向单个WebSocket发送消息
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
websocket: WebSocket实例
|
||||
message: 消息内容
|
||||
"""
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket发送失败: {e}")
|
||||
# 移除断开的连接
|
||||
await cls.disconnect(room_id, websocket)
|
||||
|
||||
@classmethod
|
||||
async def save_and_broadcast_message(
|
||||
cls,
|
||||
room_id: str,
|
||||
discussion_id: str,
|
||||
agent_id: Optional[str],
|
||||
content: str,
|
||||
message_type: str = MessageType.TEXT.value,
|
||||
round_num: int = 0,
|
||||
attachments: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_results: Optional[List[Dict[str, Any]]] = None
|
||||
) -> Message:
|
||||
"""
|
||||
保存消息并广播
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
discussion_id: 讨论ID
|
||||
agent_id: 发送Agent ID
|
||||
content: 消息内容
|
||||
message_type: 消息类型
|
||||
round_num: 轮次号
|
||||
attachments: 附件
|
||||
tool_calls: 工具调用
|
||||
tool_results: 工具结果
|
||||
|
||||
Returns:
|
||||
保存的Message文档
|
||||
"""
|
||||
# 创建消息
|
||||
message = Message(
|
||||
message_id=f"msg-{uuid.uuid4().hex[:12]}",
|
||||
room_id=room_id,
|
||||
discussion_id=discussion_id,
|
||||
agent_id=agent_id,
|
||||
content=content,
|
||||
message_type=message_type,
|
||||
attachments=attachments or [],
|
||||
round=round_num,
|
||||
token_count=len(content) // 4, # 粗略估计
|
||||
tool_calls=tool_calls or [],
|
||||
tool_results=tool_results or [],
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
await message.insert()
|
||||
|
||||
# 构建广播消息
|
||||
broadcast_data = {
|
||||
"type": "message",
|
||||
"data": {
|
||||
"message_id": message.message_id,
|
||||
"room_id": message.room_id,
|
||||
"discussion_id": message.discussion_id,
|
||||
"agent_id": message.agent_id,
|
||||
"content": message.content,
|
||||
"message_type": message.message_type,
|
||||
"round": message.round,
|
||||
"created_at": message.created_at.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
# 广播消息
|
||||
await cls.broadcast_to_room(room_id, broadcast_data)
|
||||
|
||||
# 触发回调
|
||||
for callback in cls._message_callbacks:
|
||||
try:
|
||||
await callback(message)
|
||||
except Exception as e:
|
||||
logger.error(f"消息回调执行失败: {e}")
|
||||
|
||||
return message
|
||||
|
||||
@classmethod
|
||||
async def broadcast_status(
|
||||
cls,
|
||||
room_id: str,
|
||||
status: str,
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""
|
||||
广播状态更新
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
status: 状态类型
|
||||
data: 附加数据
|
||||
"""
|
||||
message = {
|
||||
"type": "status",
|
||||
"status": status,
|
||||
"data": data or {},
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await cls.broadcast_to_room(room_id, message)
|
||||
|
||||
@classmethod
|
||||
async def broadcast_typing(
|
||||
cls,
|
||||
room_id: str,
|
||||
agent_id: str,
|
||||
is_typing: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
广播Agent输入状态
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
agent_id: Agent ID
|
||||
is_typing: 是否正在输入
|
||||
"""
|
||||
message = {
|
||||
"type": "typing",
|
||||
"agent_id": agent_id,
|
||||
"is_typing": is_typing,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await cls.broadcast_to_room(room_id, message)
|
||||
|
||||
@classmethod
|
||||
async def broadcast_round_info(
|
||||
cls,
|
||||
room_id: str,
|
||||
round_num: int,
|
||||
total_rounds: int
|
||||
) -> None:
|
||||
"""
|
||||
广播轮次信息
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
round_num: 当前轮次
|
||||
total_rounds: 最大轮次
|
||||
"""
|
||||
message = {
|
||||
"type": "round",
|
||||
"round": round_num,
|
||||
"total_rounds": total_rounds,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await cls.broadcast_to_room(room_id, message)
|
||||
|
||||
@classmethod
|
||||
async def broadcast_error(
|
||||
cls,
|
||||
room_id: str,
|
||||
error: str,
|
||||
agent_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
广播错误信息
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
error: 错误信息
|
||||
agent_id: 相关Agent ID
|
||||
"""
|
||||
message = {
|
||||
"type": "error",
|
||||
"error": error,
|
||||
"agent_id": agent_id,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await cls.broadcast_to_room(room_id, message)
|
||||
|
||||
@classmethod
|
||||
def register_callback(cls, callback: Callable) -> None:
|
||||
"""
|
||||
注册消息回调
|
||||
|
||||
Args:
|
||||
callback: 回调函数,接收Message参数
|
||||
"""
|
||||
cls._message_callbacks.append(callback)
|
||||
|
||||
@classmethod
|
||||
def unregister_callback(cls, callback: Callable) -> None:
|
||||
"""
|
||||
注销消息回调
|
||||
|
||||
Args:
|
||||
callback: 回调函数
|
||||
"""
|
||||
if callback in cls._message_callbacks:
|
||||
cls._message_callbacks.remove(callback)
|
||||
|
||||
@classmethod
|
||||
def get_connection_count(cls, room_id: str) -> int:
|
||||
"""
|
||||
获取房间连接数
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
连接数
|
||||
"""
|
||||
return len(cls._room_connections.get(room_id, set()))
|
||||
|
||||
@classmethod
|
||||
def get_all_room_ids(cls) -> List[str]:
|
||||
"""
|
||||
获取所有活跃房间ID
|
||||
|
||||
Returns:
|
||||
房间ID列表
|
||||
"""
|
||||
return list(cls._room_connections.keys())
|
||||
Reference in New Issue
Block a user