- 实现Agent管理,支持AI辅助生成系统提示词 - 支持多个AI提供商(OpenRouter、智谱、MiniMax等) - 实现聊天室和讨论引擎 - WebSocket实时消息推送 - 前端使用React + Ant Design - 后端使用FastAPI + MongoDB Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
590 lines
19 KiB
Python
590 lines
19 KiB
Python
"""
|
||
讨论引擎
|
||
实现自由讨论的核心逻辑
|
||
"""
|
||
import uuid
|
||
import asyncio
|
||
from datetime import datetime
|
||
from typing import List, Dict, Any, Optional
|
||
from dataclasses import dataclass, field
|
||
from loguru import logger
|
||
|
||
from models.chatroom import ChatRoom, ChatRoomStatus
|
||
from models.agent import Agent
|
||
from models.message import Message, MessageType
|
||
from models.discussion_result import DiscussionResult
|
||
from services.ai_provider_service import AIProviderService
|
||
from services.agent_service import AgentService
|
||
from services.chatroom_service import ChatRoomService
|
||
from services.message_router import MessageRouter
|
||
from services.consensus_manager import ConsensusManager
|
||
|
||
|
||
@dataclass
|
||
class DiscussionContext:
|
||
"""讨论上下文"""
|
||
discussion_id: str
|
||
room_id: str
|
||
objective: str
|
||
current_round: int = 0
|
||
messages: List[Message] = field(default_factory=list)
|
||
agent_speak_counts: Dict[str, int] = field(default_factory=dict)
|
||
|
||
def add_message(self, message: Message) -> None:
|
||
"""添加消息到上下文"""
|
||
self.messages.append(message)
|
||
if message.agent_id:
|
||
self.agent_speak_counts[message.agent_id] = \
|
||
self.agent_speak_counts.get(message.agent_id, 0) + 1
|
||
|
||
def get_recent_messages(self, count: int = 20) -> List[Message]:
|
||
"""获取最近的消息"""
|
||
return self.messages[-count:] if len(self.messages) > count else self.messages
|
||
|
||
def get_agent_speak_count(self, agent_id: str) -> int:
|
||
"""获取Agent在当前轮次的发言次数"""
|
||
return self.agent_speak_counts.get(agent_id, 0)
|
||
|
||
def reset_round_counts(self) -> None:
|
||
"""重置轮次发言计数"""
|
||
self.agent_speak_counts.clear()
|
||
|
||
|
||
class DiscussionEngine:
|
||
"""
|
||
讨论引擎
|
||
实现多Agent自由讨论的核心逻辑
|
||
"""
|
||
|
||
# 活跃的讨论: room_id -> DiscussionContext
|
||
_active_discussions: Dict[str, DiscussionContext] = {}
|
||
|
||
# 停止信号
|
||
_stop_signals: Dict[str, bool] = {}
|
||
|
||
@classmethod
|
||
async def start_discussion(
|
||
cls,
|
||
room_id: str,
|
||
objective: str
|
||
) -> Optional[DiscussionResult]:
|
||
"""
|
||
启动讨论
|
||
|
||
Args:
|
||
room_id: 聊天室ID
|
||
objective: 讨论目标
|
||
|
||
Returns:
|
||
讨论结果
|
||
"""
|
||
# 获取聊天室
|
||
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||
if not chatroom:
|
||
raise ValueError(f"聊天室不存在: {room_id}")
|
||
|
||
if not chatroom.agents:
|
||
raise ValueError("聊天室没有Agent参与")
|
||
|
||
if not objective:
|
||
raise ValueError("讨论目标不能为空")
|
||
|
||
# 检查是否已有活跃讨论
|
||
if room_id in cls._active_discussions:
|
||
raise ValueError("聊天室已有进行中的讨论")
|
||
|
||
# 创建讨论
|
||
discussion_id = f"disc-{uuid.uuid4().hex[:8]}"
|
||
|
||
# 创建讨论结果记录
|
||
discussion_result = DiscussionResult(
|
||
discussion_id=discussion_id,
|
||
room_id=room_id,
|
||
objective=objective,
|
||
status="in_progress",
|
||
created_at=datetime.utcnow(),
|
||
updated_at=datetime.utcnow()
|
||
)
|
||
await discussion_result.insert()
|
||
|
||
# 创建讨论上下文
|
||
context = DiscussionContext(
|
||
discussion_id=discussion_id,
|
||
room_id=room_id,
|
||
objective=objective
|
||
)
|
||
cls._active_discussions[room_id] = context
|
||
cls._stop_signals[room_id] = False
|
||
|
||
# 更新聊天室状态
|
||
await ChatRoomService.update_chatroom(
|
||
room_id,
|
||
status=ChatRoomStatus.ACTIVE.value,
|
||
objective=objective,
|
||
current_discussion_id=discussion_id,
|
||
current_round=0
|
||
)
|
||
|
||
# 广播讨论开始
|
||
await MessageRouter.broadcast_status(room_id, "discussion_started", {
|
||
"discussion_id": discussion_id,
|
||
"objective": objective
|
||
})
|
||
|
||
# 发送系统消息
|
||
await MessageRouter.save_and_broadcast_message(
|
||
room_id=room_id,
|
||
discussion_id=discussion_id,
|
||
agent_id=None,
|
||
content=f"讨论开始\n\n目标:{objective}",
|
||
message_type=MessageType.SYSTEM.value,
|
||
round_num=0
|
||
)
|
||
|
||
logger.info(f"讨论开始: {room_id} - {objective}")
|
||
|
||
# 运行讨论循环
|
||
try:
|
||
result = await cls._run_discussion_loop(chatroom, context)
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"讨论异常: {e}")
|
||
await cls._handle_discussion_error(room_id, discussion_id, str(e))
|
||
raise
|
||
finally:
|
||
# 清理
|
||
cls._active_discussions.pop(room_id, None)
|
||
cls._stop_signals.pop(room_id, None)
|
||
|
||
@classmethod
|
||
async def stop_discussion(cls, room_id: str) -> bool:
|
||
"""
|
||
停止讨论
|
||
|
||
Args:
|
||
room_id: 聊天室ID
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
if room_id not in cls._active_discussions:
|
||
return False
|
||
|
||
cls._stop_signals[room_id] = True
|
||
logger.info(f"收到停止讨论信号: {room_id}")
|
||
return True
|
||
|
||
@classmethod
|
||
async def pause_discussion(cls, room_id: str) -> bool:
|
||
"""
|
||
暂停讨论
|
||
|
||
Args:
|
||
room_id: 聊天室ID
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
if room_id not in cls._active_discussions:
|
||
return False
|
||
|
||
await ChatRoomService.update_status(room_id, ChatRoomStatus.PAUSED)
|
||
await MessageRouter.broadcast_status(room_id, "discussion_paused")
|
||
|
||
logger.info(f"讨论暂停: {room_id}")
|
||
return True
|
||
|
||
@classmethod
|
||
async def resume_discussion(cls, room_id: str) -> bool:
|
||
"""
|
||
恢复讨论
|
||
|
||
Args:
|
||
room_id: 聊天室ID
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||
if not chatroom or chatroom.status != ChatRoomStatus.PAUSED.value:
|
||
return False
|
||
|
||
await ChatRoomService.update_status(room_id, ChatRoomStatus.ACTIVE)
|
||
await MessageRouter.broadcast_status(room_id, "discussion_resumed")
|
||
|
||
logger.info(f"讨论恢复: {room_id}")
|
||
return True
|
||
|
||
@classmethod
|
||
async def _run_discussion_loop(
|
||
cls,
|
||
chatroom: ChatRoom,
|
||
context: DiscussionContext
|
||
) -> DiscussionResult:
|
||
"""
|
||
运行讨论循环
|
||
|
||
Args:
|
||
chatroom: 聊天室
|
||
context: 讨论上下文
|
||
|
||
Returns:
|
||
讨论结果
|
||
"""
|
||
room_id = chatroom.room_id
|
||
config = chatroom.get_config()
|
||
|
||
# 获取所有Agent
|
||
agents = await AgentService.get_agents_by_ids(chatroom.agents)
|
||
agent_map = {a.agent_id: a for a in agents}
|
||
|
||
# 获取主持人(用于共识判断)
|
||
moderator = None
|
||
if chatroom.moderator_agent_id:
|
||
moderator = await AgentService.get_agent(chatroom.moderator_agent_id)
|
||
|
||
consecutive_no_speak = 0 # 连续无人发言的轮次
|
||
|
||
while context.current_round < config.max_rounds:
|
||
# 检查停止信号
|
||
if cls._stop_signals.get(room_id, False):
|
||
break
|
||
|
||
# 检查暂停状态
|
||
current_chatroom = await ChatRoomService.get_chatroom(room_id)
|
||
if current_chatroom and current_chatroom.status == ChatRoomStatus.PAUSED.value:
|
||
await asyncio.sleep(1)
|
||
continue
|
||
|
||
# 增加轮次
|
||
context.current_round += 1
|
||
context.reset_round_counts()
|
||
|
||
# 广播轮次信息
|
||
await MessageRouter.broadcast_round_info(
|
||
room_id,
|
||
context.current_round,
|
||
config.max_rounds
|
||
)
|
||
|
||
# 更新聊天室轮次
|
||
await ChatRoomService.update_chatroom(
|
||
room_id,
|
||
current_round=context.current_round
|
||
)
|
||
|
||
# 本轮是否有人发言
|
||
round_has_message = False
|
||
|
||
# 遍历所有Agent,判断是否发言
|
||
for agent_id in chatroom.agents:
|
||
agent = agent_map.get(agent_id)
|
||
if not agent or not agent.enabled:
|
||
continue
|
||
|
||
# 检查本轮发言次数限制
|
||
behavior = agent.get_behavior()
|
||
if context.get_agent_speak_count(agent_id) >= behavior.max_speak_per_round:
|
||
continue
|
||
|
||
# 判断是否发言
|
||
should_speak, content = await cls._should_agent_speak(
|
||
agent, context, chatroom
|
||
)
|
||
|
||
if should_speak and content:
|
||
# 广播输入状态
|
||
await MessageRouter.broadcast_typing(room_id, agent_id, True)
|
||
|
||
# 保存并广播消息
|
||
message = await MessageRouter.save_and_broadcast_message(
|
||
room_id=room_id,
|
||
discussion_id=context.discussion_id,
|
||
agent_id=agent_id,
|
||
content=content,
|
||
message_type=MessageType.TEXT.value,
|
||
round_num=context.current_round
|
||
)
|
||
|
||
# 更新上下文
|
||
context.add_message(message)
|
||
round_has_message = True
|
||
|
||
# 广播输入结束
|
||
await MessageRouter.broadcast_typing(room_id, agent_id, False)
|
||
|
||
# 轮次间隔
|
||
await asyncio.sleep(config.round_interval)
|
||
|
||
# 检查是否需要共识判断
|
||
if round_has_message and moderator:
|
||
consecutive_no_speak = 0
|
||
|
||
# 每隔几轮检查一次共识
|
||
if context.current_round % 3 == 0 or context.current_round >= config.max_rounds - 5:
|
||
consensus_result = await ConsensusManager.check_consensus(
|
||
moderator, context, chatroom
|
||
)
|
||
|
||
if consensus_result.get("consensus_reached", False):
|
||
confidence = consensus_result.get("confidence", 0)
|
||
if confidence >= config.consensus_threshold:
|
||
# 达成共识,结束讨论
|
||
return await cls._finalize_discussion(
|
||
context,
|
||
consensus_result,
|
||
"consensus"
|
||
)
|
||
else:
|
||
consecutive_no_speak += 1
|
||
|
||
# 连续多轮无人发言,检查共识或结束
|
||
if consecutive_no_speak >= 3:
|
||
if moderator:
|
||
consensus_result = await ConsensusManager.check_consensus(
|
||
moderator, context, chatroom
|
||
)
|
||
return await cls._finalize_discussion(
|
||
context,
|
||
consensus_result,
|
||
"no_more_discussion"
|
||
)
|
||
else:
|
||
return await cls._finalize_discussion(
|
||
context,
|
||
{"consensus_reached": False, "summary": "讨论结束,无明确共识"},
|
||
"no_more_discussion"
|
||
)
|
||
|
||
# 达到最大轮次
|
||
if moderator:
|
||
consensus_result = await ConsensusManager.check_consensus(
|
||
moderator, context, chatroom
|
||
)
|
||
else:
|
||
consensus_result = {"consensus_reached": False, "summary": "达到最大轮次限制"}
|
||
|
||
return await cls._finalize_discussion(
|
||
context,
|
||
consensus_result,
|
||
"max_rounds"
|
||
)
|
||
|
||
@classmethod
|
||
async def _should_agent_speak(
|
||
cls,
|
||
agent: Agent,
|
||
context: DiscussionContext,
|
||
chatroom: ChatRoom
|
||
) -> tuple[bool, str]:
|
||
"""
|
||
判断Agent是否应该发言
|
||
|
||
Args:
|
||
agent: Agent实例
|
||
context: 讨论上下文
|
||
chatroom: 聊天室
|
||
|
||
Returns:
|
||
(是否发言, 发言内容)
|
||
"""
|
||
# 构建判断提示词
|
||
recent_messages = context.get_recent_messages(chatroom.get_config().message_history_size)
|
||
|
||
history_text = ""
|
||
for msg in recent_messages:
|
||
if msg.agent_id:
|
||
history_text += f"[{msg.agent_id}]: {msg.content}\n\n"
|
||
else:
|
||
history_text += f"[系统]: {msg.content}\n\n"
|
||
|
||
prompt = f"""你是{agent.name},角色是{agent.role}。
|
||
|
||
{agent.system_prompt}
|
||
|
||
当前讨论目标:{context.objective}
|
||
|
||
对话历史:
|
||
{history_text if history_text else "(还没有对话)"}
|
||
|
||
当前是第{context.current_round}轮讨论。
|
||
|
||
请根据你的角色判断:
|
||
1. 你是否有新的观点或建议要分享?
|
||
2. 你是否需要回应其他人的观点?
|
||
3. 当前讨论是否需要你的专业意见?
|
||
|
||
如果你认为需要发言,请直接给出你的发言内容。
|
||
如果你认为暂时不需要发言(例如等待更多信息、当前轮次已有足够讨论、或者你的观点已经充分表达),请只回复"PASS"。
|
||
|
||
注意:
|
||
- 请保持发言简洁有力,每次发言控制在200字以内
|
||
- 避免重复已经说过的内容
|
||
- 如果已经达成共识或接近共识,可以选择PASS"""
|
||
|
||
try:
|
||
# 调用AI接口
|
||
response = await AIProviderService.chat(
|
||
provider_id=agent.provider_id,
|
||
messages=[{"role": "user", "content": prompt}],
|
||
temperature=agent.temperature,
|
||
max_tokens=agent.max_tokens
|
||
)
|
||
|
||
if not response.success:
|
||
logger.warning(f"Agent {agent.agent_id} 响应失败: {response.error}")
|
||
return False, ""
|
||
|
||
content = response.content.strip()
|
||
|
||
# 判断是否PASS
|
||
if content.upper() == "PASS" or content.upper().startswith("PASS"):
|
||
return False, ""
|
||
|
||
return True, content
|
||
|
||
except Exception as e:
|
||
logger.error(f"Agent {agent.agent_id} 判断发言异常: {e}")
|
||
return False, ""
|
||
|
||
@classmethod
|
||
async def _finalize_discussion(
|
||
cls,
|
||
context: DiscussionContext,
|
||
consensus_result: Dict[str, Any],
|
||
end_reason: str
|
||
) -> DiscussionResult:
|
||
"""
|
||
完成讨论,保存结果
|
||
|
||
Args:
|
||
context: 讨论上下文
|
||
consensus_result: 共识判断结果
|
||
end_reason: 结束原因
|
||
|
||
Returns:
|
||
讨论结果
|
||
"""
|
||
room_id = context.room_id
|
||
|
||
# 获取讨论结果记录
|
||
discussion_result = await DiscussionResult.find_one(
|
||
DiscussionResult.discussion_id == context.discussion_id
|
||
)
|
||
|
||
if discussion_result:
|
||
# 更新统计
|
||
discussion_result.update_stats(
|
||
total_rounds=context.current_round,
|
||
total_messages=len(context.messages),
|
||
agent_contributions=context.agent_speak_counts
|
||
)
|
||
|
||
# 标记完成
|
||
discussion_result.mark_completed(
|
||
consensus_reached=consensus_result.get("consensus_reached", False),
|
||
confidence=consensus_result.get("confidence", 0),
|
||
summary=consensus_result.get("summary", ""),
|
||
action_items=consensus_result.get("action_items", []),
|
||
unresolved_issues=consensus_result.get("unresolved_issues", []),
|
||
end_reason=end_reason
|
||
)
|
||
|
||
await discussion_result.save()
|
||
|
||
# 更新聊天室状态
|
||
await ChatRoomService.update_status(room_id, ChatRoomStatus.COMPLETED)
|
||
|
||
# 发送系统消息
|
||
summary_text = f"""讨论结束
|
||
|
||
结果:{"达成共识" if consensus_result.get("consensus_reached") else "未达成明确共识"}
|
||
置信度:{consensus_result.get("confidence", 0):.0%}
|
||
|
||
摘要:{consensus_result.get("summary", "无")}
|
||
|
||
行动项:
|
||
{chr(10).join("- " + item for item in consensus_result.get("action_items", [])) or "无"}
|
||
|
||
未解决问题:
|
||
{chr(10).join("- " + issue for issue in consensus_result.get("unresolved_issues", [])) or "无"}
|
||
|
||
共进行 {context.current_round} 轮讨论,产生 {len(context.messages)} 条消息。"""
|
||
|
||
await MessageRouter.save_and_broadcast_message(
|
||
room_id=room_id,
|
||
discussion_id=context.discussion_id,
|
||
agent_id=None,
|
||
content=summary_text,
|
||
message_type=MessageType.SYSTEM.value,
|
||
round_num=context.current_round
|
||
)
|
||
|
||
# 广播讨论结束
|
||
await MessageRouter.broadcast_status(room_id, "discussion_completed", {
|
||
"discussion_id": context.discussion_id,
|
||
"consensus_reached": consensus_result.get("consensus_reached", False),
|
||
"end_reason": end_reason
|
||
})
|
||
|
||
logger.info(f"讨论结束: {room_id}, 原因: {end_reason}")
|
||
|
||
return discussion_result
|
||
|
||
@classmethod
|
||
async def _handle_discussion_error(
|
||
cls,
|
||
room_id: str,
|
||
discussion_id: str,
|
||
error: str
|
||
) -> None:
|
||
"""
|
||
处理讨论错误
|
||
|
||
Args:
|
||
room_id: 聊天室ID
|
||
discussion_id: 讨论ID
|
||
error: 错误信息
|
||
"""
|
||
# 更新聊天室状态
|
||
await ChatRoomService.update_status(room_id, ChatRoomStatus.ERROR)
|
||
|
||
# 更新讨论结果
|
||
discussion_result = await DiscussionResult.find_one(
|
||
DiscussionResult.discussion_id == discussion_id
|
||
)
|
||
if discussion_result:
|
||
discussion_result.status = "failed"
|
||
discussion_result.end_reason = f"error: {error}"
|
||
discussion_result.updated_at = datetime.utcnow()
|
||
await discussion_result.save()
|
||
|
||
# 广播错误
|
||
await MessageRouter.broadcast_error(room_id, error)
|
||
|
||
@classmethod
|
||
def get_active_discussion(cls, room_id: str) -> Optional[DiscussionContext]:
|
||
"""
|
||
获取活跃的讨论上下文
|
||
|
||
Args:
|
||
room_id: 聊天室ID
|
||
|
||
Returns:
|
||
讨论上下文或None
|
||
"""
|
||
return cls._active_discussions.get(room_id)
|
||
|
||
@classmethod
|
||
def is_discussion_active(cls, room_id: str) -> bool:
|
||
"""
|
||
检查是否有活跃讨论
|
||
|
||
Args:
|
||
room_id: 聊天室ID
|
||
|
||
Returns:
|
||
是否活跃
|
||
"""
|
||
return room_id in cls._active_discussions
|