feat: AI聊天室多Agent协作讨论平台
- 实现Agent管理,支持AI辅助生成系统提示词 - 支持多个AI提供商(OpenRouter、智谱、MiniMax等) - 实现聊天室和讨论引擎 - WebSocket实时消息推送 - 前端使用React + Ant Design - 后端使用FastAPI + MongoDB Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
589
backend/services/discussion_engine.py
Normal file
589
backend/services/discussion_engine.py
Normal file
@@ -0,0 +1,589 @@
|
||||
"""
|
||||
讨论引擎
|
||||
实现自由讨论的核心逻辑
|
||||
"""
|
||||
import uuid
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
|
||||
from models.chatroom import ChatRoom, ChatRoomStatus
|
||||
from models.agent import Agent
|
||||
from models.message import Message, MessageType
|
||||
from models.discussion_result import DiscussionResult
|
||||
from services.ai_provider_service import AIProviderService
|
||||
from services.agent_service import AgentService
|
||||
from services.chatroom_service import ChatRoomService
|
||||
from services.message_router import MessageRouter
|
||||
from services.consensus_manager import ConsensusManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiscussionContext:
|
||||
"""讨论上下文"""
|
||||
discussion_id: str
|
||||
room_id: str
|
||||
objective: str
|
||||
current_round: int = 0
|
||||
messages: List[Message] = field(default_factory=list)
|
||||
agent_speak_counts: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
def add_message(self, message: Message) -> None:
|
||||
"""添加消息到上下文"""
|
||||
self.messages.append(message)
|
||||
if message.agent_id:
|
||||
self.agent_speak_counts[message.agent_id] = \
|
||||
self.agent_speak_counts.get(message.agent_id, 0) + 1
|
||||
|
||||
def get_recent_messages(self, count: int = 20) -> List[Message]:
|
||||
"""获取最近的消息"""
|
||||
return self.messages[-count:] if len(self.messages) > count else self.messages
|
||||
|
||||
def get_agent_speak_count(self, agent_id: str) -> int:
|
||||
"""获取Agent在当前轮次的发言次数"""
|
||||
return self.agent_speak_counts.get(agent_id, 0)
|
||||
|
||||
def reset_round_counts(self) -> None:
|
||||
"""重置轮次发言计数"""
|
||||
self.agent_speak_counts.clear()
|
||||
|
||||
|
||||
class DiscussionEngine:
|
||||
"""
|
||||
讨论引擎
|
||||
实现多Agent自由讨论的核心逻辑
|
||||
"""
|
||||
|
||||
# 活跃的讨论: room_id -> DiscussionContext
|
||||
_active_discussions: Dict[str, DiscussionContext] = {}
|
||||
|
||||
# 停止信号
|
||||
_stop_signals: Dict[str, bool] = {}
|
||||
|
||||
@classmethod
|
||||
async def start_discussion(
|
||||
cls,
|
||||
room_id: str,
|
||||
objective: str
|
||||
) -> Optional[DiscussionResult]:
|
||||
"""
|
||||
启动讨论
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
objective: 讨论目标
|
||||
|
||||
Returns:
|
||||
讨论结果
|
||||
"""
|
||||
# 获取聊天室
|
||||
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||
if not chatroom:
|
||||
raise ValueError(f"聊天室不存在: {room_id}")
|
||||
|
||||
if not chatroom.agents:
|
||||
raise ValueError("聊天室没有Agent参与")
|
||||
|
||||
if not objective:
|
||||
raise ValueError("讨论目标不能为空")
|
||||
|
||||
# 检查是否已有活跃讨论
|
||||
if room_id in cls._active_discussions:
|
||||
raise ValueError("聊天室已有进行中的讨论")
|
||||
|
||||
# 创建讨论
|
||||
discussion_id = f"disc-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 创建讨论结果记录
|
||||
discussion_result = DiscussionResult(
|
||||
discussion_id=discussion_id,
|
||||
room_id=room_id,
|
||||
objective=objective,
|
||||
status="in_progress",
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
await discussion_result.insert()
|
||||
|
||||
# 创建讨论上下文
|
||||
context = DiscussionContext(
|
||||
discussion_id=discussion_id,
|
||||
room_id=room_id,
|
||||
objective=objective
|
||||
)
|
||||
cls._active_discussions[room_id] = context
|
||||
cls._stop_signals[room_id] = False
|
||||
|
||||
# 更新聊天室状态
|
||||
await ChatRoomService.update_chatroom(
|
||||
room_id,
|
||||
status=ChatRoomStatus.ACTIVE.value,
|
||||
objective=objective,
|
||||
current_discussion_id=discussion_id,
|
||||
current_round=0
|
||||
)
|
||||
|
||||
# 广播讨论开始
|
||||
await MessageRouter.broadcast_status(room_id, "discussion_started", {
|
||||
"discussion_id": discussion_id,
|
||||
"objective": objective
|
||||
})
|
||||
|
||||
# 发送系统消息
|
||||
await MessageRouter.save_and_broadcast_message(
|
||||
room_id=room_id,
|
||||
discussion_id=discussion_id,
|
||||
agent_id=None,
|
||||
content=f"讨论开始\n\n目标:{objective}",
|
||||
message_type=MessageType.SYSTEM.value,
|
||||
round_num=0
|
||||
)
|
||||
|
||||
logger.info(f"讨论开始: {room_id} - {objective}")
|
||||
|
||||
# 运行讨论循环
|
||||
try:
|
||||
result = await cls._run_discussion_loop(chatroom, context)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"讨论异常: {e}")
|
||||
await cls._handle_discussion_error(room_id, discussion_id, str(e))
|
||||
raise
|
||||
finally:
|
||||
# 清理
|
||||
cls._active_discussions.pop(room_id, None)
|
||||
cls._stop_signals.pop(room_id, None)
|
||||
|
||||
@classmethod
|
||||
async def stop_discussion(cls, room_id: str) -> bool:
|
||||
"""
|
||||
停止讨论
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if room_id not in cls._active_discussions:
|
||||
return False
|
||||
|
||||
cls._stop_signals[room_id] = True
|
||||
logger.info(f"收到停止讨论信号: {room_id}")
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def pause_discussion(cls, room_id: str) -> bool:
|
||||
"""
|
||||
暂停讨论
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if room_id not in cls._active_discussions:
|
||||
return False
|
||||
|
||||
await ChatRoomService.update_status(room_id, ChatRoomStatus.PAUSED)
|
||||
await MessageRouter.broadcast_status(room_id, "discussion_paused")
|
||||
|
||||
logger.info(f"讨论暂停: {room_id}")
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def resume_discussion(cls, room_id: str) -> bool:
|
||||
"""
|
||||
恢复讨论
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||
if not chatroom or chatroom.status != ChatRoomStatus.PAUSED.value:
|
||||
return False
|
||||
|
||||
await ChatRoomService.update_status(room_id, ChatRoomStatus.ACTIVE)
|
||||
await MessageRouter.broadcast_status(room_id, "discussion_resumed")
|
||||
|
||||
logger.info(f"讨论恢复: {room_id}")
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def _run_discussion_loop(
|
||||
cls,
|
||||
chatroom: ChatRoom,
|
||||
context: DiscussionContext
|
||||
) -> DiscussionResult:
|
||||
"""
|
||||
运行讨论循环
|
||||
|
||||
Args:
|
||||
chatroom: 聊天室
|
||||
context: 讨论上下文
|
||||
|
||||
Returns:
|
||||
讨论结果
|
||||
"""
|
||||
room_id = chatroom.room_id
|
||||
config = chatroom.get_config()
|
||||
|
||||
# 获取所有Agent
|
||||
agents = await AgentService.get_agents_by_ids(chatroom.agents)
|
||||
agent_map = {a.agent_id: a for a in agents}
|
||||
|
||||
# 获取主持人(用于共识判断)
|
||||
moderator = None
|
||||
if chatroom.moderator_agent_id:
|
||||
moderator = await AgentService.get_agent(chatroom.moderator_agent_id)
|
||||
|
||||
consecutive_no_speak = 0 # 连续无人发言的轮次
|
||||
|
||||
while context.current_round < config.max_rounds:
|
||||
# 检查停止信号
|
||||
if cls._stop_signals.get(room_id, False):
|
||||
break
|
||||
|
||||
# 检查暂停状态
|
||||
current_chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||
if current_chatroom and current_chatroom.status == ChatRoomStatus.PAUSED.value:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
# 增加轮次
|
||||
context.current_round += 1
|
||||
context.reset_round_counts()
|
||||
|
||||
# 广播轮次信息
|
||||
await MessageRouter.broadcast_round_info(
|
||||
room_id,
|
||||
context.current_round,
|
||||
config.max_rounds
|
||||
)
|
||||
|
||||
# 更新聊天室轮次
|
||||
await ChatRoomService.update_chatroom(
|
||||
room_id,
|
||||
current_round=context.current_round
|
||||
)
|
||||
|
||||
# 本轮是否有人发言
|
||||
round_has_message = False
|
||||
|
||||
# 遍历所有Agent,判断是否发言
|
||||
for agent_id in chatroom.agents:
|
||||
agent = agent_map.get(agent_id)
|
||||
if not agent or not agent.enabled:
|
||||
continue
|
||||
|
||||
# 检查本轮发言次数限制
|
||||
behavior = agent.get_behavior()
|
||||
if context.get_agent_speak_count(agent_id) >= behavior.max_speak_per_round:
|
||||
continue
|
||||
|
||||
# 判断是否发言
|
||||
should_speak, content = await cls._should_agent_speak(
|
||||
agent, context, chatroom
|
||||
)
|
||||
|
||||
if should_speak and content:
|
||||
# 广播输入状态
|
||||
await MessageRouter.broadcast_typing(room_id, agent_id, True)
|
||||
|
||||
# 保存并广播消息
|
||||
message = await MessageRouter.save_and_broadcast_message(
|
||||
room_id=room_id,
|
||||
discussion_id=context.discussion_id,
|
||||
agent_id=agent_id,
|
||||
content=content,
|
||||
message_type=MessageType.TEXT.value,
|
||||
round_num=context.current_round
|
||||
)
|
||||
|
||||
# 更新上下文
|
||||
context.add_message(message)
|
||||
round_has_message = True
|
||||
|
||||
# 广播输入结束
|
||||
await MessageRouter.broadcast_typing(room_id, agent_id, False)
|
||||
|
||||
# 轮次间隔
|
||||
await asyncio.sleep(config.round_interval)
|
||||
|
||||
# 检查是否需要共识判断
|
||||
if round_has_message and moderator:
|
||||
consecutive_no_speak = 0
|
||||
|
||||
# 每隔几轮检查一次共识
|
||||
if context.current_round % 3 == 0 or context.current_round >= config.max_rounds - 5:
|
||||
consensus_result = await ConsensusManager.check_consensus(
|
||||
moderator, context, chatroom
|
||||
)
|
||||
|
||||
if consensus_result.get("consensus_reached", False):
|
||||
confidence = consensus_result.get("confidence", 0)
|
||||
if confidence >= config.consensus_threshold:
|
||||
# 达成共识,结束讨论
|
||||
return await cls._finalize_discussion(
|
||||
context,
|
||||
consensus_result,
|
||||
"consensus"
|
||||
)
|
||||
else:
|
||||
consecutive_no_speak += 1
|
||||
|
||||
# 连续多轮无人发言,检查共识或结束
|
||||
if consecutive_no_speak >= 3:
|
||||
if moderator:
|
||||
consensus_result = await ConsensusManager.check_consensus(
|
||||
moderator, context, chatroom
|
||||
)
|
||||
return await cls._finalize_discussion(
|
||||
context,
|
||||
consensus_result,
|
||||
"no_more_discussion"
|
||||
)
|
||||
else:
|
||||
return await cls._finalize_discussion(
|
||||
context,
|
||||
{"consensus_reached": False, "summary": "讨论结束,无明确共识"},
|
||||
"no_more_discussion"
|
||||
)
|
||||
|
||||
# 达到最大轮次
|
||||
if moderator:
|
||||
consensus_result = await ConsensusManager.check_consensus(
|
||||
moderator, context, chatroom
|
||||
)
|
||||
else:
|
||||
consensus_result = {"consensus_reached": False, "summary": "达到最大轮次限制"}
|
||||
|
||||
return await cls._finalize_discussion(
|
||||
context,
|
||||
consensus_result,
|
||||
"max_rounds"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _should_agent_speak(
|
||||
cls,
|
||||
agent: Agent,
|
||||
context: DiscussionContext,
|
||||
chatroom: ChatRoom
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
判断Agent是否应该发言
|
||||
|
||||
Args:
|
||||
agent: Agent实例
|
||||
context: 讨论上下文
|
||||
chatroom: 聊天室
|
||||
|
||||
Returns:
|
||||
(是否发言, 发言内容)
|
||||
"""
|
||||
# 构建判断提示词
|
||||
recent_messages = context.get_recent_messages(chatroom.get_config().message_history_size)
|
||||
|
||||
history_text = ""
|
||||
for msg in recent_messages:
|
||||
if msg.agent_id:
|
||||
history_text += f"[{msg.agent_id}]: {msg.content}\n\n"
|
||||
else:
|
||||
history_text += f"[系统]: {msg.content}\n\n"
|
||||
|
||||
prompt = f"""你是{agent.name},角色是{agent.role}。
|
||||
|
||||
{agent.system_prompt}
|
||||
|
||||
当前讨论目标:{context.objective}
|
||||
|
||||
对话历史:
|
||||
{history_text if history_text else "(还没有对话)"}
|
||||
|
||||
当前是第{context.current_round}轮讨论。
|
||||
|
||||
请根据你的角色判断:
|
||||
1. 你是否有新的观点或建议要分享?
|
||||
2. 你是否需要回应其他人的观点?
|
||||
3. 当前讨论是否需要你的专业意见?
|
||||
|
||||
如果你认为需要发言,请直接给出你的发言内容。
|
||||
如果你认为暂时不需要发言(例如等待更多信息、当前轮次已有足够讨论、或者你的观点已经充分表达),请只回复"PASS"。
|
||||
|
||||
注意:
|
||||
- 请保持发言简洁有力,每次发言控制在200字以内
|
||||
- 避免重复已经说过的内容
|
||||
- 如果已经达成共识或接近共识,可以选择PASS"""
|
||||
|
||||
try:
|
||||
# 调用AI接口
|
||||
response = await AIProviderService.chat(
|
||||
provider_id=agent.provider_id,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=agent.temperature,
|
||||
max_tokens=agent.max_tokens
|
||||
)
|
||||
|
||||
if not response.success:
|
||||
logger.warning(f"Agent {agent.agent_id} 响应失败: {response.error}")
|
||||
return False, ""
|
||||
|
||||
content = response.content.strip()
|
||||
|
||||
# 判断是否PASS
|
||||
if content.upper() == "PASS" or content.upper().startswith("PASS"):
|
||||
return False, ""
|
||||
|
||||
return True, content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent {agent.agent_id} 判断发言异常: {e}")
|
||||
return False, ""
|
||||
|
||||
@classmethod
|
||||
async def _finalize_discussion(
|
||||
cls,
|
||||
context: DiscussionContext,
|
||||
consensus_result: Dict[str, Any],
|
||||
end_reason: str
|
||||
) -> DiscussionResult:
|
||||
"""
|
||||
完成讨论,保存结果
|
||||
|
||||
Args:
|
||||
context: 讨论上下文
|
||||
consensus_result: 共识判断结果
|
||||
end_reason: 结束原因
|
||||
|
||||
Returns:
|
||||
讨论结果
|
||||
"""
|
||||
room_id = context.room_id
|
||||
|
||||
# 获取讨论结果记录
|
||||
discussion_result = await DiscussionResult.find_one(
|
||||
DiscussionResult.discussion_id == context.discussion_id
|
||||
)
|
||||
|
||||
if discussion_result:
|
||||
# 更新统计
|
||||
discussion_result.update_stats(
|
||||
total_rounds=context.current_round,
|
||||
total_messages=len(context.messages),
|
||||
agent_contributions=context.agent_speak_counts
|
||||
)
|
||||
|
||||
# 标记完成
|
||||
discussion_result.mark_completed(
|
||||
consensus_reached=consensus_result.get("consensus_reached", False),
|
||||
confidence=consensus_result.get("confidence", 0),
|
||||
summary=consensus_result.get("summary", ""),
|
||||
action_items=consensus_result.get("action_items", []),
|
||||
unresolved_issues=consensus_result.get("unresolved_issues", []),
|
||||
end_reason=end_reason
|
||||
)
|
||||
|
||||
await discussion_result.save()
|
||||
|
||||
# 更新聊天室状态
|
||||
await ChatRoomService.update_status(room_id, ChatRoomStatus.COMPLETED)
|
||||
|
||||
# 发送系统消息
|
||||
summary_text = f"""讨论结束
|
||||
|
||||
结果:{"达成共识" if consensus_result.get("consensus_reached") else "未达成明确共识"}
|
||||
置信度:{consensus_result.get("confidence", 0):.0%}
|
||||
|
||||
摘要:{consensus_result.get("summary", "无")}
|
||||
|
||||
行动项:
|
||||
{chr(10).join("- " + item for item in consensus_result.get("action_items", [])) or "无"}
|
||||
|
||||
未解决问题:
|
||||
{chr(10).join("- " + issue for issue in consensus_result.get("unresolved_issues", [])) or "无"}
|
||||
|
||||
共进行 {context.current_round} 轮讨论,产生 {len(context.messages)} 条消息。"""
|
||||
|
||||
await MessageRouter.save_and_broadcast_message(
|
||||
room_id=room_id,
|
||||
discussion_id=context.discussion_id,
|
||||
agent_id=None,
|
||||
content=summary_text,
|
||||
message_type=MessageType.SYSTEM.value,
|
||||
round_num=context.current_round
|
||||
)
|
||||
|
||||
# 广播讨论结束
|
||||
await MessageRouter.broadcast_status(room_id, "discussion_completed", {
|
||||
"discussion_id": context.discussion_id,
|
||||
"consensus_reached": consensus_result.get("consensus_reached", False),
|
||||
"end_reason": end_reason
|
||||
})
|
||||
|
||||
logger.info(f"讨论结束: {room_id}, 原因: {end_reason}")
|
||||
|
||||
return discussion_result
|
||||
|
||||
@classmethod
|
||||
async def _handle_discussion_error(
|
||||
cls,
|
||||
room_id: str,
|
||||
discussion_id: str,
|
||||
error: str
|
||||
) -> None:
|
||||
"""
|
||||
处理讨论错误
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
discussion_id: 讨论ID
|
||||
error: 错误信息
|
||||
"""
|
||||
# 更新聊天室状态
|
||||
await ChatRoomService.update_status(room_id, ChatRoomStatus.ERROR)
|
||||
|
||||
# 更新讨论结果
|
||||
discussion_result = await DiscussionResult.find_one(
|
||||
DiscussionResult.discussion_id == discussion_id
|
||||
)
|
||||
if discussion_result:
|
||||
discussion_result.status = "failed"
|
||||
discussion_result.end_reason = f"error: {error}"
|
||||
discussion_result.updated_at = datetime.utcnow()
|
||||
await discussion_result.save()
|
||||
|
||||
# 广播错误
|
||||
await MessageRouter.broadcast_error(room_id, error)
|
||||
|
||||
@classmethod
|
||||
def get_active_discussion(cls, room_id: str) -> Optional[DiscussionContext]:
|
||||
"""
|
||||
获取活跃的讨论上下文
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
讨论上下文或None
|
||||
"""
|
||||
return cls._active_discussions.get(room_id)
|
||||
|
||||
@classmethod
|
||||
def is_discussion_active(cls, room_id: str) -> bool:
|
||||
"""
|
||||
检查是否有活跃讨论
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
是否活跃
|
||||
"""
|
||||
return room_id in cls._active_discussions
|
||||
Reference in New Issue
Block a user