feat: AI聊天室多Agent协作讨论平台

- 实现Agent管理,支持AI辅助生成系统提示词
- 支持多个AI提供商(OpenRouter、智谱、MiniMax等)
- 实现聊天室和讨论引擎
- WebSocket实时消息推送
- 前端使用React + Ant Design
- 后端使用FastAPI + MongoDB

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Claude Code
2026-02-03 19:20:02 +08:00
commit edbddf855d
76 changed files with 14681 additions and 0 deletions

View File

@@ -0,0 +1,22 @@
"""
业务服务模块
"""
from .ai_provider_service import AIProviderService
from .agent_service import AgentService
from .chatroom_service import ChatRoomService
from .message_router import MessageRouter
from .discussion_engine import DiscussionEngine
from .consensus_manager import ConsensusManager
from .mcp_service import MCPService
from .memory_service import MemoryService
__all__ = [
"AIProviderService",
"AgentService",
"ChatRoomService",
"MessageRouter",
"DiscussionEngine",
"ConsensusManager",
"MCPService",
"MemoryService",
]

View File

@@ -0,0 +1,438 @@
"""
Agent服务
管理AI代理的配置
"""
import uuid
from datetime import datetime
from typing import List, Dict, Any, Optional
from loguru import logger
from models.agent import Agent
from services.ai_provider_service import AIProviderService
class AgentService:
"""
Agent服务类
负责Agent的CRUD操作
"""
@classmethod
async def create_agent(
cls,
name: str,
role: str,
system_prompt: str,
provider_id: str,
temperature: float = 0.7,
max_tokens: int = 2000,
capabilities: Optional[Dict[str, Any]] = None,
behavior: Optional[Dict[str, Any]] = None,
avatar: Optional[str] = None,
color: str = "#1890ff"
) -> Agent:
"""
创建新的Agent
Args:
name: Agent名称
role: 角色定义
system_prompt: 系统提示词
provider_id: 使用的AI接口ID
temperature: 温度参数
max_tokens: 最大token数
capabilities: 能力配置
behavior: 行为配置
avatar: 头像URL
color: 代表颜色
Returns:
创建的Agent文档
"""
# 验证AI接口存在
provider = await AIProviderService.get_provider(provider_id)
if not provider:
raise ValueError(f"AI接口不存在: {provider_id}")
# 生成唯一ID
agent_id = f"agent-{uuid.uuid4().hex[:8]}"
# 默认能力配置
default_capabilities = {
"memory_enabled": False,
"mcp_tools": [],
"skills": [],
"multimodal": False
}
if capabilities:
default_capabilities.update(capabilities)
# 默认行为配置
default_behavior = {
"speak_threshold": 0.5,
"max_speak_per_round": 2,
"speak_style": "balanced"
}
if behavior:
default_behavior.update(behavior)
# 创建文档
agent = Agent(
agent_id=agent_id,
name=name,
role=role,
system_prompt=system_prompt,
provider_id=provider_id,
temperature=temperature,
max_tokens=max_tokens,
capabilities=default_capabilities,
behavior=default_behavior,
avatar=avatar,
color=color,
enabled=True,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
await agent.insert()
logger.info(f"创建Agent: {agent_id} ({name})")
return agent
@classmethod
async def get_agent(cls, agent_id: str) -> Optional[Agent]:
"""
获取指定Agent
Args:
agent_id: Agent ID
Returns:
Agent文档或None
"""
return await Agent.find_one(Agent.agent_id == agent_id)
@classmethod
async def get_all_agents(
cls,
enabled_only: bool = False
) -> List[Agent]:
"""
获取所有Agent
Args:
enabled_only: 是否只返回启用的Agent
Returns:
Agent列表
"""
if enabled_only:
return await Agent.find(Agent.enabled == True).to_list()
return await Agent.find_all().to_list()
@classmethod
async def get_agents_by_ids(
cls,
agent_ids: List[str]
) -> List[Agent]:
"""
根据ID列表获取多个Agent
Args:
agent_ids: Agent ID列表
Returns:
Agent列表
"""
return await Agent.find(
{"agent_id": {"$in": agent_ids}}
).to_list()
@classmethod
async def update_agent(
cls,
agent_id: str,
**kwargs
) -> Optional[Agent]:
"""
更新Agent配置
Args:
agent_id: Agent ID
**kwargs: 要更新的字段
Returns:
更新后的Agent或None
"""
agent = await cls.get_agent(agent_id)
if not agent:
return None
# 如果更新了provider_id验证其存在
if "provider_id" in kwargs:
provider = await AIProviderService.get_provider(kwargs["provider_id"])
if not provider:
raise ValueError(f"AI接口不存在: {kwargs['provider_id']}")
# 更新字段
kwargs["updated_at"] = datetime.utcnow()
for key, value in kwargs.items():
if hasattr(agent, key):
setattr(agent, key, value)
await agent.save()
logger.info(f"更新Agent: {agent_id}")
return agent
@classmethod
async def delete_agent(cls, agent_id: str) -> bool:
"""
删除Agent
Args:
agent_id: Agent ID
Returns:
是否删除成功
"""
agent = await cls.get_agent(agent_id)
if not agent:
return False
await agent.delete()
logger.info(f"删除Agent: {agent_id}")
return True
@classmethod
async def test_agent(
cls,
agent_id: str,
test_message: str = "你好,请简单介绍一下你自己。"
) -> Dict[str, Any]:
"""
测试Agent对话
Args:
agent_id: Agent ID
test_message: 测试消息
Returns:
测试结果
"""
agent = await cls.get_agent(agent_id)
if not agent:
return {
"success": False,
"message": f"Agent不存在: {agent_id}"
}
if not agent.enabled:
return {
"success": False,
"message": "Agent已禁用"
}
# 构建消息
messages = [
{"role": "system", "content": agent.system_prompt},
{"role": "user", "content": test_message}
]
# 调用AI接口
response = await AIProviderService.chat(
provider_id=agent.provider_id,
messages=messages,
temperature=agent.temperature,
max_tokens=agent.max_tokens
)
if response.success:
return {
"success": True,
"message": "测试成功",
"response": response.content,
"model": response.model,
"tokens": response.total_tokens,
"latency_ms": response.latency_ms
}
else:
return {
"success": False,
"message": response.error
}
@classmethod
async def duplicate_agent(
cls,
agent_id: str,
new_name: Optional[str] = None
) -> Optional[Agent]:
"""
复制Agent
Args:
agent_id: 源Agent ID
new_name: 新Agent名称
Returns:
新创建的Agent或None
"""
source_agent = await cls.get_agent(agent_id)
if not source_agent:
return None
return await cls.create_agent(
name=new_name or f"{source_agent.name} (副本)",
role=source_agent.role,
system_prompt=source_agent.system_prompt,
provider_id=source_agent.provider_id,
temperature=source_agent.temperature,
max_tokens=source_agent.max_tokens,
capabilities=source_agent.capabilities,
behavior=source_agent.behavior,
avatar=source_agent.avatar,
color=source_agent.color
)
@classmethod
async def generate_system_prompt(
cls,
provider_id: str,
name: str,
role: str,
description: Optional[str] = None
) -> Dict[str, Any]:
"""
使用AI生成Agent系统提示词
Args:
provider_id: AI接口ID
name: Agent名称
role: 角色定位
description: 额外描述(可选)
Returns:
生成结果包含success和生成的prompt
"""
# 验证AI接口存在
provider = await AIProviderService.get_provider(provider_id)
if not provider:
return {
"success": False,
"message": f"AI接口不存在: {provider_id}"
}
# 构建生成提示词的请求
generate_prompt = f"""请为一个AI Agent编写系统提示词system prompt
Agent名称{name}
角色定位:{role}
{f'补充说明:{description}' if description else ''}
要求:
1. 提示词应简洁专业控制在200字以内
2. 明确该Agent的核心职责和专业领域
3. 说明在多Agent讨论中应该关注什么
4. 使用中文编写
5. 不要包含任何问候语或开场白,直接给出提示词内容
请直接输出系统提示词,不要有任何额外的解释或包装。"""
try:
messages = [{"role": "user", "content": generate_prompt}]
response = await AIProviderService.chat(
provider_id=provider_id,
messages=messages,
temperature=0.7,
max_tokens=1000
)
if response.success:
# 清理可能的包装文本
content = response.content.strip()
# 移除可能的markdown代码块标记
if content.startswith("```"):
lines = content.split("\n")
content = "\n".join(lines[1:])
if content.endswith("```"):
content = content[:-3]
content = content.strip()
return {
"success": True,
"prompt": content,
"model": response.model,
"tokens": response.total_tokens
}
else:
return {
"success": False,
"message": response.error or "生成失败"
}
except Exception as e:
logger.error(f"生成系统提示词失败: {e}")
return {
"success": False,
"message": f"生成失败: {str(e)}"
}
# Agent预设模板
AGENT_TEMPLATES = {
"product_manager": {
"name": "产品经理",
"role": "产品规划和需求分析专家",
"system_prompt": """你是一位经验丰富的产品经理,擅长:
- 分析用户需求和痛点
- 制定产品策略和路线图
- 平衡业务目标和用户体验
- 与团队协作推进产品迭代
在讨论中,你需要从产品角度出发,关注用户价值、商业可行性和优先级排序。
请用专业但易懂的语言表达观点。""",
"color": "#1890ff"
},
"developer": {
"name": "开发工程师",
"role": "技术实现和架构设计专家",
"system_prompt": """你是一位资深的软件开发工程师,擅长:
- 系统架构设计
- 代码实现和优化
- 技术方案评估
- 性能和安全考量
在讨论中,你需要从技术角度出发,关注实现可行性、技术债务和最佳实践。
请提供具体的技术建议和潜在风险评估。""",
"color": "#52c41a"
},
"designer": {
"name": "设计师",
"role": "用户体验和界面设计专家",
"system_prompt": """你是一位专业的UI/UX设计师擅长:
- 用户体验设计
- 界面视觉设计
- 交互流程优化
- 设计系统构建
在讨论中,你需要从设计角度出发,关注用户体验、视觉美感和交互流畅性。
请提供设计建议并考虑可用性和一致性。""",
"color": "#eb2f96"
},
"moderator": {
"name": "主持人",
"role": "讨论主持和共识判断专家",
"system_prompt": """你是讨论的主持人,负责:
- 引导讨论方向
- 总结各方观点
- 判断是否达成共识
- 提炼行动要点
在讨论中,你需要保持中立,促进有效沟通,并在适当时机总结讨论成果。
当各方观点趋于一致时,请明确指出并总结共识内容。""",
"color": "#722ed1"
}
}

View File

@@ -0,0 +1,364 @@
"""
AI接口提供商服务
管理AI接口的配置和调用
"""
import uuid
from datetime import datetime
from typing import List, Dict, Any, Optional
from loguru import logger
from models.ai_provider import AIProvider
from adapters import get_adapter, BaseAdapter, ChatMessage, AdapterResponse
from utils.encryption import encrypt_api_key, decrypt_api_key
from utils.rate_limiter import rate_limiter
class AIProviderService:
"""
AI接口提供商服务类
负责AI接口的CRUD操作和调用
"""
# 缓存适配器实例
_adapter_cache: Dict[str, BaseAdapter] = {}
@classmethod
async def create_provider(
cls,
provider_type: str,
name: str,
model: str,
api_key: str = "",
base_url: str = "",
use_proxy: bool = False,
proxy_config: Optional[Dict[str, Any]] = None,
rate_limit: Optional[Dict[str, int]] = None,
timeout: int = 60,
extra_params: Optional[Dict[str, Any]] = None
) -> AIProvider:
"""
创建新的AI接口配置
Args:
provider_type: 提供商类型
name: 自定义名称
model: 模型名称
api_key: API密钥
base_url: API基础URL
use_proxy: 是否使用代理
proxy_config: 代理配置
rate_limit: 速率限制配置
timeout: 超时时间
extra_params: 额外参数
Returns:
创建的AIProvider文档
"""
# 验证提供商类型
try:
get_adapter(provider_type)
except ValueError as e:
raise ValueError(f"不支持的提供商类型: {provider_type}")
# 生成唯一ID
provider_id = f"{provider_type}-{uuid.uuid4().hex[:8]}"
# 加密API密钥
encrypted_key = encrypt_api_key(api_key) if api_key else ""
# 创建文档
provider = AIProvider(
provider_id=provider_id,
provider_type=provider_type,
name=name,
api_key=encrypted_key,
base_url=base_url,
model=model,
use_proxy=use_proxy,
proxy_config=proxy_config or {},
rate_limit=rate_limit or {"requests_per_minute": 60, "tokens_per_minute": 100000},
timeout=timeout,
extra_params=extra_params or {},
enabled=True,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
await provider.insert()
# 注册速率限制
rate_limiter.register(
provider_id,
provider.rate_limit.get("requests_per_minute", 60),
provider.rate_limit.get("tokens_per_minute", 100000)
)
logger.info(f"创建AI接口配置: {provider_id} ({name})")
return provider
@classmethod
async def get_provider(cls, provider_id: str) -> Optional[AIProvider]:
"""
获取指定AI接口配置
Args:
provider_id: 接口ID
Returns:
AIProvider文档或None
"""
return await AIProvider.find_one(AIProvider.provider_id == provider_id)
@classmethod
async def get_all_providers(
cls,
enabled_only: bool = False
) -> List[AIProvider]:
"""
获取所有AI接口配置
Args:
enabled_only: 是否只返回启用的接口
Returns:
AIProvider列表
"""
if enabled_only:
return await AIProvider.find(AIProvider.enabled == True).to_list()
return await AIProvider.find_all().to_list()
@classmethod
async def update_provider(
cls,
provider_id: str,
**kwargs
) -> Optional[AIProvider]:
"""
更新AI接口配置
Args:
provider_id: 接口ID
**kwargs: 要更新的字段
Returns:
更新后的AIProvider或None
"""
provider = await cls.get_provider(provider_id)
if not provider:
return None
# 如果更新了API密钥需要加密
if "api_key" in kwargs and kwargs["api_key"]:
kwargs["api_key"] = encrypt_api_key(kwargs["api_key"])
# 更新字段
kwargs["updated_at"] = datetime.utcnow()
for key, value in kwargs.items():
if hasattr(provider, key):
setattr(provider, key, value)
await provider.save()
# 清除适配器缓存
cls._adapter_cache.pop(provider_id, None)
# 更新速率限制
if "rate_limit" in kwargs:
rate_limiter.unregister(provider_id)
rate_limiter.register(
provider_id,
provider.rate_limit.get("requests_per_minute", 60),
provider.rate_limit.get("tokens_per_minute", 100000)
)
logger.info(f"更新AI接口配置: {provider_id}")
return provider
@classmethod
async def delete_provider(cls, provider_id: str) -> bool:
"""
删除AI接口配置
Args:
provider_id: 接口ID
Returns:
是否删除成功
"""
provider = await cls.get_provider(provider_id)
if not provider:
return False
await provider.delete()
# 清除缓存和速率限制
cls._adapter_cache.pop(provider_id, None)
rate_limiter.unregister(provider_id)
logger.info(f"删除AI接口配置: {provider_id}")
return True
@classmethod
async def get_adapter(cls, provider_id: str) -> Optional[BaseAdapter]:
"""
获取AI接口的适配器实例
Args:
provider_id: 接口ID
Returns:
适配器实例或None
"""
# 检查缓存
if provider_id in cls._adapter_cache:
return cls._adapter_cache[provider_id]
provider = await cls.get_provider(provider_id)
if not provider or not provider.enabled:
return None
# 解密API密钥
api_key = decrypt_api_key(provider.api_key) if provider.api_key else ""
# 创建适配器
adapter_class = get_adapter(provider.provider_type)
adapter = adapter_class(
api_key=api_key,
base_url=provider.base_url,
model=provider.model,
use_proxy=provider.use_proxy,
proxy_config=provider.proxy_config,
timeout=provider.timeout,
**provider.extra_params
)
# 缓存适配器
cls._adapter_cache[provider_id] = adapter
return adapter
@classmethod
async def chat(
cls,
provider_id: str,
messages: List[Dict[str, str]],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> AdapterResponse:
"""
调用AI接口进行对话
Args:
provider_id: 接口ID
messages: 消息列表 [{"role": "user", "content": "..."}]
temperature: 温度参数
max_tokens: 最大token数
**kwargs: 额外参数
Returns:
适配器响应
"""
adapter = await cls.get_adapter(provider_id)
if not adapter:
return AdapterResponse(
success=False,
error=f"AI接口不存在或未启用: {provider_id}"
)
# 检查速率限制
estimated_tokens = sum(len(m.get("content", "")) for m in messages) // 4
if not await rate_limiter.acquire_wait(provider_id, estimated_tokens):
return AdapterResponse(
success=False,
error="请求频率超限,请稍后重试"
)
# 转换消息格式
chat_messages = [
ChatMessage(
role=m.get("role", "user"),
content=m.get("content", ""),
name=m.get("name")
)
for m in messages
]
# 调用适配器
response = await adapter.chat(
messages=chat_messages,
temperature=temperature,
max_tokens=max_tokens,
**kwargs
)
return response
@classmethod
async def test_provider(cls, provider_id: str) -> Dict[str, Any]:
"""
测试AI接口连接
Args:
provider_id: 接口ID
Returns:
测试结果
"""
adapter = await cls.get_adapter(provider_id)
if not adapter:
return {
"success": False,
"message": f"AI接口不存在或未启用: {provider_id}"
}
return await adapter.test_connection()
@classmethod
async def test_provider_config(
cls,
provider_type: str,
api_key: str,
base_url: str = "",
model: str = "",
use_proxy: bool = False,
proxy_config: Optional[Dict[str, Any]] = None,
timeout: int = 30,
**kwargs
) -> Dict[str, Any]:
"""
测试AI接口配置不保存
Args:
provider_type: 提供商类型
api_key: API密钥
base_url: API基础URL
model: 模型名称
use_proxy: 是否使用代理
proxy_config: 代理配置
timeout: 超时时间
**kwargs: 额外参数
Returns:
测试结果
"""
try:
adapter_class = get_adapter(provider_type)
except ValueError:
return {
"success": False,
"message": f"不支持的提供商类型: {provider_type}"
}
adapter = adapter_class(
api_key=api_key,
base_url=base_url,
model=model,
use_proxy=use_proxy,
proxy_config=proxy_config,
timeout=timeout,
**kwargs
)
return await adapter.test_connection()

View File

@@ -0,0 +1,357 @@
"""
聊天室服务
管理聊天室的创建和状态
"""
import uuid
from datetime import datetime
from typing import List, Dict, Any, Optional
from loguru import logger
from models.chatroom import ChatRoom, ChatRoomStatus
from models.message import Message
from services.agent_service import AgentService
class ChatRoomService:
"""
聊天室服务类
负责聊天室的CRUD操作
"""
@classmethod
async def create_chatroom(
cls,
name: str,
description: str = "",
agents: Optional[List[str]] = None,
moderator_agent_id: Optional[str] = None,
config: Optional[Dict[str, Any]] = None
) -> ChatRoom:
"""
创建新的聊天室
Args:
name: 聊天室名称
description: 描述
agents: Agent ID列表
moderator_agent_id: 主持人Agent ID
config: 聊天室配置
Returns:
创建的ChatRoom文档
"""
# 验证Agent存在
if agents:
existing_agents = await AgentService.get_agents_by_ids(agents)
existing_ids = {a.agent_id for a in existing_agents}
missing_ids = set(agents) - existing_ids
if missing_ids:
raise ValueError(f"Agent不存在: {', '.join(missing_ids)}")
# 验证主持人Agent
if moderator_agent_id:
moderator = await AgentService.get_agent(moderator_agent_id)
if not moderator:
raise ValueError(f"主持人Agent不存在: {moderator_agent_id}")
# 生成唯一ID
room_id = f"room-{uuid.uuid4().hex[:8]}"
# 默认配置
default_config = {
"max_rounds": 50,
"message_history_size": 20,
"consensus_threshold": 0.8,
"round_interval": 1.0,
"allow_user_interrupt": True
}
if config:
default_config.update(config)
# 创建文档
chatroom = ChatRoom(
room_id=room_id,
name=name,
description=description,
objective="",
agents=agents or [],
moderator_agent_id=moderator_agent_id,
config=default_config,
status=ChatRoomStatus.IDLE.value,
current_round=0,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
await chatroom.insert()
logger.info(f"创建聊天室: {room_id} ({name})")
return chatroom
@classmethod
async def get_chatroom(cls, room_id: str) -> Optional[ChatRoom]:
"""
获取指定聊天室
Args:
room_id: 聊天室ID
Returns:
ChatRoom文档或None
"""
return await ChatRoom.find_one(ChatRoom.room_id == room_id)
@classmethod
async def get_all_chatrooms(cls) -> List[ChatRoom]:
"""
获取所有聊天室
Returns:
ChatRoom列表
"""
return await ChatRoom.find_all().to_list()
@classmethod
async def update_chatroom(
cls,
room_id: str,
**kwargs
) -> Optional[ChatRoom]:
"""
更新聊天室配置
Args:
room_id: 聊天室ID
**kwargs: 要更新的字段
Returns:
更新后的ChatRoom或None
"""
chatroom = await cls.get_chatroom(room_id)
if not chatroom:
return None
# 验证Agent
if "agents" in kwargs:
existing_agents = await AgentService.get_agents_by_ids(kwargs["agents"])
existing_ids = {a.agent_id for a in existing_agents}
missing_ids = set(kwargs["agents"]) - existing_ids
if missing_ids:
raise ValueError(f"Agent不存在: {', '.join(missing_ids)}")
# 验证主持人
if "moderator_agent_id" in kwargs and kwargs["moderator_agent_id"]:
moderator = await AgentService.get_agent(kwargs["moderator_agent_id"])
if not moderator:
raise ValueError(f"主持人Agent不存在: {kwargs['moderator_agent_id']}")
# 更新字段
kwargs["updated_at"] = datetime.utcnow()
for key, value in kwargs.items():
if hasattr(chatroom, key):
setattr(chatroom, key, value)
await chatroom.save()
logger.info(f"更新聊天室: {room_id}")
return chatroom
@classmethod
async def delete_chatroom(cls, room_id: str) -> bool:
"""
删除聊天室
Args:
room_id: 聊天室ID
Returns:
是否删除成功
"""
chatroom = await cls.get_chatroom(room_id)
if not chatroom:
return False
# 删除相关消息
await Message.find(Message.room_id == room_id).delete()
await chatroom.delete()
logger.info(f"删除聊天室: {room_id}")
return True
@classmethod
async def add_agent(cls, room_id: str, agent_id: str) -> Optional[ChatRoom]:
"""
向聊天室添加Agent
Args:
room_id: 聊天室ID
agent_id: Agent ID
Returns:
更新后的ChatRoom或None
"""
chatroom = await cls.get_chatroom(room_id)
if not chatroom:
return None
# 验证Agent存在
agent = await AgentService.get_agent(agent_id)
if not agent:
raise ValueError(f"Agent不存在: {agent_id}")
# 添加Agent
if agent_id not in chatroom.agents:
chatroom.agents.append(agent_id)
chatroom.updated_at = datetime.utcnow()
await chatroom.save()
return chatroom
@classmethod
async def remove_agent(cls, room_id: str, agent_id: str) -> Optional[ChatRoom]:
"""
从聊天室移除Agent
Args:
room_id: 聊天室ID
agent_id: Agent ID
Returns:
更新后的ChatRoom或None
"""
chatroom = await cls.get_chatroom(room_id)
if not chatroom:
return None
# 移除Agent
if agent_id in chatroom.agents:
chatroom.agents.remove(agent_id)
chatroom.updated_at = datetime.utcnow()
await chatroom.save()
return chatroom
@classmethod
async def set_objective(
cls,
room_id: str,
objective: str
) -> Optional[ChatRoom]:
"""
设置讨论目标
Args:
room_id: 聊天室ID
objective: 讨论目标
Returns:
更新后的ChatRoom或None
"""
return await cls.update_chatroom(room_id, objective=objective)
@classmethod
async def update_status(
cls,
room_id: str,
status: ChatRoomStatus
) -> Optional[ChatRoom]:
"""
更新聊天室状态
Args:
room_id: 聊天室ID
status: 新状态
Returns:
更新后的ChatRoom或None
"""
chatroom = await cls.get_chatroom(room_id)
if not chatroom:
return None
chatroom.status = status.value
chatroom.updated_at = datetime.utcnow()
if status == ChatRoomStatus.COMPLETED:
chatroom.completed_at = datetime.utcnow()
await chatroom.save()
logger.info(f"聊天室状态更新: {room_id} -> {status.value}")
return chatroom
@classmethod
async def increment_round(cls, room_id: str) -> Optional[ChatRoom]:
"""
增加轮次计数
Args:
room_id: 聊天室ID
Returns:
更新后的ChatRoom或None
"""
chatroom = await cls.get_chatroom(room_id)
if not chatroom:
return None
chatroom.current_round += 1
chatroom.updated_at = datetime.utcnow()
await chatroom.save()
return chatroom
@classmethod
async def get_messages(
cls,
room_id: str,
limit: int = 50,
skip: int = 0,
discussion_id: Optional[str] = None
) -> List[Message]:
"""
获取聊天室消息历史
Args:
room_id: 聊天室ID
limit: 返回数量限制
skip: 跳过数量
discussion_id: 讨论ID可选
Returns:
消息列表
"""
query = {"room_id": room_id}
if discussion_id:
query["discussion_id"] = discussion_id
return await Message.find(query).sort(
"-created_at"
).skip(skip).limit(limit).to_list()
@classmethod
async def get_recent_messages(
cls,
room_id: str,
count: int = 20,
discussion_id: Optional[str] = None
) -> List[Message]:
"""
获取最近的消息
Args:
room_id: 聊天室ID
count: 消息数量
discussion_id: 讨论ID可选
Returns:
消息列表(按时间正序)
"""
messages = await cls.get_messages(
room_id,
limit=count,
discussion_id=discussion_id
)
return list(reversed(messages)) # 返回正序

View File

@@ -0,0 +1,227 @@
"""
共识管理器
判断讨论是否达成共识
"""
import json
from typing import Dict, Any, Optional
from loguru import logger
from models.agent import Agent
from models.chatroom import ChatRoom
from services.ai_provider_service import AIProviderService
class ConsensusManager:
"""
共识管理器
使用主持人Agent判断讨论共识
"""
# 共识判断提示词模板
CONSENSUS_PROMPT = """你是讨论的主持人,负责判断讨论是否达成共识。
讨论目标:{objective}
对话历史:
{history}
请仔细分析对话内容,判断:
1. 参与者是否对核心问题达成一致意见?
2. 是否还有重要分歧未解决?
3. 讨论结果是否足够明确和可执行?
请以JSON格式回复不要包含任何其他文字
{{
"consensus_reached": true或false,
"confidence": 0到1之间的数字,
"summary": "讨论结果摘要,简洁概括达成的共识或当前状态",
"action_items": ["具体的行动项列表"],
"unresolved_issues": ["未解决的问题列表"],
"key_decisions": ["关键决策列表"]
}}
注意:
- consensus_reached为true表示核心问题已有明确结论
- confidence表示你对共识判断的信心程度
- 如果讨论仍有争议或不够深入应该返回false
- action_items应该是具体可执行的任务
- 请确保返回有效的JSON格式"""
@classmethod
async def check_consensus(
cls,
moderator: Agent,
context: "DiscussionContext",
chatroom: ChatRoom
) -> Dict[str, Any]:
"""
检查是否达成共识
Args:
moderator: 主持人Agent
context: 讨论上下文
chatroom: 聊天室
Returns:
共识判断结果
"""
from services.discussion_engine import DiscussionContext
# 构建历史记录
history_text = ""
for msg in context.messages:
if msg.agent_id:
history_text += f"[{msg.agent_id}]: {msg.content}\n\n"
else:
history_text += f"[系统]: {msg.content}\n\n"
if not history_text:
return {
"consensus_reached": False,
"confidence": 0,
"summary": "讨论尚未开始",
"action_items": [],
"unresolved_issues": [],
"key_decisions": []
}
# 构建提示词
prompt = cls.CONSENSUS_PROMPT.format(
objective=context.objective,
history=history_text
)
try:
# 调用主持人Agent的AI接口
response = await AIProviderService.chat(
provider_id=moderator.provider_id,
messages=[{"role": "user", "content": prompt}],
temperature=0.3, # 使用较低温度以获得更一致的结果
max_tokens=1000
)
if not response.success:
logger.error(f"共识判断失败: {response.error}")
return cls._default_result("AI接口调用失败")
# 解析JSON响应
content = response.content.strip()
# 尝试提取JSON部分
try:
# 尝试直接解析
result = json.loads(content)
except json.JSONDecodeError:
# 尝试提取JSON块
import re
json_match = re.search(r'\{[\s\S]*\}', content)
if json_match:
try:
result = json.loads(json_match.group())
except json.JSONDecodeError:
logger.warning(f"无法解析共识判断结果: {content}")
return cls._default_result("无法解析AI响应")
else:
return cls._default_result("AI响应格式错误")
# 验证和规范化结果
return cls._normalize_result(result)
except Exception as e:
logger.error(f"共识判断异常: {e}")
return cls._default_result(str(e))
@classmethod
async def generate_summary(
cls,
moderator: Agent,
context: "DiscussionContext"
) -> str:
"""
生成讨论摘要
Args:
moderator: 主持人Agent
context: 讨论上下文
Returns:
讨论摘要
"""
from services.discussion_engine import DiscussionContext
# 构建历史记录
history_text = ""
for msg in context.messages:
if msg.agent_id:
history_text += f"[{msg.agent_id}]: {msg.content}\n\n"
prompt = f"""请为以下讨论生成一份简洁的摘要。
讨论目标:{context.objective}
对话记录:
{history_text}
请提供:
1. 讨论的主要观点和结论
2. 参与者的立场和建议
3. 最终的决策或共识(如果有)
摘要应该简洁明了控制在300字以内。"""
try:
response = await AIProviderService.chat(
provider_id=moderator.provider_id,
messages=[{"role": "user", "content": prompt}],
temperature=0.5,
max_tokens=500
)
if response.success:
return response.content.strip()
else:
return "无法生成摘要"
except Exception as e:
logger.error(f"生成摘要异常: {e}")
return "生成摘要时发生错误"
@classmethod
def _default_result(cls, error: str = "") -> Dict[str, Any]:
"""
返回默认结果
Args:
error: 错误信息
Returns:
默认共识结果
"""
return {
"consensus_reached": False,
"confidence": 0,
"summary": error if error else "共识判断失败",
"action_items": [],
"unresolved_issues": [],
"key_decisions": []
}
@classmethod
def _normalize_result(cls, result: Dict[str, Any]) -> Dict[str, Any]:
"""
规范化共识结果
Args:
result: 原始结果
Returns:
规范化的结果
"""
return {
"consensus_reached": bool(result.get("consensus_reached", False)),
"confidence": max(0, min(1, float(result.get("confidence", 0)))),
"summary": str(result.get("summary", "")),
"action_items": list(result.get("action_items", [])),
"unresolved_issues": list(result.get("unresolved_issues", [])),
"key_decisions": list(result.get("key_decisions", []))
}

View File

@@ -0,0 +1,589 @@
"""
讨论引擎
实现自由讨论的核心逻辑
"""
import uuid
import asyncio
from datetime import datetime
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
from loguru import logger
from models.chatroom import ChatRoom, ChatRoomStatus
from models.agent import Agent
from models.message import Message, MessageType
from models.discussion_result import DiscussionResult
from services.ai_provider_service import AIProviderService
from services.agent_service import AgentService
from services.chatroom_service import ChatRoomService
from services.message_router import MessageRouter
from services.consensus_manager import ConsensusManager
@dataclass
class DiscussionContext:
"""讨论上下文"""
discussion_id: str
room_id: str
objective: str
current_round: int = 0
messages: List[Message] = field(default_factory=list)
agent_speak_counts: Dict[str, int] = field(default_factory=dict)
def add_message(self, message: Message) -> None:
"""添加消息到上下文"""
self.messages.append(message)
if message.agent_id:
self.agent_speak_counts[message.agent_id] = \
self.agent_speak_counts.get(message.agent_id, 0) + 1
def get_recent_messages(self, count: int = 20) -> List[Message]:
"""获取最近的消息"""
return self.messages[-count:] if len(self.messages) > count else self.messages
def get_agent_speak_count(self, agent_id: str) -> int:
"""获取Agent在当前轮次的发言次数"""
return self.agent_speak_counts.get(agent_id, 0)
def reset_round_counts(self) -> None:
"""重置轮次发言计数"""
self.agent_speak_counts.clear()
class DiscussionEngine:
"""
讨论引擎
实现多Agent自由讨论的核心逻辑
"""
# 活跃的讨论: room_id -> DiscussionContext
_active_discussions: Dict[str, DiscussionContext] = {}
# 停止信号
_stop_signals: Dict[str, bool] = {}
@classmethod
async def start_discussion(
cls,
room_id: str,
objective: str
) -> Optional[DiscussionResult]:
"""
启动讨论
Args:
room_id: 聊天室ID
objective: 讨论目标
Returns:
讨论结果
"""
# 获取聊天室
chatroom = await ChatRoomService.get_chatroom(room_id)
if not chatroom:
raise ValueError(f"聊天室不存在: {room_id}")
if not chatroom.agents:
raise ValueError("聊天室没有Agent参与")
if not objective:
raise ValueError("讨论目标不能为空")
# 检查是否已有活跃讨论
if room_id in cls._active_discussions:
raise ValueError("聊天室已有进行中的讨论")
# 创建讨论
discussion_id = f"disc-{uuid.uuid4().hex[:8]}"
# 创建讨论结果记录
discussion_result = DiscussionResult(
discussion_id=discussion_id,
room_id=room_id,
objective=objective,
status="in_progress",
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
await discussion_result.insert()
# 创建讨论上下文
context = DiscussionContext(
discussion_id=discussion_id,
room_id=room_id,
objective=objective
)
cls._active_discussions[room_id] = context
cls._stop_signals[room_id] = False
# 更新聊天室状态
await ChatRoomService.update_chatroom(
room_id,
status=ChatRoomStatus.ACTIVE.value,
objective=objective,
current_discussion_id=discussion_id,
current_round=0
)
# 广播讨论开始
await MessageRouter.broadcast_status(room_id, "discussion_started", {
"discussion_id": discussion_id,
"objective": objective
})
# 发送系统消息
await MessageRouter.save_and_broadcast_message(
room_id=room_id,
discussion_id=discussion_id,
agent_id=None,
content=f"讨论开始\n\n目标:{objective}",
message_type=MessageType.SYSTEM.value,
round_num=0
)
logger.info(f"讨论开始: {room_id} - {objective}")
# 运行讨论循环
try:
result = await cls._run_discussion_loop(chatroom, context)
return result
except Exception as e:
logger.error(f"讨论异常: {e}")
await cls._handle_discussion_error(room_id, discussion_id, str(e))
raise
finally:
# 清理
cls._active_discussions.pop(room_id, None)
cls._stop_signals.pop(room_id, None)
@classmethod
async def stop_discussion(cls, room_id: str) -> bool:
"""
停止讨论
Args:
room_id: 聊天室ID
Returns:
是否成功
"""
if room_id not in cls._active_discussions:
return False
cls._stop_signals[room_id] = True
logger.info(f"收到停止讨论信号: {room_id}")
return True
@classmethod
async def pause_discussion(cls, room_id: str) -> bool:
"""
暂停讨论
Args:
room_id: 聊天室ID
Returns:
是否成功
"""
if room_id not in cls._active_discussions:
return False
await ChatRoomService.update_status(room_id, ChatRoomStatus.PAUSED)
await MessageRouter.broadcast_status(room_id, "discussion_paused")
logger.info(f"讨论暂停: {room_id}")
return True
@classmethod
async def resume_discussion(cls, room_id: str) -> bool:
"""
恢复讨论
Args:
room_id: 聊天室ID
Returns:
是否成功
"""
chatroom = await ChatRoomService.get_chatroom(room_id)
if not chatroom or chatroom.status != ChatRoomStatus.PAUSED.value:
return False
await ChatRoomService.update_status(room_id, ChatRoomStatus.ACTIVE)
await MessageRouter.broadcast_status(room_id, "discussion_resumed")
logger.info(f"讨论恢复: {room_id}")
return True
@classmethod
async def _run_discussion_loop(
cls,
chatroom: ChatRoom,
context: DiscussionContext
) -> DiscussionResult:
"""
运行讨论循环
Args:
chatroom: 聊天室
context: 讨论上下文
Returns:
讨论结果
"""
room_id = chatroom.room_id
config = chatroom.get_config()
# 获取所有Agent
agents = await AgentService.get_agents_by_ids(chatroom.agents)
agent_map = {a.agent_id: a for a in agents}
# 获取主持人(用于共识判断)
moderator = None
if chatroom.moderator_agent_id:
moderator = await AgentService.get_agent(chatroom.moderator_agent_id)
consecutive_no_speak = 0 # 连续无人发言的轮次
while context.current_round < config.max_rounds:
# 检查停止信号
if cls._stop_signals.get(room_id, False):
break
# 检查暂停状态
current_chatroom = await ChatRoomService.get_chatroom(room_id)
if current_chatroom and current_chatroom.status == ChatRoomStatus.PAUSED.value:
await asyncio.sleep(1)
continue
# 增加轮次
context.current_round += 1
context.reset_round_counts()
# 广播轮次信息
await MessageRouter.broadcast_round_info(
room_id,
context.current_round,
config.max_rounds
)
# 更新聊天室轮次
await ChatRoomService.update_chatroom(
room_id,
current_round=context.current_round
)
# 本轮是否有人发言
round_has_message = False
# 遍历所有Agent判断是否发言
for agent_id in chatroom.agents:
agent = agent_map.get(agent_id)
if not agent or not agent.enabled:
continue
# 检查本轮发言次数限制
behavior = agent.get_behavior()
if context.get_agent_speak_count(agent_id) >= behavior.max_speak_per_round:
continue
# 判断是否发言
should_speak, content = await cls._should_agent_speak(
agent, context, chatroom
)
if should_speak and content:
# 广播输入状态
await MessageRouter.broadcast_typing(room_id, agent_id, True)
# 保存并广播消息
message = await MessageRouter.save_and_broadcast_message(
room_id=room_id,
discussion_id=context.discussion_id,
agent_id=agent_id,
content=content,
message_type=MessageType.TEXT.value,
round_num=context.current_round
)
# 更新上下文
context.add_message(message)
round_has_message = True
# 广播输入结束
await MessageRouter.broadcast_typing(room_id, agent_id, False)
# 轮次间隔
await asyncio.sleep(config.round_interval)
# 检查是否需要共识判断
if round_has_message and moderator:
consecutive_no_speak = 0
# 每隔几轮检查一次共识
if context.current_round % 3 == 0 or context.current_round >= config.max_rounds - 5:
consensus_result = await ConsensusManager.check_consensus(
moderator, context, chatroom
)
if consensus_result.get("consensus_reached", False):
confidence = consensus_result.get("confidence", 0)
if confidence >= config.consensus_threshold:
# 达成共识,结束讨论
return await cls._finalize_discussion(
context,
consensus_result,
"consensus"
)
else:
consecutive_no_speak += 1
# 连续多轮无人发言,检查共识或结束
if consecutive_no_speak >= 3:
if moderator:
consensus_result = await ConsensusManager.check_consensus(
moderator, context, chatroom
)
return await cls._finalize_discussion(
context,
consensus_result,
"no_more_discussion"
)
else:
return await cls._finalize_discussion(
context,
{"consensus_reached": False, "summary": "讨论结束,无明确共识"},
"no_more_discussion"
)
# 达到最大轮次
if moderator:
consensus_result = await ConsensusManager.check_consensus(
moderator, context, chatroom
)
else:
consensus_result = {"consensus_reached": False, "summary": "达到最大轮次限制"}
return await cls._finalize_discussion(
context,
consensus_result,
"max_rounds"
)
@classmethod
async def _should_agent_speak(
cls,
agent: Agent,
context: DiscussionContext,
chatroom: ChatRoom
) -> tuple[bool, str]:
"""
判断Agent是否应该发言
Args:
agent: Agent实例
context: 讨论上下文
chatroom: 聊天室
Returns:
(是否发言, 发言内容)
"""
# 构建判断提示词
recent_messages = context.get_recent_messages(chatroom.get_config().message_history_size)
history_text = ""
for msg in recent_messages:
if msg.agent_id:
history_text += f"[{msg.agent_id}]: {msg.content}\n\n"
else:
history_text += f"[系统]: {msg.content}\n\n"
prompt = f"""你是{agent.name},角色是{agent.role}
{agent.system_prompt}
当前讨论目标:{context.objective}
对话历史:
{history_text if history_text else "(还没有对话)"}
当前是第{context.current_round}轮讨论。
请根据你的角色判断:
1. 你是否有新的观点或建议要分享?
2. 你是否需要回应其他人的观点?
3. 当前讨论是否需要你的专业意见?
如果你认为需要发言,请直接给出你的发言内容。
如果你认为暂时不需要发言(例如等待更多信息、当前轮次已有足够讨论、或者你的观点已经充分表达),请只回复"PASS"
注意:
- 请保持发言简洁有力每次发言控制在200字以内
- 避免重复已经说过的内容
- 如果已经达成共识或接近共识可以选择PASS"""
try:
# 调用AI接口
response = await AIProviderService.chat(
provider_id=agent.provider_id,
messages=[{"role": "user", "content": prompt}],
temperature=agent.temperature,
max_tokens=agent.max_tokens
)
if not response.success:
logger.warning(f"Agent {agent.agent_id} 响应失败: {response.error}")
return False, ""
content = response.content.strip()
# 判断是否PASS
if content.upper() == "PASS" or content.upper().startswith("PASS"):
return False, ""
return True, content
except Exception as e:
logger.error(f"Agent {agent.agent_id} 判断发言异常: {e}")
return False, ""
@classmethod
async def _finalize_discussion(
cls,
context: DiscussionContext,
consensus_result: Dict[str, Any],
end_reason: str
) -> DiscussionResult:
"""
完成讨论,保存结果
Args:
context: 讨论上下文
consensus_result: 共识判断结果
end_reason: 结束原因
Returns:
讨论结果
"""
room_id = context.room_id
# 获取讨论结果记录
discussion_result = await DiscussionResult.find_one(
DiscussionResult.discussion_id == context.discussion_id
)
if discussion_result:
# 更新统计
discussion_result.update_stats(
total_rounds=context.current_round,
total_messages=len(context.messages),
agent_contributions=context.agent_speak_counts
)
# 标记完成
discussion_result.mark_completed(
consensus_reached=consensus_result.get("consensus_reached", False),
confidence=consensus_result.get("confidence", 0),
summary=consensus_result.get("summary", ""),
action_items=consensus_result.get("action_items", []),
unresolved_issues=consensus_result.get("unresolved_issues", []),
end_reason=end_reason
)
await discussion_result.save()
# 更新聊天室状态
await ChatRoomService.update_status(room_id, ChatRoomStatus.COMPLETED)
# 发送系统消息
summary_text = f"""讨论结束
结果:{"达成共识" if consensus_result.get("consensus_reached") else "未达成明确共识"}
置信度:{consensus_result.get("confidence", 0):.0%}
摘要:{consensus_result.get("summary", "")}
行动项:
{chr(10).join("- " + item for item in consensus_result.get("action_items", [])) or ""}
未解决问题:
{chr(10).join("- " + issue for issue in consensus_result.get("unresolved_issues", [])) or ""}
共进行 {context.current_round} 轮讨论,产生 {len(context.messages)} 条消息。"""
await MessageRouter.save_and_broadcast_message(
room_id=room_id,
discussion_id=context.discussion_id,
agent_id=None,
content=summary_text,
message_type=MessageType.SYSTEM.value,
round_num=context.current_round
)
# 广播讨论结束
await MessageRouter.broadcast_status(room_id, "discussion_completed", {
"discussion_id": context.discussion_id,
"consensus_reached": consensus_result.get("consensus_reached", False),
"end_reason": end_reason
})
logger.info(f"讨论结束: {room_id}, 原因: {end_reason}")
return discussion_result
@classmethod
async def _handle_discussion_error(
cls,
room_id: str,
discussion_id: str,
error: str
) -> None:
"""
处理讨论错误
Args:
room_id: 聊天室ID
discussion_id: 讨论ID
error: 错误信息
"""
# 更新聊天室状态
await ChatRoomService.update_status(room_id, ChatRoomStatus.ERROR)
# 更新讨论结果
discussion_result = await DiscussionResult.find_one(
DiscussionResult.discussion_id == discussion_id
)
if discussion_result:
discussion_result.status = "failed"
discussion_result.end_reason = f"error: {error}"
discussion_result.updated_at = datetime.utcnow()
await discussion_result.save()
# 广播错误
await MessageRouter.broadcast_error(room_id, error)
@classmethod
def get_active_discussion(cls, room_id: str) -> Optional[DiscussionContext]:
"""
获取活跃的讨论上下文
Args:
room_id: 聊天室ID
Returns:
讨论上下文或None
"""
return cls._active_discussions.get(room_id)
@classmethod
def is_discussion_active(cls, room_id: str) -> bool:
"""
检查是否有活跃讨论
Args:
room_id: 聊天室ID
Returns:
是否活跃
"""
return room_id in cls._active_discussions

View File

@@ -0,0 +1,252 @@
"""
MCP服务
管理MCP工具的集成和调用
"""
import json
import os
from typing import List, Dict, Any, Optional
from pathlib import Path
from loguru import logger
class MCPService:
"""
MCP工具服务
集成MCP服务器提供工具调用能力
"""
# MCP服务器配置目录
MCP_CONFIG_DIR = Path(os.getenv("CURSOR_MCP_DIR", "~/.cursor/mcps")).expanduser()
# 已注册的工具: server_name -> List[tool_info]
_registered_tools: Dict[str, List[Dict[str, Any]]] = {}
# Agent工具映射: agent_id -> List[tool_name]
_agent_tools: Dict[str, List[str]] = {}
@classmethod
async def initialize(cls) -> None:
"""
初始化MCP服务
扫描并注册可用的MCP工具
"""
logger.info("初始化MCP服务...")
if not cls.MCP_CONFIG_DIR.exists():
logger.warning(f"MCP配置目录不存在: {cls.MCP_CONFIG_DIR}")
return
# 扫描MCP服务器目录
for server_dir in cls.MCP_CONFIG_DIR.iterdir():
if server_dir.is_dir():
await cls._scan_server(server_dir)
logger.info(f"MCP服务初始化完成已注册 {len(cls._registered_tools)} 个服务器")
@classmethod
async def _scan_server(cls, server_dir: Path) -> None:
"""
扫描MCP服务器目录
Args:
server_dir: 服务器目录
"""
server_name = server_dir.name
tools_dir = server_dir / "tools"
if not tools_dir.exists():
return
tools = []
for tool_file in tools_dir.glob("*.json"):
try:
with open(tool_file, "r", encoding="utf-8") as f:
tool_info = json.load(f)
tool_info["_file"] = str(tool_file)
tools.append(tool_info)
except Exception as e:
logger.warning(f"加载MCP工具配置失败: {tool_file} - {e}")
if tools:
cls._registered_tools[server_name] = tools
logger.debug(f"注册MCP服务器: {server_name}, 工具数: {len(tools)}")
@classmethod
def list_servers(cls) -> List[str]:
"""
列出所有可用的MCP服务器
Returns:
服务器名称列表
"""
return list(cls._registered_tools.keys())
@classmethod
def list_tools(cls, server: Optional[str] = None) -> List[Dict[str, Any]]:
"""
列出可用的MCP工具
Args:
server: 服务器名称(可选,不指定则返回所有)
Returns:
工具信息列表
"""
if server:
return cls._registered_tools.get(server, [])
# 返回所有工具
all_tools = []
for server_name, tools in cls._registered_tools.items():
for tool in tools:
tool_copy = tool.copy()
tool_copy["server"] = server_name
all_tools.append(tool_copy)
return all_tools
@classmethod
def get_tool(cls, server: str, tool_name: str) -> Optional[Dict[str, Any]]:
"""
获取指定工具的信息
Args:
server: 服务器名称
tool_name: 工具名称
Returns:
工具信息或None
"""
tools = cls._registered_tools.get(server, [])
for tool in tools:
if tool.get("name") == tool_name:
return tool
return None
@classmethod
async def call_tool(
cls,
server: str,
tool_name: str,
arguments: Dict[str, Any]
) -> Dict[str, Any]:
"""
调用MCP工具
Args:
server: 服务器名称
tool_name: 工具名称
arguments: 工具参数
Returns:
调用结果
"""
tool = cls.get_tool(server, tool_name)
if not tool:
return {
"success": False,
"error": f"工具不存在: {server}/{tool_name}"
}
# TODO: 实际的MCP工具调用逻辑
# 这里需要根据MCP协议实现工具调用
# 目前返回模拟结果
logger.info(f"调用MCP工具: {server}/{tool_name}, 参数: {arguments}")
return {
"success": True,
"result": f"MCP工具调用: {tool_name}",
"tool": tool_name,
"server": server,
"arguments": arguments
}
@classmethod
def register_tool_for_agent(
cls,
agent_id: str,
tool_name: str
) -> bool:
"""
为Agent注册可用工具
Args:
agent_id: Agent ID
tool_name: 工具名称(格式: server/tool_name
Returns:
是否注册成功
"""
if agent_id not in cls._agent_tools:
cls._agent_tools[agent_id] = []
if tool_name not in cls._agent_tools[agent_id]:
cls._agent_tools[agent_id].append(tool_name)
return True
return False
@classmethod
def unregister_tool_for_agent(
cls,
agent_id: str,
tool_name: str
) -> bool:
"""
为Agent注销工具
Args:
agent_id: Agent ID
tool_name: 工具名称
Returns:
是否注销成功
"""
if agent_id in cls._agent_tools:
if tool_name in cls._agent_tools[agent_id]:
cls._agent_tools[agent_id].remove(tool_name)
return True
return False
@classmethod
def get_agent_tools(cls, agent_id: str) -> List[str]:
"""
获取Agent可用的工具列表
Args:
agent_id: Agent ID
Returns:
工具名称列表
"""
return cls._agent_tools.get(agent_id, [])
@classmethod
def get_tools_for_prompt(cls, agent_id: str) -> str:
"""
获取用于提示词的工具描述
Args:
agent_id: Agent ID
Returns:
工具描述文本
"""
tool_names = cls.get_agent_tools(agent_id)
if not tool_names:
return ""
descriptions = []
for full_name in tool_names:
parts = full_name.split("/", 1)
if len(parts) == 2:
server, tool_name = parts
tool = cls.get_tool(server, tool_name)
if tool:
desc = tool.get("description", "无描述")
descriptions.append(f"- {tool_name}: {desc}")
if not descriptions:
return ""
return "你可以使用以下工具:\n" + "\n".join(descriptions)

View File

@@ -0,0 +1,416 @@
"""
记忆服务
管理Agent的记忆存储和检索
"""
import uuid
from datetime import datetime, timedelta
from typing import List, Dict, Any, Optional
import numpy as np
from loguru import logger
from models.agent_memory import AgentMemory, MemoryType
class MemoryService:
"""
Agent记忆服务
提供记忆的存储、检索和管理功能
"""
# 嵌入模型(延迟加载)
_embedding_model = None
@classmethod
def _get_embedding_model(cls):
"""
获取嵌入模型实例(延迟加载)
"""
if cls._embedding_model is None:
try:
from sentence_transformers import SentenceTransformer
cls._embedding_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
logger.info("嵌入模型加载成功")
except Exception as e:
logger.warning(f"嵌入模型加载失败: {e}")
return None
return cls._embedding_model
@classmethod
async def create_memory(
cls,
agent_id: str,
content: str,
memory_type: str = MemoryType.SHORT_TERM.value,
importance: float = 0.5,
source_room_id: Optional[str] = None,
source_discussion_id: Optional[str] = None,
tags: Optional[List[str]] = None,
expires_in_hours: Optional[int] = None
) -> AgentMemory:
"""
创建新的记忆
Args:
agent_id: Agent ID
content: 记忆内容
memory_type: 记忆类型
importance: 重要性评分
source_room_id: 来源聊天室
source_discussion_id: 来源讨论
tags: 标签
expires_in_hours: 过期时间(小时)
Returns:
创建的AgentMemory文档
"""
memory_id = f"mem-{uuid.uuid4().hex[:12]}"
# 生成向量嵌入
embedding = await cls._generate_embedding(content)
# 生成摘要
summary = content[:100] + "..." if len(content) > 100 else content
# 计算过期时间
expires_at = None
if expires_in_hours:
expires_at = datetime.utcnow() + timedelta(hours=expires_in_hours)
memory = AgentMemory(
memory_id=memory_id,
agent_id=agent_id,
memory_type=memory_type,
content=content,
summary=summary,
embedding=embedding,
importance=importance,
source_room_id=source_room_id,
source_discussion_id=source_discussion_id,
tags=tags or [],
created_at=datetime.utcnow(),
last_accessed=datetime.utcnow(),
expires_at=expires_at
)
await memory.insert()
logger.debug(f"创建记忆: {memory_id} for Agent {agent_id}")
return memory
@classmethod
async def get_memory(cls, memory_id: str) -> Optional[AgentMemory]:
"""
获取指定记忆
Args:
memory_id: 记忆ID
Returns:
AgentMemory文档或None
"""
return await AgentMemory.find_one(AgentMemory.memory_id == memory_id)
@classmethod
async def get_agent_memories(
cls,
agent_id: str,
memory_type: Optional[str] = None,
limit: int = 50
) -> List[AgentMemory]:
"""
获取Agent的记忆列表
Args:
agent_id: Agent ID
memory_type: 记忆类型(可选)
limit: 返回数量限制
Returns:
记忆列表
"""
query = {"agent_id": agent_id}
if memory_type:
query["memory_type"] = memory_type
return await AgentMemory.find(query).sort(
"-importance", "-last_accessed"
).limit(limit).to_list()
@classmethod
async def search_memories(
cls,
agent_id: str,
query: str,
limit: int = 10,
memory_type: Optional[str] = None,
min_relevance: float = 0.3
) -> List[Dict[str, Any]]:
"""
搜索相关记忆
Args:
agent_id: Agent ID
query: 查询文本
limit: 返回数量
memory_type: 记忆类型(可选)
min_relevance: 最小相关性阈值
Returns:
带相关性分数的记忆列表
"""
# 生成查询向量
query_embedding = await cls._generate_embedding(query)
if not query_embedding:
# 无法生成向量时,使用文本匹配
return await cls._text_search(agent_id, query, limit, memory_type)
# 获取Agent的所有记忆
filter_query = {"agent_id": agent_id}
if memory_type:
filter_query["memory_type"] = memory_type
memories = await AgentMemory.find(filter_query).to_list()
# 计算相似度
results = []
for memory in memories:
if memory.is_expired():
continue
if memory.embedding:
similarity = cls._cosine_similarity(query_embedding, memory.embedding)
relevance = memory.calculate_relevance_score(similarity)
if relevance >= min_relevance:
results.append({
"memory": memory,
"similarity": similarity,
"relevance": relevance
})
# 按相关性排序
results.sort(key=lambda x: x["relevance"], reverse=True)
# 更新访问记录
for item in results[:limit]:
memory = item["memory"]
memory.access()
await memory.save()
return results[:limit]
@classmethod
async def update_memory(
cls,
memory_id: str,
**kwargs
) -> Optional[AgentMemory]:
"""
更新记忆
Args:
memory_id: 记忆ID
**kwargs: 要更新的字段
Returns:
更新后的AgentMemory或None
"""
memory = await cls.get_memory(memory_id)
if not memory:
return None
# 如果更新了内容,重新生成嵌入
if "content" in kwargs:
kwargs["embedding"] = await cls._generate_embedding(kwargs["content"])
kwargs["summary"] = kwargs["content"][:100] + "..." if len(kwargs["content"]) > 100 else kwargs["content"]
for key, value in kwargs.items():
if hasattr(memory, key):
setattr(memory, key, value)
await memory.save()
return memory
@classmethod
async def delete_memory(cls, memory_id: str) -> bool:
"""
删除记忆
Args:
memory_id: 记忆ID
Returns:
是否删除成功
"""
memory = await cls.get_memory(memory_id)
if not memory:
return False
await memory.delete()
return True
@classmethod
async def delete_agent_memories(
cls,
agent_id: str,
memory_type: Optional[str] = None
) -> int:
"""
删除Agent的记忆
Args:
agent_id: Agent ID
memory_type: 记忆类型(可选)
Returns:
删除的数量
"""
query = {"agent_id": agent_id}
if memory_type:
query["memory_type"] = memory_type
result = await AgentMemory.find(query).delete()
return result.deleted_count if result else 0
@classmethod
async def cleanup_expired_memories(cls) -> int:
"""
清理过期的记忆
Returns:
清理的数量
"""
now = datetime.utcnow()
result = await AgentMemory.find(
{"expires_at": {"$lt": now}}
).delete()
count = result.deleted_count if result else 0
if count > 0:
logger.info(f"清理了 {count} 条过期记忆")
return count
@classmethod
async def consolidate_memories(
cls,
agent_id: str,
min_importance: float = 0.7,
max_age_days: int = 30
) -> None:
"""
整合记忆(将重要的短期记忆转为长期记忆)
Args:
agent_id: Agent ID
min_importance: 最小重要性阈值
max_age_days: 最大年龄(天)
"""
cutoff_date = datetime.utcnow() - timedelta(days=max_age_days)
# 查找符合条件的短期记忆
memories = await AgentMemory.find({
"agent_id": agent_id,
"memory_type": MemoryType.SHORT_TERM.value,
"importance": {"$gte": min_importance},
"created_at": {"$lt": cutoff_date}
}).to_list()
for memory in memories:
memory.memory_type = MemoryType.LONG_TERM.value
memory.expires_at = None # 长期记忆不过期
await memory.save()
if memories:
logger.info(f"整合了 {len(memories)} 条记忆为长期记忆: Agent {agent_id}")
@classmethod
async def _generate_embedding(cls, text: str) -> List[float]:
"""
生成文本的向量嵌入
Args:
text: 文本内容
Returns:
向量嵌入列表
"""
model = cls._get_embedding_model()
if model is None:
return []
try:
embedding = model.encode(text, convert_to_numpy=True)
return embedding.tolist()
except Exception as e:
logger.warning(f"生成嵌入失败: {e}")
return []
@classmethod
def _cosine_similarity(cls, vec1: List[float], vec2: List[float]) -> float:
"""
计算余弦相似度
Args:
vec1: 向量1
vec2: 向量2
Returns:
相似度 (0-1)
"""
if not vec1 or not vec2:
return 0.0
try:
a = np.array(vec1)
b = np.array(vec2)
similarity = np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
return float(max(0, similarity))
except Exception:
return 0.0
@classmethod
async def _text_search(
cls,
agent_id: str,
query: str,
limit: int,
memory_type: Optional[str]
) -> List[Dict[str, Any]]:
"""
文本搜索(后备方案)
Args:
agent_id: Agent ID
query: 查询文本
limit: 返回数量
memory_type: 记忆类型
Returns:
记忆列表
"""
filter_query = {"agent_id": agent_id}
if memory_type:
filter_query["memory_type"] = memory_type
# 简单的文本匹配
memories = await AgentMemory.find(filter_query).to_list()
results = []
query_lower = query.lower()
for memory in memories:
if memory.is_expired():
continue
content_lower = memory.content.lower()
if query_lower in content_lower:
# 计算简单的匹配分数
score = len(query_lower) / len(content_lower)
results.append({
"memory": memory,
"similarity": score,
"relevance": score * memory.importance
})
results.sort(key=lambda x: x["relevance"], reverse=True)
return results[:limit]

View File

@@ -0,0 +1,335 @@
"""
消息路由服务
管理消息的发送和广播
"""
import uuid
import asyncio
from datetime import datetime
from typing import List, Dict, Any, Optional, Callable, Set
from dataclasses import dataclass, field
from loguru import logger
from fastapi import WebSocket
from models.message import Message, MessageType
from models.chatroom import ChatRoom
@dataclass
class WebSocketConnection:
"""WebSocket连接信息"""
websocket: WebSocket
room_id: str
connected_at: datetime = field(default_factory=datetime.utcnow)
class MessageRouter:
"""
消息路由器
管理WebSocket连接和消息广播
"""
# 房间连接映射: room_id -> Set[WebSocket]
_room_connections: Dict[str, Set[WebSocket]] = {}
# 消息回调: 用于外部订阅消息
_message_callbacks: List[Callable] = []
@classmethod
async def connect(cls, room_id: str, websocket: WebSocket) -> None:
"""
建立WebSocket连接
Args:
room_id: 聊天室ID
websocket: WebSocket实例
"""
await websocket.accept()
if room_id not in cls._room_connections:
cls._room_connections[room_id] = set()
cls._room_connections[room_id].add(websocket)
logger.info(f"WebSocket连接建立: {room_id}, 当前连接数: {len(cls._room_connections[room_id])}")
@classmethod
async def disconnect(cls, room_id: str, websocket: WebSocket) -> None:
"""
断开WebSocket连接
Args:
room_id: 聊天室ID
websocket: WebSocket实例
"""
if room_id in cls._room_connections:
cls._room_connections[room_id].discard(websocket)
# 清理空房间
if not cls._room_connections[room_id]:
del cls._room_connections[room_id]
logger.info(f"WebSocket连接断开: {room_id}")
@classmethod
async def broadcast_to_room(
cls,
room_id: str,
message: Dict[str, Any]
) -> None:
"""
向聊天室广播消息
Args:
room_id: 聊天室ID
message: 消息内容
"""
if room_id not in cls._room_connections:
return
# 获取所有连接
connections = cls._room_connections[room_id].copy()
# 并发发送
tasks = []
for websocket in connections:
tasks.append(cls._send_message(room_id, websocket, message))
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
@classmethod
async def _send_message(
cls,
room_id: str,
websocket: WebSocket,
message: Dict[str, Any]
) -> None:
"""
向单个WebSocket发送消息
Args:
room_id: 聊天室ID
websocket: WebSocket实例
message: 消息内容
"""
try:
await websocket.send_json(message)
except Exception as e:
logger.warning(f"WebSocket发送失败: {e}")
# 移除断开的连接
await cls.disconnect(room_id, websocket)
@classmethod
async def save_and_broadcast_message(
cls,
room_id: str,
discussion_id: str,
agent_id: Optional[str],
content: str,
message_type: str = MessageType.TEXT.value,
round_num: int = 0,
attachments: Optional[List[Dict[str, Any]]] = None,
tool_calls: Optional[List[Dict[str, Any]]] = None,
tool_results: Optional[List[Dict[str, Any]]] = None
) -> Message:
"""
保存消息并广播
Args:
room_id: 聊天室ID
discussion_id: 讨论ID
agent_id: 发送Agent ID
content: 消息内容
message_type: 消息类型
round_num: 轮次号
attachments: 附件
tool_calls: 工具调用
tool_results: 工具结果
Returns:
保存的Message文档
"""
# 创建消息
message = Message(
message_id=f"msg-{uuid.uuid4().hex[:12]}",
room_id=room_id,
discussion_id=discussion_id,
agent_id=agent_id,
content=content,
message_type=message_type,
attachments=attachments or [],
round=round_num,
token_count=len(content) // 4, # 粗略估计
tool_calls=tool_calls or [],
tool_results=tool_results or [],
created_at=datetime.utcnow()
)
await message.insert()
# 构建广播消息
broadcast_data = {
"type": "message",
"data": {
"message_id": message.message_id,
"room_id": message.room_id,
"discussion_id": message.discussion_id,
"agent_id": message.agent_id,
"content": message.content,
"message_type": message.message_type,
"round": message.round,
"created_at": message.created_at.isoformat()
}
}
# 广播消息
await cls.broadcast_to_room(room_id, broadcast_data)
# 触发回调
for callback in cls._message_callbacks:
try:
await callback(message)
except Exception as e:
logger.error(f"消息回调执行失败: {e}")
return message
@classmethod
async def broadcast_status(
cls,
room_id: str,
status: str,
data: Optional[Dict[str, Any]] = None
) -> None:
"""
广播状态更新
Args:
room_id: 聊天室ID
status: 状态类型
data: 附加数据
"""
message = {
"type": "status",
"status": status,
"data": data or {},
"timestamp": datetime.utcnow().isoformat()
}
await cls.broadcast_to_room(room_id, message)
@classmethod
async def broadcast_typing(
cls,
room_id: str,
agent_id: str,
is_typing: bool = True
) -> None:
"""
广播Agent输入状态
Args:
room_id: 聊天室ID
agent_id: Agent ID
is_typing: 是否正在输入
"""
message = {
"type": "typing",
"agent_id": agent_id,
"is_typing": is_typing,
"timestamp": datetime.utcnow().isoformat()
}
await cls.broadcast_to_room(room_id, message)
@classmethod
async def broadcast_round_info(
cls,
room_id: str,
round_num: int,
total_rounds: int
) -> None:
"""
广播轮次信息
Args:
room_id: 聊天室ID
round_num: 当前轮次
total_rounds: 最大轮次
"""
message = {
"type": "round",
"round": round_num,
"total_rounds": total_rounds,
"timestamp": datetime.utcnow().isoformat()
}
await cls.broadcast_to_room(room_id, message)
@classmethod
async def broadcast_error(
cls,
room_id: str,
error: str,
agent_id: Optional[str] = None
) -> None:
"""
广播错误信息
Args:
room_id: 聊天室ID
error: 错误信息
agent_id: 相关Agent ID
"""
message = {
"type": "error",
"error": error,
"agent_id": agent_id,
"timestamp": datetime.utcnow().isoformat()
}
await cls.broadcast_to_room(room_id, message)
@classmethod
def register_callback(cls, callback: Callable) -> None:
"""
注册消息回调
Args:
callback: 回调函数接收Message参数
"""
cls._message_callbacks.append(callback)
@classmethod
def unregister_callback(cls, callback: Callable) -> None:
"""
注销消息回调
Args:
callback: 回调函数
"""
if callback in cls._message_callbacks:
cls._message_callbacks.remove(callback)
@classmethod
def get_connection_count(cls, room_id: str) -> int:
"""
获取房间连接数
Args:
room_id: 聊天室ID
Returns:
连接数
"""
return len(cls._room_connections.get(room_id, set()))
@classmethod
def get_all_room_ids(cls) -> List[str]:
"""
获取所有活跃房间ID
Returns:
房间ID列表
"""
return list(cls._room_connections.keys())