feat: AI聊天室多Agent协作讨论平台
- 实现Agent管理,支持AI辅助生成系统提示词 - 支持多个AI提供商(OpenRouter、智谱、MiniMax等) - 实现聊天室和讨论引擎 - WebSocket实时消息推送 - 前端使用React + Ant Design - 后端使用FastAPI + MongoDB Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
22
backend/services/__init__.py
Normal file
22
backend/services/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
业务服务模块
|
||||
"""
|
||||
from .ai_provider_service import AIProviderService
|
||||
from .agent_service import AgentService
|
||||
from .chatroom_service import ChatRoomService
|
||||
from .message_router import MessageRouter
|
||||
from .discussion_engine import DiscussionEngine
|
||||
from .consensus_manager import ConsensusManager
|
||||
from .mcp_service import MCPService
|
||||
from .memory_service import MemoryService
|
||||
|
||||
__all__ = [
|
||||
"AIProviderService",
|
||||
"AgentService",
|
||||
"ChatRoomService",
|
||||
"MessageRouter",
|
||||
"DiscussionEngine",
|
||||
"ConsensusManager",
|
||||
"MCPService",
|
||||
"MemoryService",
|
||||
]
|
||||
438
backend/services/agent_service.py
Normal file
438
backend/services/agent_service.py
Normal file
@@ -0,0 +1,438 @@
|
||||
"""
|
||||
Agent服务
|
||||
管理AI代理的配置
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from loguru import logger
|
||||
|
||||
from models.agent import Agent
|
||||
from services.ai_provider_service import AIProviderService
|
||||
|
||||
|
||||
class AgentService:
|
||||
"""
|
||||
Agent服务类
|
||||
负责Agent的CRUD操作
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def create_agent(
|
||||
cls,
|
||||
name: str,
|
||||
role: str,
|
||||
system_prompt: str,
|
||||
provider_id: str,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
capabilities: Optional[Dict[str, Any]] = None,
|
||||
behavior: Optional[Dict[str, Any]] = None,
|
||||
avatar: Optional[str] = None,
|
||||
color: str = "#1890ff"
|
||||
) -> Agent:
|
||||
"""
|
||||
创建新的Agent
|
||||
|
||||
Args:
|
||||
name: Agent名称
|
||||
role: 角色定义
|
||||
system_prompt: 系统提示词
|
||||
provider_id: 使用的AI接口ID
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
capabilities: 能力配置
|
||||
behavior: 行为配置
|
||||
avatar: 头像URL
|
||||
color: 代表颜色
|
||||
|
||||
Returns:
|
||||
创建的Agent文档
|
||||
"""
|
||||
# 验证AI接口存在
|
||||
provider = await AIProviderService.get_provider(provider_id)
|
||||
if not provider:
|
||||
raise ValueError(f"AI接口不存在: {provider_id}")
|
||||
|
||||
# 生成唯一ID
|
||||
agent_id = f"agent-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 默认能力配置
|
||||
default_capabilities = {
|
||||
"memory_enabled": False,
|
||||
"mcp_tools": [],
|
||||
"skills": [],
|
||||
"multimodal": False
|
||||
}
|
||||
if capabilities:
|
||||
default_capabilities.update(capabilities)
|
||||
|
||||
# 默认行为配置
|
||||
default_behavior = {
|
||||
"speak_threshold": 0.5,
|
||||
"max_speak_per_round": 2,
|
||||
"speak_style": "balanced"
|
||||
}
|
||||
if behavior:
|
||||
default_behavior.update(behavior)
|
||||
|
||||
# 创建文档
|
||||
agent = Agent(
|
||||
agent_id=agent_id,
|
||||
name=name,
|
||||
role=role,
|
||||
system_prompt=system_prompt,
|
||||
provider_id=provider_id,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
capabilities=default_capabilities,
|
||||
behavior=default_behavior,
|
||||
avatar=avatar,
|
||||
color=color,
|
||||
enabled=True,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
await agent.insert()
|
||||
|
||||
logger.info(f"创建Agent: {agent_id} ({name})")
|
||||
return agent
|
||||
|
||||
@classmethod
|
||||
async def get_agent(cls, agent_id: str) -> Optional[Agent]:
|
||||
"""
|
||||
获取指定Agent
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
Agent文档或None
|
||||
"""
|
||||
return await Agent.find_one(Agent.agent_id == agent_id)
|
||||
|
||||
@classmethod
|
||||
async def get_all_agents(
|
||||
cls,
|
||||
enabled_only: bool = False
|
||||
) -> List[Agent]:
|
||||
"""
|
||||
获取所有Agent
|
||||
|
||||
Args:
|
||||
enabled_only: 是否只返回启用的Agent
|
||||
|
||||
Returns:
|
||||
Agent列表
|
||||
"""
|
||||
if enabled_only:
|
||||
return await Agent.find(Agent.enabled == True).to_list()
|
||||
return await Agent.find_all().to_list()
|
||||
|
||||
@classmethod
|
||||
async def get_agents_by_ids(
|
||||
cls,
|
||||
agent_ids: List[str]
|
||||
) -> List[Agent]:
|
||||
"""
|
||||
根据ID列表获取多个Agent
|
||||
|
||||
Args:
|
||||
agent_ids: Agent ID列表
|
||||
|
||||
Returns:
|
||||
Agent列表
|
||||
"""
|
||||
return await Agent.find(
|
||||
{"agent_id": {"$in": agent_ids}}
|
||||
).to_list()
|
||||
|
||||
@classmethod
|
||||
async def update_agent(
|
||||
cls,
|
||||
agent_id: str,
|
||||
**kwargs
|
||||
) -> Optional[Agent]:
|
||||
"""
|
||||
更新Agent配置
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
**kwargs: 要更新的字段
|
||||
|
||||
Returns:
|
||||
更新后的Agent或None
|
||||
"""
|
||||
agent = await cls.get_agent(agent_id)
|
||||
if not agent:
|
||||
return None
|
||||
|
||||
# 如果更新了provider_id,验证其存在
|
||||
if "provider_id" in kwargs:
|
||||
provider = await AIProviderService.get_provider(kwargs["provider_id"])
|
||||
if not provider:
|
||||
raise ValueError(f"AI接口不存在: {kwargs['provider_id']}")
|
||||
|
||||
# 更新字段
|
||||
kwargs["updated_at"] = datetime.utcnow()
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(agent, key):
|
||||
setattr(agent, key, value)
|
||||
|
||||
await agent.save()
|
||||
|
||||
logger.info(f"更新Agent: {agent_id}")
|
||||
return agent
|
||||
|
||||
@classmethod
|
||||
async def delete_agent(cls, agent_id: str) -> bool:
|
||||
"""
|
||||
删除Agent
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
agent = await cls.get_agent(agent_id)
|
||||
if not agent:
|
||||
return False
|
||||
|
||||
await agent.delete()
|
||||
|
||||
logger.info(f"删除Agent: {agent_id}")
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def test_agent(
|
||||
cls,
|
||||
agent_id: str,
|
||||
test_message: str = "你好,请简单介绍一下你自己。"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
测试Agent对话
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
test_message: 测试消息
|
||||
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
agent = await cls.get_agent(agent_id)
|
||||
if not agent:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Agent不存在: {agent_id}"
|
||||
}
|
||||
|
||||
if not agent.enabled:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Agent已禁用"
|
||||
}
|
||||
|
||||
# 构建消息
|
||||
messages = [
|
||||
{"role": "system", "content": agent.system_prompt},
|
||||
{"role": "user", "content": test_message}
|
||||
]
|
||||
|
||||
# 调用AI接口
|
||||
response = await AIProviderService.chat(
|
||||
provider_id=agent.provider_id,
|
||||
messages=messages,
|
||||
temperature=agent.temperature,
|
||||
max_tokens=agent.max_tokens
|
||||
)
|
||||
|
||||
if response.success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "测试成功",
|
||||
"response": response.content,
|
||||
"model": response.model,
|
||||
"tokens": response.total_tokens,
|
||||
"latency_ms": response.latency_ms
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": response.error
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def duplicate_agent(
|
||||
cls,
|
||||
agent_id: str,
|
||||
new_name: Optional[str] = None
|
||||
) -> Optional[Agent]:
|
||||
"""
|
||||
复制Agent
|
||||
|
||||
Args:
|
||||
agent_id: 源Agent ID
|
||||
new_name: 新Agent名称
|
||||
|
||||
Returns:
|
||||
新创建的Agent或None
|
||||
"""
|
||||
source_agent = await cls.get_agent(agent_id)
|
||||
if not source_agent:
|
||||
return None
|
||||
|
||||
return await cls.create_agent(
|
||||
name=new_name or f"{source_agent.name} (副本)",
|
||||
role=source_agent.role,
|
||||
system_prompt=source_agent.system_prompt,
|
||||
provider_id=source_agent.provider_id,
|
||||
temperature=source_agent.temperature,
|
||||
max_tokens=source_agent.max_tokens,
|
||||
capabilities=source_agent.capabilities,
|
||||
behavior=source_agent.behavior,
|
||||
avatar=source_agent.avatar,
|
||||
color=source_agent.color
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def generate_system_prompt(
|
||||
cls,
|
||||
provider_id: str,
|
||||
name: str,
|
||||
role: str,
|
||||
description: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
使用AI生成Agent系统提示词
|
||||
|
||||
Args:
|
||||
provider_id: AI接口ID
|
||||
name: Agent名称
|
||||
role: 角色定位
|
||||
description: 额外描述(可选)
|
||||
|
||||
Returns:
|
||||
生成结果,包含success和生成的prompt
|
||||
"""
|
||||
# 验证AI接口存在
|
||||
provider = await AIProviderService.get_provider(provider_id)
|
||||
if not provider:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"AI接口不存在: {provider_id}"
|
||||
}
|
||||
|
||||
# 构建生成提示词的请求
|
||||
generate_prompt = f"""请为一个AI Agent编写系统提示词(system prompt)。
|
||||
|
||||
Agent名称:{name}
|
||||
角色定位:{role}
|
||||
{f'补充说明:{description}' if description else ''}
|
||||
|
||||
要求:
|
||||
1. 提示词应简洁专业,控制在200字以内
|
||||
2. 明确该Agent的核心职责和专业领域
|
||||
3. 说明在多Agent讨论中应该关注什么
|
||||
4. 使用中文编写
|
||||
5. 不要包含任何问候语或开场白,直接给出提示词内容
|
||||
|
||||
请直接输出系统提示词,不要有任何额外的解释或包装。"""
|
||||
|
||||
try:
|
||||
messages = [{"role": "user", "content": generate_prompt}]
|
||||
|
||||
response = await AIProviderService.chat(
|
||||
provider_id=provider_id,
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
if response.success:
|
||||
# 清理可能的包装文本
|
||||
content = response.content.strip()
|
||||
# 移除可能的markdown代码块标记
|
||||
if content.startswith("```"):
|
||||
lines = content.split("\n")
|
||||
content = "\n".join(lines[1:])
|
||||
if content.endswith("```"):
|
||||
content = content[:-3]
|
||||
content = content.strip()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"prompt": content,
|
||||
"model": response.model,
|
||||
"tokens": response.total_tokens
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": response.error or "生成失败"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"生成系统提示词失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"生成失败: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
# Agent预设模板
|
||||
AGENT_TEMPLATES = {
|
||||
"product_manager": {
|
||||
"name": "产品经理",
|
||||
"role": "产品规划和需求分析专家",
|
||||
"system_prompt": """你是一位经验丰富的产品经理,擅长:
|
||||
- 分析用户需求和痛点
|
||||
- 制定产品策略和路线图
|
||||
- 平衡业务目标和用户体验
|
||||
- 与团队协作推进产品迭代
|
||||
|
||||
在讨论中,你需要从产品角度出发,关注用户价值、商业可行性和优先级排序。
|
||||
请用专业但易懂的语言表达观点。""",
|
||||
"color": "#1890ff"
|
||||
},
|
||||
"developer": {
|
||||
"name": "开发工程师",
|
||||
"role": "技术实现和架构设计专家",
|
||||
"system_prompt": """你是一位资深的软件开发工程师,擅长:
|
||||
- 系统架构设计
|
||||
- 代码实现和优化
|
||||
- 技术方案评估
|
||||
- 性能和安全考量
|
||||
|
||||
在讨论中,你需要从技术角度出发,关注实现可行性、技术债务和最佳实践。
|
||||
请提供具体的技术建议和潜在风险评估。""",
|
||||
"color": "#52c41a"
|
||||
},
|
||||
"designer": {
|
||||
"name": "设计师",
|
||||
"role": "用户体验和界面设计专家",
|
||||
"system_prompt": """你是一位专业的UI/UX设计师,擅长:
|
||||
- 用户体验设计
|
||||
- 界面视觉设计
|
||||
- 交互流程优化
|
||||
- 设计系统构建
|
||||
|
||||
在讨论中,你需要从设计角度出发,关注用户体验、视觉美感和交互流畅性。
|
||||
请提供设计建议并考虑可用性和一致性。""",
|
||||
"color": "#eb2f96"
|
||||
},
|
||||
"moderator": {
|
||||
"name": "主持人",
|
||||
"role": "讨论主持和共识判断专家",
|
||||
"system_prompt": """你是讨论的主持人,负责:
|
||||
- 引导讨论方向
|
||||
- 总结各方观点
|
||||
- 判断是否达成共识
|
||||
- 提炼行动要点
|
||||
|
||||
在讨论中,你需要保持中立,促进有效沟通,并在适当时机总结讨论成果。
|
||||
当各方观点趋于一致时,请明确指出并总结共识内容。""",
|
||||
"color": "#722ed1"
|
||||
}
|
||||
}
|
||||
364
backend/services/ai_provider_service.py
Normal file
364
backend/services/ai_provider_service.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""
|
||||
AI接口提供商服务
|
||||
管理AI接口的配置和调用
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from loguru import logger
|
||||
|
||||
from models.ai_provider import AIProvider
|
||||
from adapters import get_adapter, BaseAdapter, ChatMessage, AdapterResponse
|
||||
from utils.encryption import encrypt_api_key, decrypt_api_key
|
||||
from utils.rate_limiter import rate_limiter
|
||||
|
||||
|
||||
class AIProviderService:
|
||||
"""
|
||||
AI接口提供商服务类
|
||||
负责AI接口的CRUD操作和调用
|
||||
"""
|
||||
|
||||
# 缓存适配器实例
|
||||
_adapter_cache: Dict[str, BaseAdapter] = {}
|
||||
|
||||
@classmethod
|
||||
async def create_provider(
|
||||
cls,
|
||||
provider_type: str,
|
||||
name: str,
|
||||
model: str,
|
||||
api_key: str = "",
|
||||
base_url: str = "",
|
||||
use_proxy: bool = False,
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
rate_limit: Optional[Dict[str, int]] = None,
|
||||
timeout: int = 60,
|
||||
extra_params: Optional[Dict[str, Any]] = None
|
||||
) -> AIProvider:
|
||||
"""
|
||||
创建新的AI接口配置
|
||||
|
||||
Args:
|
||||
provider_type: 提供商类型
|
||||
name: 自定义名称
|
||||
model: 模型名称
|
||||
api_key: API密钥
|
||||
base_url: API基础URL
|
||||
use_proxy: 是否使用代理
|
||||
proxy_config: 代理配置
|
||||
rate_limit: 速率限制配置
|
||||
timeout: 超时时间
|
||||
extra_params: 额外参数
|
||||
|
||||
Returns:
|
||||
创建的AIProvider文档
|
||||
"""
|
||||
# 验证提供商类型
|
||||
try:
|
||||
get_adapter(provider_type)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"不支持的提供商类型: {provider_type}")
|
||||
|
||||
# 生成唯一ID
|
||||
provider_id = f"{provider_type}-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 加密API密钥
|
||||
encrypted_key = encrypt_api_key(api_key) if api_key else ""
|
||||
|
||||
# 创建文档
|
||||
provider = AIProvider(
|
||||
provider_id=provider_id,
|
||||
provider_type=provider_type,
|
||||
name=name,
|
||||
api_key=encrypted_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
use_proxy=use_proxy,
|
||||
proxy_config=proxy_config or {},
|
||||
rate_limit=rate_limit or {"requests_per_minute": 60, "tokens_per_minute": 100000},
|
||||
timeout=timeout,
|
||||
extra_params=extra_params or {},
|
||||
enabled=True,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
await provider.insert()
|
||||
|
||||
# 注册速率限制
|
||||
rate_limiter.register(
|
||||
provider_id,
|
||||
provider.rate_limit.get("requests_per_minute", 60),
|
||||
provider.rate_limit.get("tokens_per_minute", 100000)
|
||||
)
|
||||
|
||||
logger.info(f"创建AI接口配置: {provider_id} ({name})")
|
||||
return provider
|
||||
|
||||
@classmethod
|
||||
async def get_provider(cls, provider_id: str) -> Optional[AIProvider]:
|
||||
"""
|
||||
获取指定AI接口配置
|
||||
|
||||
Args:
|
||||
provider_id: 接口ID
|
||||
|
||||
Returns:
|
||||
AIProvider文档或None
|
||||
"""
|
||||
return await AIProvider.find_one(AIProvider.provider_id == provider_id)
|
||||
|
||||
@classmethod
|
||||
async def get_all_providers(
|
||||
cls,
|
||||
enabled_only: bool = False
|
||||
) -> List[AIProvider]:
|
||||
"""
|
||||
获取所有AI接口配置
|
||||
|
||||
Args:
|
||||
enabled_only: 是否只返回启用的接口
|
||||
|
||||
Returns:
|
||||
AIProvider列表
|
||||
"""
|
||||
if enabled_only:
|
||||
return await AIProvider.find(AIProvider.enabled == True).to_list()
|
||||
return await AIProvider.find_all().to_list()
|
||||
|
||||
@classmethod
|
||||
async def update_provider(
|
||||
cls,
|
||||
provider_id: str,
|
||||
**kwargs
|
||||
) -> Optional[AIProvider]:
|
||||
"""
|
||||
更新AI接口配置
|
||||
|
||||
Args:
|
||||
provider_id: 接口ID
|
||||
**kwargs: 要更新的字段
|
||||
|
||||
Returns:
|
||||
更新后的AIProvider或None
|
||||
"""
|
||||
provider = await cls.get_provider(provider_id)
|
||||
if not provider:
|
||||
return None
|
||||
|
||||
# 如果更新了API密钥,需要加密
|
||||
if "api_key" in kwargs and kwargs["api_key"]:
|
||||
kwargs["api_key"] = encrypt_api_key(kwargs["api_key"])
|
||||
|
||||
# 更新字段
|
||||
kwargs["updated_at"] = datetime.utcnow()
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(provider, key):
|
||||
setattr(provider, key, value)
|
||||
|
||||
await provider.save()
|
||||
|
||||
# 清除适配器缓存
|
||||
cls._adapter_cache.pop(provider_id, None)
|
||||
|
||||
# 更新速率限制
|
||||
if "rate_limit" in kwargs:
|
||||
rate_limiter.unregister(provider_id)
|
||||
rate_limiter.register(
|
||||
provider_id,
|
||||
provider.rate_limit.get("requests_per_minute", 60),
|
||||
provider.rate_limit.get("tokens_per_minute", 100000)
|
||||
)
|
||||
|
||||
logger.info(f"更新AI接口配置: {provider_id}")
|
||||
return provider
|
||||
|
||||
@classmethod
|
||||
async def delete_provider(cls, provider_id: str) -> bool:
|
||||
"""
|
||||
删除AI接口配置
|
||||
|
||||
Args:
|
||||
provider_id: 接口ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
provider = await cls.get_provider(provider_id)
|
||||
if not provider:
|
||||
return False
|
||||
|
||||
await provider.delete()
|
||||
|
||||
# 清除缓存和速率限制
|
||||
cls._adapter_cache.pop(provider_id, None)
|
||||
rate_limiter.unregister(provider_id)
|
||||
|
||||
logger.info(f"删除AI接口配置: {provider_id}")
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def get_adapter(cls, provider_id: str) -> Optional[BaseAdapter]:
|
||||
"""
|
||||
获取AI接口的适配器实例
|
||||
|
||||
Args:
|
||||
provider_id: 接口ID
|
||||
|
||||
Returns:
|
||||
适配器实例或None
|
||||
"""
|
||||
# 检查缓存
|
||||
if provider_id in cls._adapter_cache:
|
||||
return cls._adapter_cache[provider_id]
|
||||
|
||||
provider = await cls.get_provider(provider_id)
|
||||
if not provider or not provider.enabled:
|
||||
return None
|
||||
|
||||
# 解密API密钥
|
||||
api_key = decrypt_api_key(provider.api_key) if provider.api_key else ""
|
||||
|
||||
# 创建适配器
|
||||
adapter_class = get_adapter(provider.provider_type)
|
||||
adapter = adapter_class(
|
||||
api_key=api_key,
|
||||
base_url=provider.base_url,
|
||||
model=provider.model,
|
||||
use_proxy=provider.use_proxy,
|
||||
proxy_config=provider.proxy_config,
|
||||
timeout=provider.timeout,
|
||||
**provider.extra_params
|
||||
)
|
||||
|
||||
# 缓存适配器
|
||||
cls._adapter_cache[provider_id] = adapter
|
||||
|
||||
return adapter
|
||||
|
||||
@classmethod
|
||||
async def chat(
|
||||
cls,
|
||||
provider_id: str,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
**kwargs
|
||||
) -> AdapterResponse:
|
||||
"""
|
||||
调用AI接口进行对话
|
||||
|
||||
Args:
|
||||
provider_id: 接口ID
|
||||
messages: 消息列表 [{"role": "user", "content": "..."}]
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
适配器响应
|
||||
"""
|
||||
adapter = await cls.get_adapter(provider_id)
|
||||
if not adapter:
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error=f"AI接口不存在或未启用: {provider_id}"
|
||||
)
|
||||
|
||||
# 检查速率限制
|
||||
estimated_tokens = sum(len(m.get("content", "")) for m in messages) // 4
|
||||
if not await rate_limiter.acquire_wait(provider_id, estimated_tokens):
|
||||
return AdapterResponse(
|
||||
success=False,
|
||||
error="请求频率超限,请稍后重试"
|
||||
)
|
||||
|
||||
# 转换消息格式
|
||||
chat_messages = [
|
||||
ChatMessage(
|
||||
role=m.get("role", "user"),
|
||||
content=m.get("content", ""),
|
||||
name=m.get("name")
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
|
||||
# 调用适配器
|
||||
response = await adapter.chat(
|
||||
messages=chat_messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
async def test_provider(cls, provider_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
测试AI接口连接
|
||||
|
||||
Args:
|
||||
provider_id: 接口ID
|
||||
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
adapter = await cls.get_adapter(provider_id)
|
||||
if not adapter:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"AI接口不存在或未启用: {provider_id}"
|
||||
}
|
||||
|
||||
return await adapter.test_connection()
|
||||
|
||||
@classmethod
|
||||
async def test_provider_config(
|
||||
cls,
|
||||
provider_type: str,
|
||||
api_key: str,
|
||||
base_url: str = "",
|
||||
model: str = "",
|
||||
use_proxy: bool = False,
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: int = 30,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
测试AI接口配置(不保存)
|
||||
|
||||
Args:
|
||||
provider_type: 提供商类型
|
||||
api_key: API密钥
|
||||
base_url: API基础URL
|
||||
model: 模型名称
|
||||
use_proxy: 是否使用代理
|
||||
proxy_config: 代理配置
|
||||
timeout: 超时时间
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
try:
|
||||
adapter_class = get_adapter(provider_type)
|
||||
except ValueError:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"不支持的提供商类型: {provider_type}"
|
||||
}
|
||||
|
||||
adapter = adapter_class(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
use_proxy=use_proxy,
|
||||
proxy_config=proxy_config,
|
||||
timeout=timeout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return await adapter.test_connection()
|
||||
357
backend/services/chatroom_service.py
Normal file
357
backend/services/chatroom_service.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""
|
||||
聊天室服务
|
||||
管理聊天室的创建和状态
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from loguru import logger
|
||||
|
||||
from models.chatroom import ChatRoom, ChatRoomStatus
|
||||
from models.message import Message
|
||||
from services.agent_service import AgentService
|
||||
|
||||
|
||||
class ChatRoomService:
|
||||
"""
|
||||
聊天室服务类
|
||||
负责聊天室的CRUD操作
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def create_chatroom(
|
||||
cls,
|
||||
name: str,
|
||||
description: str = "",
|
||||
agents: Optional[List[str]] = None,
|
||||
moderator_agent_id: Optional[str] = None,
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
) -> ChatRoom:
|
||||
"""
|
||||
创建新的聊天室
|
||||
|
||||
Args:
|
||||
name: 聊天室名称
|
||||
description: 描述
|
||||
agents: Agent ID列表
|
||||
moderator_agent_id: 主持人Agent ID
|
||||
config: 聊天室配置
|
||||
|
||||
Returns:
|
||||
创建的ChatRoom文档
|
||||
"""
|
||||
# 验证Agent存在
|
||||
if agents:
|
||||
existing_agents = await AgentService.get_agents_by_ids(agents)
|
||||
existing_ids = {a.agent_id for a in existing_agents}
|
||||
missing_ids = set(agents) - existing_ids
|
||||
if missing_ids:
|
||||
raise ValueError(f"Agent不存在: {', '.join(missing_ids)}")
|
||||
|
||||
# 验证主持人Agent
|
||||
if moderator_agent_id:
|
||||
moderator = await AgentService.get_agent(moderator_agent_id)
|
||||
if not moderator:
|
||||
raise ValueError(f"主持人Agent不存在: {moderator_agent_id}")
|
||||
|
||||
# 生成唯一ID
|
||||
room_id = f"room-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 默认配置
|
||||
default_config = {
|
||||
"max_rounds": 50,
|
||||
"message_history_size": 20,
|
||||
"consensus_threshold": 0.8,
|
||||
"round_interval": 1.0,
|
||||
"allow_user_interrupt": True
|
||||
}
|
||||
if config:
|
||||
default_config.update(config)
|
||||
|
||||
# 创建文档
|
||||
chatroom = ChatRoom(
|
||||
room_id=room_id,
|
||||
name=name,
|
||||
description=description,
|
||||
objective="",
|
||||
agents=agents or [],
|
||||
moderator_agent_id=moderator_agent_id,
|
||||
config=default_config,
|
||||
status=ChatRoomStatus.IDLE.value,
|
||||
current_round=0,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
await chatroom.insert()
|
||||
|
||||
logger.info(f"创建聊天室: {room_id} ({name})")
|
||||
return chatroom
|
||||
|
||||
@classmethod
|
||||
async def get_chatroom(cls, room_id: str) -> Optional[ChatRoom]:
|
||||
"""
|
||||
获取指定聊天室
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
ChatRoom文档或None
|
||||
"""
|
||||
return await ChatRoom.find_one(ChatRoom.room_id == room_id)
|
||||
|
||||
@classmethod
|
||||
async def get_all_chatrooms(cls) -> List[ChatRoom]:
|
||||
"""
|
||||
获取所有聊天室
|
||||
|
||||
Returns:
|
||||
ChatRoom列表
|
||||
"""
|
||||
return await ChatRoom.find_all().to_list()
|
||||
|
||||
@classmethod
|
||||
async def update_chatroom(
|
||||
cls,
|
||||
room_id: str,
|
||||
**kwargs
|
||||
) -> Optional[ChatRoom]:
|
||||
"""
|
||||
更新聊天室配置
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
**kwargs: 要更新的字段
|
||||
|
||||
Returns:
|
||||
更新后的ChatRoom或None
|
||||
"""
|
||||
chatroom = await cls.get_chatroom(room_id)
|
||||
if not chatroom:
|
||||
return None
|
||||
|
||||
# 验证Agent
|
||||
if "agents" in kwargs:
|
||||
existing_agents = await AgentService.get_agents_by_ids(kwargs["agents"])
|
||||
existing_ids = {a.agent_id for a in existing_agents}
|
||||
missing_ids = set(kwargs["agents"]) - existing_ids
|
||||
if missing_ids:
|
||||
raise ValueError(f"Agent不存在: {', '.join(missing_ids)}")
|
||||
|
||||
# 验证主持人
|
||||
if "moderator_agent_id" in kwargs and kwargs["moderator_agent_id"]:
|
||||
moderator = await AgentService.get_agent(kwargs["moderator_agent_id"])
|
||||
if not moderator:
|
||||
raise ValueError(f"主持人Agent不存在: {kwargs['moderator_agent_id']}")
|
||||
|
||||
# 更新字段
|
||||
kwargs["updated_at"] = datetime.utcnow()
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(chatroom, key):
|
||||
setattr(chatroom, key, value)
|
||||
|
||||
await chatroom.save()
|
||||
|
||||
logger.info(f"更新聊天室: {room_id}")
|
||||
return chatroom
|
||||
|
||||
@classmethod
|
||||
async def delete_chatroom(cls, room_id: str) -> bool:
|
||||
"""
|
||||
删除聊天室
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
chatroom = await cls.get_chatroom(room_id)
|
||||
if not chatroom:
|
||||
return False
|
||||
|
||||
# 删除相关消息
|
||||
await Message.find(Message.room_id == room_id).delete()
|
||||
|
||||
await chatroom.delete()
|
||||
|
||||
logger.info(f"删除聊天室: {room_id}")
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def add_agent(cls, room_id: str, agent_id: str) -> Optional[ChatRoom]:
|
||||
"""
|
||||
向聊天室添加Agent
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
更新后的ChatRoom或None
|
||||
"""
|
||||
chatroom = await cls.get_chatroom(room_id)
|
||||
if not chatroom:
|
||||
return None
|
||||
|
||||
# 验证Agent存在
|
||||
agent = await AgentService.get_agent(agent_id)
|
||||
if not agent:
|
||||
raise ValueError(f"Agent不存在: {agent_id}")
|
||||
|
||||
# 添加Agent
|
||||
if agent_id not in chatroom.agents:
|
||||
chatroom.agents.append(agent_id)
|
||||
chatroom.updated_at = datetime.utcnow()
|
||||
await chatroom.save()
|
||||
|
||||
return chatroom
|
||||
|
||||
@classmethod
|
||||
async def remove_agent(cls, room_id: str, agent_id: str) -> Optional[ChatRoom]:
|
||||
"""
|
||||
从聊天室移除Agent
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
更新后的ChatRoom或None
|
||||
"""
|
||||
chatroom = await cls.get_chatroom(room_id)
|
||||
if not chatroom:
|
||||
return None
|
||||
|
||||
# 移除Agent
|
||||
if agent_id in chatroom.agents:
|
||||
chatroom.agents.remove(agent_id)
|
||||
chatroom.updated_at = datetime.utcnow()
|
||||
await chatroom.save()
|
||||
|
||||
return chatroom
|
||||
|
||||
@classmethod
|
||||
async def set_objective(
|
||||
cls,
|
||||
room_id: str,
|
||||
objective: str
|
||||
) -> Optional[ChatRoom]:
|
||||
"""
|
||||
设置讨论目标
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
objective: 讨论目标
|
||||
|
||||
Returns:
|
||||
更新后的ChatRoom或None
|
||||
"""
|
||||
return await cls.update_chatroom(room_id, objective=objective)
|
||||
|
||||
@classmethod
|
||||
async def update_status(
|
||||
cls,
|
||||
room_id: str,
|
||||
status: ChatRoomStatus
|
||||
) -> Optional[ChatRoom]:
|
||||
"""
|
||||
更新聊天室状态
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
status: 新状态
|
||||
|
||||
Returns:
|
||||
更新后的ChatRoom或None
|
||||
"""
|
||||
chatroom = await cls.get_chatroom(room_id)
|
||||
if not chatroom:
|
||||
return None
|
||||
|
||||
chatroom.status = status.value
|
||||
chatroom.updated_at = datetime.utcnow()
|
||||
|
||||
if status == ChatRoomStatus.COMPLETED:
|
||||
chatroom.completed_at = datetime.utcnow()
|
||||
|
||||
await chatroom.save()
|
||||
|
||||
logger.info(f"聊天室状态更新: {room_id} -> {status.value}")
|
||||
return chatroom
|
||||
|
||||
@classmethod
|
||||
async def increment_round(cls, room_id: str) -> Optional[ChatRoom]:
|
||||
"""
|
||||
增加轮次计数
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
更新后的ChatRoom或None
|
||||
"""
|
||||
chatroom = await cls.get_chatroom(room_id)
|
||||
if not chatroom:
|
||||
return None
|
||||
|
||||
chatroom.current_round += 1
|
||||
chatroom.updated_at = datetime.utcnow()
|
||||
await chatroom.save()
|
||||
|
||||
return chatroom
|
||||
|
||||
@classmethod
|
||||
async def get_messages(
|
||||
cls,
|
||||
room_id: str,
|
||||
limit: int = 50,
|
||||
skip: int = 0,
|
||||
discussion_id: Optional[str] = None
|
||||
) -> List[Message]:
|
||||
"""
|
||||
获取聊天室消息历史
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
limit: 返回数量限制
|
||||
skip: 跳过数量
|
||||
discussion_id: 讨论ID(可选)
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
query = {"room_id": room_id}
|
||||
if discussion_id:
|
||||
query["discussion_id"] = discussion_id
|
||||
|
||||
return await Message.find(query).sort(
|
||||
"-created_at"
|
||||
).skip(skip).limit(limit).to_list()
|
||||
|
||||
@classmethod
|
||||
async def get_recent_messages(
|
||||
cls,
|
||||
room_id: str,
|
||||
count: int = 20,
|
||||
discussion_id: Optional[str] = None
|
||||
) -> List[Message]:
|
||||
"""
|
||||
获取最近的消息
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
count: 消息数量
|
||||
discussion_id: 讨论ID(可选)
|
||||
|
||||
Returns:
|
||||
消息列表(按时间正序)
|
||||
"""
|
||||
messages = await cls.get_messages(
|
||||
room_id,
|
||||
limit=count,
|
||||
discussion_id=discussion_id
|
||||
)
|
||||
return list(reversed(messages)) # 返回正序
|
||||
227
backend/services/consensus_manager.py
Normal file
227
backend/services/consensus_manager.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""
|
||||
共识管理器
|
||||
判断讨论是否达成共识
|
||||
"""
|
||||
import json
|
||||
from typing import Dict, Any, Optional
|
||||
from loguru import logger
|
||||
|
||||
from models.agent import Agent
|
||||
from models.chatroom import ChatRoom
|
||||
from services.ai_provider_service import AIProviderService
|
||||
|
||||
|
||||
class ConsensusManager:
|
||||
"""
|
||||
共识管理器
|
||||
使用主持人Agent判断讨论共识
|
||||
"""
|
||||
|
||||
# 共识判断提示词模板
|
||||
CONSENSUS_PROMPT = """你是讨论的主持人,负责判断讨论是否达成共识。
|
||||
|
||||
讨论目标:{objective}
|
||||
|
||||
对话历史:
|
||||
{history}
|
||||
|
||||
请仔细分析对话内容,判断:
|
||||
1. 参与者是否对核心问题达成一致意见?
|
||||
2. 是否还有重要分歧未解决?
|
||||
3. 讨论结果是否足够明确和可执行?
|
||||
|
||||
请以JSON格式回复(不要包含任何其他文字):
|
||||
{{
|
||||
"consensus_reached": true或false,
|
||||
"confidence": 0到1之间的数字,
|
||||
"summary": "讨论结果摘要,简洁概括达成的共识或当前状态",
|
||||
"action_items": ["具体的行动项列表"],
|
||||
"unresolved_issues": ["未解决的问题列表"],
|
||||
"key_decisions": ["关键决策列表"]
|
||||
}}
|
||||
|
||||
注意:
|
||||
- consensus_reached为true表示核心问题已有明确结论
|
||||
- confidence表示你对共识判断的信心程度
|
||||
- 如果讨论仍有争议或不够深入,应该返回false
|
||||
- action_items应该是具体可执行的任务
|
||||
- 请确保返回有效的JSON格式"""
|
||||
|
||||
@classmethod
|
||||
async def check_consensus(
|
||||
cls,
|
||||
moderator: Agent,
|
||||
context: "DiscussionContext",
|
||||
chatroom: ChatRoom
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
检查是否达成共识
|
||||
|
||||
Args:
|
||||
moderator: 主持人Agent
|
||||
context: 讨论上下文
|
||||
chatroom: 聊天室
|
||||
|
||||
Returns:
|
||||
共识判断结果
|
||||
"""
|
||||
from services.discussion_engine import DiscussionContext
|
||||
|
||||
# 构建历史记录
|
||||
history_text = ""
|
||||
for msg in context.messages:
|
||||
if msg.agent_id:
|
||||
history_text += f"[{msg.agent_id}]: {msg.content}\n\n"
|
||||
else:
|
||||
history_text += f"[系统]: {msg.content}\n\n"
|
||||
|
||||
if not history_text:
|
||||
return {
|
||||
"consensus_reached": False,
|
||||
"confidence": 0,
|
||||
"summary": "讨论尚未开始",
|
||||
"action_items": [],
|
||||
"unresolved_issues": [],
|
||||
"key_decisions": []
|
||||
}
|
||||
|
||||
# 构建提示词
|
||||
prompt = cls.CONSENSUS_PROMPT.format(
|
||||
objective=context.objective,
|
||||
history=history_text
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用主持人Agent的AI接口
|
||||
response = await AIProviderService.chat(
|
||||
provider_id=moderator.provider_id,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.3, # 使用较低温度以获得更一致的结果
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
if not response.success:
|
||||
logger.error(f"共识判断失败: {response.error}")
|
||||
return cls._default_result("AI接口调用失败")
|
||||
|
||||
# 解析JSON响应
|
||||
content = response.content.strip()
|
||||
|
||||
# 尝试提取JSON部分
|
||||
try:
|
||||
# 尝试直接解析
|
||||
result = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
# 尝试提取JSON块
|
||||
import re
|
||||
json_match = re.search(r'\{[\s\S]*\}', content)
|
||||
if json_match:
|
||||
try:
|
||||
result = json.loads(json_match.group())
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"无法解析共识判断结果: {content}")
|
||||
return cls._default_result("无法解析AI响应")
|
||||
else:
|
||||
return cls._default_result("AI响应格式错误")
|
||||
|
||||
# 验证和规范化结果
|
||||
return cls._normalize_result(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"共识判断异常: {e}")
|
||||
return cls._default_result(str(e))
|
||||
|
||||
@classmethod
|
||||
async def generate_summary(
|
||||
cls,
|
||||
moderator: Agent,
|
||||
context: "DiscussionContext"
|
||||
) -> str:
|
||||
"""
|
||||
生成讨论摘要
|
||||
|
||||
Args:
|
||||
moderator: 主持人Agent
|
||||
context: 讨论上下文
|
||||
|
||||
Returns:
|
||||
讨论摘要
|
||||
"""
|
||||
from services.discussion_engine import DiscussionContext
|
||||
|
||||
# 构建历史记录
|
||||
history_text = ""
|
||||
for msg in context.messages:
|
||||
if msg.agent_id:
|
||||
history_text += f"[{msg.agent_id}]: {msg.content}\n\n"
|
||||
|
||||
prompt = f"""请为以下讨论生成一份简洁的摘要。
|
||||
|
||||
讨论目标:{context.objective}
|
||||
|
||||
对话记录:
|
||||
{history_text}
|
||||
|
||||
请提供:
|
||||
1. 讨论的主要观点和结论
|
||||
2. 参与者的立场和建议
|
||||
3. 最终的决策或共识(如果有)
|
||||
|
||||
摘要应该简洁明了,控制在300字以内。"""
|
||||
|
||||
try:
|
||||
response = await AIProviderService.chat(
|
||||
provider_id=moderator.provider_id,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.5,
|
||||
max_tokens=500
|
||||
)
|
||||
|
||||
if response.success:
|
||||
return response.content.strip()
|
||||
else:
|
||||
return "无法生成摘要"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成摘要异常: {e}")
|
||||
return "生成摘要时发生错误"
|
||||
|
||||
@classmethod
|
||||
def _default_result(cls, error: str = "") -> Dict[str, Any]:
|
||||
"""
|
||||
返回默认结果
|
||||
|
||||
Args:
|
||||
error: 错误信息
|
||||
|
||||
Returns:
|
||||
默认共识结果
|
||||
"""
|
||||
return {
|
||||
"consensus_reached": False,
|
||||
"confidence": 0,
|
||||
"summary": error if error else "共识判断失败",
|
||||
"action_items": [],
|
||||
"unresolved_issues": [],
|
||||
"key_decisions": []
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _normalize_result(cls, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
规范化共识结果
|
||||
|
||||
Args:
|
||||
result: 原始结果
|
||||
|
||||
Returns:
|
||||
规范化的结果
|
||||
"""
|
||||
return {
|
||||
"consensus_reached": bool(result.get("consensus_reached", False)),
|
||||
"confidence": max(0, min(1, float(result.get("confidence", 0)))),
|
||||
"summary": str(result.get("summary", "")),
|
||||
"action_items": list(result.get("action_items", [])),
|
||||
"unresolved_issues": list(result.get("unresolved_issues", [])),
|
||||
"key_decisions": list(result.get("key_decisions", []))
|
||||
}
|
||||
589
backend/services/discussion_engine.py
Normal file
589
backend/services/discussion_engine.py
Normal file
@@ -0,0 +1,589 @@
|
||||
"""
|
||||
讨论引擎
|
||||
实现自由讨论的核心逻辑
|
||||
"""
|
||||
import uuid
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
|
||||
from models.chatroom import ChatRoom, ChatRoomStatus
|
||||
from models.agent import Agent
|
||||
from models.message import Message, MessageType
|
||||
from models.discussion_result import DiscussionResult
|
||||
from services.ai_provider_service import AIProviderService
|
||||
from services.agent_service import AgentService
|
||||
from services.chatroom_service import ChatRoomService
|
||||
from services.message_router import MessageRouter
|
||||
from services.consensus_manager import ConsensusManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiscussionContext:
|
||||
"""讨论上下文"""
|
||||
discussion_id: str
|
||||
room_id: str
|
||||
objective: str
|
||||
current_round: int = 0
|
||||
messages: List[Message] = field(default_factory=list)
|
||||
agent_speak_counts: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
def add_message(self, message: Message) -> None:
|
||||
"""添加消息到上下文"""
|
||||
self.messages.append(message)
|
||||
if message.agent_id:
|
||||
self.agent_speak_counts[message.agent_id] = \
|
||||
self.agent_speak_counts.get(message.agent_id, 0) + 1
|
||||
|
||||
def get_recent_messages(self, count: int = 20) -> List[Message]:
|
||||
"""获取最近的消息"""
|
||||
return self.messages[-count:] if len(self.messages) > count else self.messages
|
||||
|
||||
def get_agent_speak_count(self, agent_id: str) -> int:
|
||||
"""获取Agent在当前轮次的发言次数"""
|
||||
return self.agent_speak_counts.get(agent_id, 0)
|
||||
|
||||
def reset_round_counts(self) -> None:
|
||||
"""重置轮次发言计数"""
|
||||
self.agent_speak_counts.clear()
|
||||
|
||||
|
||||
class DiscussionEngine:
|
||||
"""
|
||||
讨论引擎
|
||||
实现多Agent自由讨论的核心逻辑
|
||||
"""
|
||||
|
||||
# 活跃的讨论: room_id -> DiscussionContext
|
||||
_active_discussions: Dict[str, DiscussionContext] = {}
|
||||
|
||||
# 停止信号
|
||||
_stop_signals: Dict[str, bool] = {}
|
||||
|
||||
@classmethod
|
||||
async def start_discussion(
|
||||
cls,
|
||||
room_id: str,
|
||||
objective: str
|
||||
) -> Optional[DiscussionResult]:
|
||||
"""
|
||||
启动讨论
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
objective: 讨论目标
|
||||
|
||||
Returns:
|
||||
讨论结果
|
||||
"""
|
||||
# 获取聊天室
|
||||
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||
if not chatroom:
|
||||
raise ValueError(f"聊天室不存在: {room_id}")
|
||||
|
||||
if not chatroom.agents:
|
||||
raise ValueError("聊天室没有Agent参与")
|
||||
|
||||
if not objective:
|
||||
raise ValueError("讨论目标不能为空")
|
||||
|
||||
# 检查是否已有活跃讨论
|
||||
if room_id in cls._active_discussions:
|
||||
raise ValueError("聊天室已有进行中的讨论")
|
||||
|
||||
# 创建讨论
|
||||
discussion_id = f"disc-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 创建讨论结果记录
|
||||
discussion_result = DiscussionResult(
|
||||
discussion_id=discussion_id,
|
||||
room_id=room_id,
|
||||
objective=objective,
|
||||
status="in_progress",
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
await discussion_result.insert()
|
||||
|
||||
# 创建讨论上下文
|
||||
context = DiscussionContext(
|
||||
discussion_id=discussion_id,
|
||||
room_id=room_id,
|
||||
objective=objective
|
||||
)
|
||||
cls._active_discussions[room_id] = context
|
||||
cls._stop_signals[room_id] = False
|
||||
|
||||
# 更新聊天室状态
|
||||
await ChatRoomService.update_chatroom(
|
||||
room_id,
|
||||
status=ChatRoomStatus.ACTIVE.value,
|
||||
objective=objective,
|
||||
current_discussion_id=discussion_id,
|
||||
current_round=0
|
||||
)
|
||||
|
||||
# 广播讨论开始
|
||||
await MessageRouter.broadcast_status(room_id, "discussion_started", {
|
||||
"discussion_id": discussion_id,
|
||||
"objective": objective
|
||||
})
|
||||
|
||||
# 发送系统消息
|
||||
await MessageRouter.save_and_broadcast_message(
|
||||
room_id=room_id,
|
||||
discussion_id=discussion_id,
|
||||
agent_id=None,
|
||||
content=f"讨论开始\n\n目标:{objective}",
|
||||
message_type=MessageType.SYSTEM.value,
|
||||
round_num=0
|
||||
)
|
||||
|
||||
logger.info(f"讨论开始: {room_id} - {objective}")
|
||||
|
||||
# 运行讨论循环
|
||||
try:
|
||||
result = await cls._run_discussion_loop(chatroom, context)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"讨论异常: {e}")
|
||||
await cls._handle_discussion_error(room_id, discussion_id, str(e))
|
||||
raise
|
||||
finally:
|
||||
# 清理
|
||||
cls._active_discussions.pop(room_id, None)
|
||||
cls._stop_signals.pop(room_id, None)
|
||||
|
||||
@classmethod
|
||||
async def stop_discussion(cls, room_id: str) -> bool:
|
||||
"""
|
||||
停止讨论
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if room_id not in cls._active_discussions:
|
||||
return False
|
||||
|
||||
cls._stop_signals[room_id] = True
|
||||
logger.info(f"收到停止讨论信号: {room_id}")
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def pause_discussion(cls, room_id: str) -> bool:
|
||||
"""
|
||||
暂停讨论
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if room_id not in cls._active_discussions:
|
||||
return False
|
||||
|
||||
await ChatRoomService.update_status(room_id, ChatRoomStatus.PAUSED)
|
||||
await MessageRouter.broadcast_status(room_id, "discussion_paused")
|
||||
|
||||
logger.info(f"讨论暂停: {room_id}")
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def resume_discussion(cls, room_id: str) -> bool:
|
||||
"""
|
||||
恢复讨论
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||
if not chatroom or chatroom.status != ChatRoomStatus.PAUSED.value:
|
||||
return False
|
||||
|
||||
await ChatRoomService.update_status(room_id, ChatRoomStatus.ACTIVE)
|
||||
await MessageRouter.broadcast_status(room_id, "discussion_resumed")
|
||||
|
||||
logger.info(f"讨论恢复: {room_id}")
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def _run_discussion_loop(
|
||||
cls,
|
||||
chatroom: ChatRoom,
|
||||
context: DiscussionContext
|
||||
) -> DiscussionResult:
|
||||
"""
|
||||
运行讨论循环
|
||||
|
||||
Args:
|
||||
chatroom: 聊天室
|
||||
context: 讨论上下文
|
||||
|
||||
Returns:
|
||||
讨论结果
|
||||
"""
|
||||
room_id = chatroom.room_id
|
||||
config = chatroom.get_config()
|
||||
|
||||
# 获取所有Agent
|
||||
agents = await AgentService.get_agents_by_ids(chatroom.agents)
|
||||
agent_map = {a.agent_id: a for a in agents}
|
||||
|
||||
# 获取主持人(用于共识判断)
|
||||
moderator = None
|
||||
if chatroom.moderator_agent_id:
|
||||
moderator = await AgentService.get_agent(chatroom.moderator_agent_id)
|
||||
|
||||
consecutive_no_speak = 0 # 连续无人发言的轮次
|
||||
|
||||
while context.current_round < config.max_rounds:
|
||||
# 检查停止信号
|
||||
if cls._stop_signals.get(room_id, False):
|
||||
break
|
||||
|
||||
# 检查暂停状态
|
||||
current_chatroom = await ChatRoomService.get_chatroom(room_id)
|
||||
if current_chatroom and current_chatroom.status == ChatRoomStatus.PAUSED.value:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
# 增加轮次
|
||||
context.current_round += 1
|
||||
context.reset_round_counts()
|
||||
|
||||
# 广播轮次信息
|
||||
await MessageRouter.broadcast_round_info(
|
||||
room_id,
|
||||
context.current_round,
|
||||
config.max_rounds
|
||||
)
|
||||
|
||||
# 更新聊天室轮次
|
||||
await ChatRoomService.update_chatroom(
|
||||
room_id,
|
||||
current_round=context.current_round
|
||||
)
|
||||
|
||||
# 本轮是否有人发言
|
||||
round_has_message = False
|
||||
|
||||
# 遍历所有Agent,判断是否发言
|
||||
for agent_id in chatroom.agents:
|
||||
agent = agent_map.get(agent_id)
|
||||
if not agent or not agent.enabled:
|
||||
continue
|
||||
|
||||
# 检查本轮发言次数限制
|
||||
behavior = agent.get_behavior()
|
||||
if context.get_agent_speak_count(agent_id) >= behavior.max_speak_per_round:
|
||||
continue
|
||||
|
||||
# 判断是否发言
|
||||
should_speak, content = await cls._should_agent_speak(
|
||||
agent, context, chatroom
|
||||
)
|
||||
|
||||
if should_speak and content:
|
||||
# 广播输入状态
|
||||
await MessageRouter.broadcast_typing(room_id, agent_id, True)
|
||||
|
||||
# 保存并广播消息
|
||||
message = await MessageRouter.save_and_broadcast_message(
|
||||
room_id=room_id,
|
||||
discussion_id=context.discussion_id,
|
||||
agent_id=agent_id,
|
||||
content=content,
|
||||
message_type=MessageType.TEXT.value,
|
||||
round_num=context.current_round
|
||||
)
|
||||
|
||||
# 更新上下文
|
||||
context.add_message(message)
|
||||
round_has_message = True
|
||||
|
||||
# 广播输入结束
|
||||
await MessageRouter.broadcast_typing(room_id, agent_id, False)
|
||||
|
||||
# 轮次间隔
|
||||
await asyncio.sleep(config.round_interval)
|
||||
|
||||
# 检查是否需要共识判断
|
||||
if round_has_message and moderator:
|
||||
consecutive_no_speak = 0
|
||||
|
||||
# 每隔几轮检查一次共识
|
||||
if context.current_round % 3 == 0 or context.current_round >= config.max_rounds - 5:
|
||||
consensus_result = await ConsensusManager.check_consensus(
|
||||
moderator, context, chatroom
|
||||
)
|
||||
|
||||
if consensus_result.get("consensus_reached", False):
|
||||
confidence = consensus_result.get("confidence", 0)
|
||||
if confidence >= config.consensus_threshold:
|
||||
# 达成共识,结束讨论
|
||||
return await cls._finalize_discussion(
|
||||
context,
|
||||
consensus_result,
|
||||
"consensus"
|
||||
)
|
||||
else:
|
||||
consecutive_no_speak += 1
|
||||
|
||||
# 连续多轮无人发言,检查共识或结束
|
||||
if consecutive_no_speak >= 3:
|
||||
if moderator:
|
||||
consensus_result = await ConsensusManager.check_consensus(
|
||||
moderator, context, chatroom
|
||||
)
|
||||
return await cls._finalize_discussion(
|
||||
context,
|
||||
consensus_result,
|
||||
"no_more_discussion"
|
||||
)
|
||||
else:
|
||||
return await cls._finalize_discussion(
|
||||
context,
|
||||
{"consensus_reached": False, "summary": "讨论结束,无明确共识"},
|
||||
"no_more_discussion"
|
||||
)
|
||||
|
||||
# 达到最大轮次
|
||||
if moderator:
|
||||
consensus_result = await ConsensusManager.check_consensus(
|
||||
moderator, context, chatroom
|
||||
)
|
||||
else:
|
||||
consensus_result = {"consensus_reached": False, "summary": "达到最大轮次限制"}
|
||||
|
||||
return await cls._finalize_discussion(
|
||||
context,
|
||||
consensus_result,
|
||||
"max_rounds"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _should_agent_speak(
|
||||
cls,
|
||||
agent: Agent,
|
||||
context: DiscussionContext,
|
||||
chatroom: ChatRoom
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
判断Agent是否应该发言
|
||||
|
||||
Args:
|
||||
agent: Agent实例
|
||||
context: 讨论上下文
|
||||
chatroom: 聊天室
|
||||
|
||||
Returns:
|
||||
(是否发言, 发言内容)
|
||||
"""
|
||||
# 构建判断提示词
|
||||
recent_messages = context.get_recent_messages(chatroom.get_config().message_history_size)
|
||||
|
||||
history_text = ""
|
||||
for msg in recent_messages:
|
||||
if msg.agent_id:
|
||||
history_text += f"[{msg.agent_id}]: {msg.content}\n\n"
|
||||
else:
|
||||
history_text += f"[系统]: {msg.content}\n\n"
|
||||
|
||||
prompt = f"""你是{agent.name},角色是{agent.role}。
|
||||
|
||||
{agent.system_prompt}
|
||||
|
||||
当前讨论目标:{context.objective}
|
||||
|
||||
对话历史:
|
||||
{history_text if history_text else "(还没有对话)"}
|
||||
|
||||
当前是第{context.current_round}轮讨论。
|
||||
|
||||
请根据你的角色判断:
|
||||
1. 你是否有新的观点或建议要分享?
|
||||
2. 你是否需要回应其他人的观点?
|
||||
3. 当前讨论是否需要你的专业意见?
|
||||
|
||||
如果你认为需要发言,请直接给出你的发言内容。
|
||||
如果你认为暂时不需要发言(例如等待更多信息、当前轮次已有足够讨论、或者你的观点已经充分表达),请只回复"PASS"。
|
||||
|
||||
注意:
|
||||
- 请保持发言简洁有力,每次发言控制在200字以内
|
||||
- 避免重复已经说过的内容
|
||||
- 如果已经达成共识或接近共识,可以选择PASS"""
|
||||
|
||||
try:
|
||||
# 调用AI接口
|
||||
response = await AIProviderService.chat(
|
||||
provider_id=agent.provider_id,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=agent.temperature,
|
||||
max_tokens=agent.max_tokens
|
||||
)
|
||||
|
||||
if not response.success:
|
||||
logger.warning(f"Agent {agent.agent_id} 响应失败: {response.error}")
|
||||
return False, ""
|
||||
|
||||
content = response.content.strip()
|
||||
|
||||
# 判断是否PASS
|
||||
if content.upper() == "PASS" or content.upper().startswith("PASS"):
|
||||
return False, ""
|
||||
|
||||
return True, content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent {agent.agent_id} 判断发言异常: {e}")
|
||||
return False, ""
|
||||
|
||||
@classmethod
|
||||
async def _finalize_discussion(
|
||||
cls,
|
||||
context: DiscussionContext,
|
||||
consensus_result: Dict[str, Any],
|
||||
end_reason: str
|
||||
) -> DiscussionResult:
|
||||
"""
|
||||
完成讨论,保存结果
|
||||
|
||||
Args:
|
||||
context: 讨论上下文
|
||||
consensus_result: 共识判断结果
|
||||
end_reason: 结束原因
|
||||
|
||||
Returns:
|
||||
讨论结果
|
||||
"""
|
||||
room_id = context.room_id
|
||||
|
||||
# 获取讨论结果记录
|
||||
discussion_result = await DiscussionResult.find_one(
|
||||
DiscussionResult.discussion_id == context.discussion_id
|
||||
)
|
||||
|
||||
if discussion_result:
|
||||
# 更新统计
|
||||
discussion_result.update_stats(
|
||||
total_rounds=context.current_round,
|
||||
total_messages=len(context.messages),
|
||||
agent_contributions=context.agent_speak_counts
|
||||
)
|
||||
|
||||
# 标记完成
|
||||
discussion_result.mark_completed(
|
||||
consensus_reached=consensus_result.get("consensus_reached", False),
|
||||
confidence=consensus_result.get("confidence", 0),
|
||||
summary=consensus_result.get("summary", ""),
|
||||
action_items=consensus_result.get("action_items", []),
|
||||
unresolved_issues=consensus_result.get("unresolved_issues", []),
|
||||
end_reason=end_reason
|
||||
)
|
||||
|
||||
await discussion_result.save()
|
||||
|
||||
# 更新聊天室状态
|
||||
await ChatRoomService.update_status(room_id, ChatRoomStatus.COMPLETED)
|
||||
|
||||
# 发送系统消息
|
||||
summary_text = f"""讨论结束
|
||||
|
||||
结果:{"达成共识" if consensus_result.get("consensus_reached") else "未达成明确共识"}
|
||||
置信度:{consensus_result.get("confidence", 0):.0%}
|
||||
|
||||
摘要:{consensus_result.get("summary", "无")}
|
||||
|
||||
行动项:
|
||||
{chr(10).join("- " + item for item in consensus_result.get("action_items", [])) or "无"}
|
||||
|
||||
未解决问题:
|
||||
{chr(10).join("- " + issue for issue in consensus_result.get("unresolved_issues", [])) or "无"}
|
||||
|
||||
共进行 {context.current_round} 轮讨论,产生 {len(context.messages)} 条消息。"""
|
||||
|
||||
await MessageRouter.save_and_broadcast_message(
|
||||
room_id=room_id,
|
||||
discussion_id=context.discussion_id,
|
||||
agent_id=None,
|
||||
content=summary_text,
|
||||
message_type=MessageType.SYSTEM.value,
|
||||
round_num=context.current_round
|
||||
)
|
||||
|
||||
# 广播讨论结束
|
||||
await MessageRouter.broadcast_status(room_id, "discussion_completed", {
|
||||
"discussion_id": context.discussion_id,
|
||||
"consensus_reached": consensus_result.get("consensus_reached", False),
|
||||
"end_reason": end_reason
|
||||
})
|
||||
|
||||
logger.info(f"讨论结束: {room_id}, 原因: {end_reason}")
|
||||
|
||||
return discussion_result
|
||||
|
||||
@classmethod
|
||||
async def _handle_discussion_error(
|
||||
cls,
|
||||
room_id: str,
|
||||
discussion_id: str,
|
||||
error: str
|
||||
) -> None:
|
||||
"""
|
||||
处理讨论错误
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
discussion_id: 讨论ID
|
||||
error: 错误信息
|
||||
"""
|
||||
# 更新聊天室状态
|
||||
await ChatRoomService.update_status(room_id, ChatRoomStatus.ERROR)
|
||||
|
||||
# 更新讨论结果
|
||||
discussion_result = await DiscussionResult.find_one(
|
||||
DiscussionResult.discussion_id == discussion_id
|
||||
)
|
||||
if discussion_result:
|
||||
discussion_result.status = "failed"
|
||||
discussion_result.end_reason = f"error: {error}"
|
||||
discussion_result.updated_at = datetime.utcnow()
|
||||
await discussion_result.save()
|
||||
|
||||
# 广播错误
|
||||
await MessageRouter.broadcast_error(room_id, error)
|
||||
|
||||
@classmethod
|
||||
def get_active_discussion(cls, room_id: str) -> Optional[DiscussionContext]:
|
||||
"""
|
||||
获取活跃的讨论上下文
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
讨论上下文或None
|
||||
"""
|
||||
return cls._active_discussions.get(room_id)
|
||||
|
||||
@classmethod
|
||||
def is_discussion_active(cls, room_id: str) -> bool:
|
||||
"""
|
||||
检查是否有活跃讨论
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
是否活跃
|
||||
"""
|
||||
return room_id in cls._active_discussions
|
||||
252
backend/services/mcp_service.py
Normal file
252
backend/services/mcp_service.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""
|
||||
MCP服务
|
||||
管理MCP工具的集成和调用
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class MCPService:
|
||||
"""
|
||||
MCP工具服务
|
||||
集成MCP服务器,提供工具调用能力
|
||||
"""
|
||||
|
||||
# MCP服务器配置目录
|
||||
MCP_CONFIG_DIR = Path(os.getenv("CURSOR_MCP_DIR", "~/.cursor/mcps")).expanduser()
|
||||
|
||||
# 已注册的工具: server_name -> List[tool_info]
|
||||
_registered_tools: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
# Agent工具映射: agent_id -> List[tool_name]
|
||||
_agent_tools: Dict[str, List[str]] = {}
|
||||
|
||||
@classmethod
|
||||
async def initialize(cls) -> None:
|
||||
"""
|
||||
初始化MCP服务
|
||||
扫描并注册可用的MCP工具
|
||||
"""
|
||||
logger.info("初始化MCP服务...")
|
||||
|
||||
if not cls.MCP_CONFIG_DIR.exists():
|
||||
logger.warning(f"MCP配置目录不存在: {cls.MCP_CONFIG_DIR}")
|
||||
return
|
||||
|
||||
# 扫描MCP服务器目录
|
||||
for server_dir in cls.MCP_CONFIG_DIR.iterdir():
|
||||
if server_dir.is_dir():
|
||||
await cls._scan_server(server_dir)
|
||||
|
||||
logger.info(f"MCP服务初始化完成,已注册 {len(cls._registered_tools)} 个服务器")
|
||||
|
||||
@classmethod
|
||||
async def _scan_server(cls, server_dir: Path) -> None:
|
||||
"""
|
||||
扫描MCP服务器目录
|
||||
|
||||
Args:
|
||||
server_dir: 服务器目录
|
||||
"""
|
||||
server_name = server_dir.name
|
||||
tools_dir = server_dir / "tools"
|
||||
|
||||
if not tools_dir.exists():
|
||||
return
|
||||
|
||||
tools = []
|
||||
for tool_file in tools_dir.glob("*.json"):
|
||||
try:
|
||||
with open(tool_file, "r", encoding="utf-8") as f:
|
||||
tool_info = json.load(f)
|
||||
tool_info["_file"] = str(tool_file)
|
||||
tools.append(tool_info)
|
||||
except Exception as e:
|
||||
logger.warning(f"加载MCP工具配置失败: {tool_file} - {e}")
|
||||
|
||||
if tools:
|
||||
cls._registered_tools[server_name] = tools
|
||||
logger.debug(f"注册MCP服务器: {server_name}, 工具数: {len(tools)}")
|
||||
|
||||
@classmethod
|
||||
def list_servers(cls) -> List[str]:
|
||||
"""
|
||||
列出所有可用的MCP服务器
|
||||
|
||||
Returns:
|
||||
服务器名称列表
|
||||
"""
|
||||
return list(cls._registered_tools.keys())
|
||||
|
||||
@classmethod
|
||||
def list_tools(cls, server: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
列出可用的MCP工具
|
||||
|
||||
Args:
|
||||
server: 服务器名称(可选,不指定则返回所有)
|
||||
|
||||
Returns:
|
||||
工具信息列表
|
||||
"""
|
||||
if server:
|
||||
return cls._registered_tools.get(server, [])
|
||||
|
||||
# 返回所有工具
|
||||
all_tools = []
|
||||
for server_name, tools in cls._registered_tools.items():
|
||||
for tool in tools:
|
||||
tool_copy = tool.copy()
|
||||
tool_copy["server"] = server_name
|
||||
all_tools.append(tool_copy)
|
||||
|
||||
return all_tools
|
||||
|
||||
@classmethod
|
||||
def get_tool(cls, server: str, tool_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定工具的信息
|
||||
|
||||
Args:
|
||||
server: 服务器名称
|
||||
tool_name: 工具名称
|
||||
|
||||
Returns:
|
||||
工具信息或None
|
||||
"""
|
||||
tools = cls._registered_tools.get(server, [])
|
||||
for tool in tools:
|
||||
if tool.get("name") == tool_name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def call_tool(
|
||||
cls,
|
||||
server: str,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
调用MCP工具
|
||||
|
||||
Args:
|
||||
server: 服务器名称
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
|
||||
Returns:
|
||||
调用结果
|
||||
"""
|
||||
tool = cls.get_tool(server, tool_name)
|
||||
if not tool:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"工具不存在: {server}/{tool_name}"
|
||||
}
|
||||
|
||||
# TODO: 实际的MCP工具调用逻辑
|
||||
# 这里需要根据MCP协议实现工具调用
|
||||
# 目前返回模拟结果
|
||||
logger.info(f"调用MCP工具: {server}/{tool_name}, 参数: {arguments}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"result": f"MCP工具调用: {tool_name}",
|
||||
"tool": tool_name,
|
||||
"server": server,
|
||||
"arguments": arguments
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_tool_for_agent(
|
||||
cls,
|
||||
agent_id: str,
|
||||
tool_name: str
|
||||
) -> bool:
|
||||
"""
|
||||
为Agent注册可用工具
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
tool_name: 工具名称(格式: server/tool_name)
|
||||
|
||||
Returns:
|
||||
是否注册成功
|
||||
"""
|
||||
if agent_id not in cls._agent_tools:
|
||||
cls._agent_tools[agent_id] = []
|
||||
|
||||
if tool_name not in cls._agent_tools[agent_id]:
|
||||
cls._agent_tools[agent_id].append(tool_name)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def unregister_tool_for_agent(
|
||||
cls,
|
||||
agent_id: str,
|
||||
tool_name: str
|
||||
) -> bool:
|
||||
"""
|
||||
为Agent注销工具
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
tool_name: 工具名称
|
||||
|
||||
Returns:
|
||||
是否注销成功
|
||||
"""
|
||||
if agent_id in cls._agent_tools:
|
||||
if tool_name in cls._agent_tools[agent_id]:
|
||||
cls._agent_tools[agent_id].remove(tool_name)
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_agent_tools(cls, agent_id: str) -> List[str]:
|
||||
"""
|
||||
获取Agent可用的工具列表
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
工具名称列表
|
||||
"""
|
||||
return cls._agent_tools.get(agent_id, [])
|
||||
|
||||
@classmethod
|
||||
def get_tools_for_prompt(cls, agent_id: str) -> str:
|
||||
"""
|
||||
获取用于提示词的工具描述
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
工具描述文本
|
||||
"""
|
||||
tool_names = cls.get_agent_tools(agent_id)
|
||||
if not tool_names:
|
||||
return ""
|
||||
|
||||
descriptions = []
|
||||
for full_name in tool_names:
|
||||
parts = full_name.split("/", 1)
|
||||
if len(parts) == 2:
|
||||
server, tool_name = parts
|
||||
tool = cls.get_tool(server, tool_name)
|
||||
if tool:
|
||||
desc = tool.get("description", "无描述")
|
||||
descriptions.append(f"- {tool_name}: {desc}")
|
||||
|
||||
if not descriptions:
|
||||
return ""
|
||||
|
||||
return "你可以使用以下工具:\n" + "\n".join(descriptions)
|
||||
416
backend/services/memory_service.py
Normal file
416
backend/services/memory_service.py
Normal file
@@ -0,0 +1,416 @@
|
||||
"""
|
||||
记忆服务
|
||||
管理Agent的记忆存储和检索
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any, Optional
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from models.agent_memory import AgentMemory, MemoryType
|
||||
|
||||
|
||||
class MemoryService:
|
||||
"""
|
||||
Agent记忆服务
|
||||
提供记忆的存储、检索和管理功能
|
||||
"""
|
||||
|
||||
# 嵌入模型(延迟加载)
|
||||
_embedding_model = None
|
||||
|
||||
@classmethod
|
||||
def _get_embedding_model(cls):
|
||||
"""
|
||||
获取嵌入模型实例(延迟加载)
|
||||
"""
|
||||
if cls._embedding_model is None:
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
cls._embedding_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
|
||||
logger.info("嵌入模型加载成功")
|
||||
except Exception as e:
|
||||
logger.warning(f"嵌入模型加载失败: {e}")
|
||||
return None
|
||||
return cls._embedding_model
|
||||
|
||||
@classmethod
|
||||
async def create_memory(
|
||||
cls,
|
||||
agent_id: str,
|
||||
content: str,
|
||||
memory_type: str = MemoryType.SHORT_TERM.value,
|
||||
importance: float = 0.5,
|
||||
source_room_id: Optional[str] = None,
|
||||
source_discussion_id: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
expires_in_hours: Optional[int] = None
|
||||
) -> AgentMemory:
|
||||
"""
|
||||
创建新的记忆
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
content: 记忆内容
|
||||
memory_type: 记忆类型
|
||||
importance: 重要性评分
|
||||
source_room_id: 来源聊天室
|
||||
source_discussion_id: 来源讨论
|
||||
tags: 标签
|
||||
expires_in_hours: 过期时间(小时)
|
||||
|
||||
Returns:
|
||||
创建的AgentMemory文档
|
||||
"""
|
||||
memory_id = f"mem-{uuid.uuid4().hex[:12]}"
|
||||
|
||||
# 生成向量嵌入
|
||||
embedding = await cls._generate_embedding(content)
|
||||
|
||||
# 生成摘要
|
||||
summary = content[:100] + "..." if len(content) > 100 else content
|
||||
|
||||
# 计算过期时间
|
||||
expires_at = None
|
||||
if expires_in_hours:
|
||||
expires_at = datetime.utcnow() + timedelta(hours=expires_in_hours)
|
||||
|
||||
memory = AgentMemory(
|
||||
memory_id=memory_id,
|
||||
agent_id=agent_id,
|
||||
memory_type=memory_type,
|
||||
content=content,
|
||||
summary=summary,
|
||||
embedding=embedding,
|
||||
importance=importance,
|
||||
source_room_id=source_room_id,
|
||||
source_discussion_id=source_discussion_id,
|
||||
tags=tags or [],
|
||||
created_at=datetime.utcnow(),
|
||||
last_accessed=datetime.utcnow(),
|
||||
expires_at=expires_at
|
||||
)
|
||||
|
||||
await memory.insert()
|
||||
|
||||
logger.debug(f"创建记忆: {memory_id} for Agent {agent_id}")
|
||||
return memory
|
||||
|
||||
@classmethod
|
||||
async def get_memory(cls, memory_id: str) -> Optional[AgentMemory]:
|
||||
"""
|
||||
获取指定记忆
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
|
||||
Returns:
|
||||
AgentMemory文档或None
|
||||
"""
|
||||
return await AgentMemory.find_one(AgentMemory.memory_id == memory_id)
|
||||
|
||||
@classmethod
|
||||
async def get_agent_memories(
|
||||
cls,
|
||||
agent_id: str,
|
||||
memory_type: Optional[str] = None,
|
||||
limit: int = 50
|
||||
) -> List[AgentMemory]:
|
||||
"""
|
||||
获取Agent的记忆列表
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
memory_type: 记忆类型(可选)
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
记忆列表
|
||||
"""
|
||||
query = {"agent_id": agent_id}
|
||||
if memory_type:
|
||||
query["memory_type"] = memory_type
|
||||
|
||||
return await AgentMemory.find(query).sort(
|
||||
"-importance", "-last_accessed"
|
||||
).limit(limit).to_list()
|
||||
|
||||
@classmethod
|
||||
async def search_memories(
|
||||
cls,
|
||||
agent_id: str,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
memory_type: Optional[str] = None,
|
||||
min_relevance: float = 0.3
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
搜索相关记忆
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
query: 查询文本
|
||||
limit: 返回数量
|
||||
memory_type: 记忆类型(可选)
|
||||
min_relevance: 最小相关性阈值
|
||||
|
||||
Returns:
|
||||
带相关性分数的记忆列表
|
||||
"""
|
||||
# 生成查询向量
|
||||
query_embedding = await cls._generate_embedding(query)
|
||||
if not query_embedding:
|
||||
# 无法生成向量时,使用文本匹配
|
||||
return await cls._text_search(agent_id, query, limit, memory_type)
|
||||
|
||||
# 获取Agent的所有记忆
|
||||
filter_query = {"agent_id": agent_id}
|
||||
if memory_type:
|
||||
filter_query["memory_type"] = memory_type
|
||||
|
||||
memories = await AgentMemory.find(filter_query).to_list()
|
||||
|
||||
# 计算相似度
|
||||
results = []
|
||||
for memory in memories:
|
||||
if memory.is_expired():
|
||||
continue
|
||||
|
||||
if memory.embedding:
|
||||
similarity = cls._cosine_similarity(query_embedding, memory.embedding)
|
||||
relevance = memory.calculate_relevance_score(similarity)
|
||||
|
||||
if relevance >= min_relevance:
|
||||
results.append({
|
||||
"memory": memory,
|
||||
"similarity": similarity,
|
||||
"relevance": relevance
|
||||
})
|
||||
|
||||
# 按相关性排序
|
||||
results.sort(key=lambda x: x["relevance"], reverse=True)
|
||||
|
||||
# 更新访问记录
|
||||
for item in results[:limit]:
|
||||
memory = item["memory"]
|
||||
memory.access()
|
||||
await memory.save()
|
||||
|
||||
return results[:limit]
|
||||
|
||||
@classmethod
|
||||
async def update_memory(
|
||||
cls,
|
||||
memory_id: str,
|
||||
**kwargs
|
||||
) -> Optional[AgentMemory]:
|
||||
"""
|
||||
更新记忆
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
**kwargs: 要更新的字段
|
||||
|
||||
Returns:
|
||||
更新后的AgentMemory或None
|
||||
"""
|
||||
memory = await cls.get_memory(memory_id)
|
||||
if not memory:
|
||||
return None
|
||||
|
||||
# 如果更新了内容,重新生成嵌入
|
||||
if "content" in kwargs:
|
||||
kwargs["embedding"] = await cls._generate_embedding(kwargs["content"])
|
||||
kwargs["summary"] = kwargs["content"][:100] + "..." if len(kwargs["content"]) > 100 else kwargs["content"]
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(memory, key):
|
||||
setattr(memory, key, value)
|
||||
|
||||
await memory.save()
|
||||
return memory
|
||||
|
||||
@classmethod
|
||||
async def delete_memory(cls, memory_id: str) -> bool:
|
||||
"""
|
||||
删除记忆
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
memory = await cls.get_memory(memory_id)
|
||||
if not memory:
|
||||
return False
|
||||
|
||||
await memory.delete()
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def delete_agent_memories(
|
||||
cls,
|
||||
agent_id: str,
|
||||
memory_type: Optional[str] = None
|
||||
) -> int:
|
||||
"""
|
||||
删除Agent的记忆
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
memory_type: 记忆类型(可选)
|
||||
|
||||
Returns:
|
||||
删除的数量
|
||||
"""
|
||||
query = {"agent_id": agent_id}
|
||||
if memory_type:
|
||||
query["memory_type"] = memory_type
|
||||
|
||||
result = await AgentMemory.find(query).delete()
|
||||
return result.deleted_count if result else 0
|
||||
|
||||
@classmethod
|
||||
async def cleanup_expired_memories(cls) -> int:
|
||||
"""
|
||||
清理过期的记忆
|
||||
|
||||
Returns:
|
||||
清理的数量
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
result = await AgentMemory.find(
|
||||
{"expires_at": {"$lt": now}}
|
||||
).delete()
|
||||
|
||||
count = result.deleted_count if result else 0
|
||||
if count > 0:
|
||||
logger.info(f"清理了 {count} 条过期记忆")
|
||||
|
||||
return count
|
||||
|
||||
@classmethod
|
||||
async def consolidate_memories(
|
||||
cls,
|
||||
agent_id: str,
|
||||
min_importance: float = 0.7,
|
||||
max_age_days: int = 30
|
||||
) -> None:
|
||||
"""
|
||||
整合记忆(将重要的短期记忆转为长期记忆)
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
min_importance: 最小重要性阈值
|
||||
max_age_days: 最大年龄(天)
|
||||
"""
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=max_age_days)
|
||||
|
||||
# 查找符合条件的短期记忆
|
||||
memories = await AgentMemory.find({
|
||||
"agent_id": agent_id,
|
||||
"memory_type": MemoryType.SHORT_TERM.value,
|
||||
"importance": {"$gte": min_importance},
|
||||
"created_at": {"$lt": cutoff_date}
|
||||
}).to_list()
|
||||
|
||||
for memory in memories:
|
||||
memory.memory_type = MemoryType.LONG_TERM.value
|
||||
memory.expires_at = None # 长期记忆不过期
|
||||
await memory.save()
|
||||
|
||||
if memories:
|
||||
logger.info(f"整合了 {len(memories)} 条记忆为长期记忆: Agent {agent_id}")
|
||||
|
||||
@classmethod
|
||||
async def _generate_embedding(cls, text: str) -> List[float]:
|
||||
"""
|
||||
生成文本的向量嵌入
|
||||
|
||||
Args:
|
||||
text: 文本内容
|
||||
|
||||
Returns:
|
||||
向量嵌入列表
|
||||
"""
|
||||
model = cls._get_embedding_model()
|
||||
if model is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
embedding = model.encode(text, convert_to_numpy=True)
|
||||
return embedding.tolist()
|
||||
except Exception as e:
|
||||
logger.warning(f"生成嵌入失败: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _cosine_similarity(cls, vec1: List[float], vec2: List[float]) -> float:
|
||||
"""
|
||||
计算余弦相似度
|
||||
|
||||
Args:
|
||||
vec1: 向量1
|
||||
vec2: 向量2
|
||||
|
||||
Returns:
|
||||
相似度 (0-1)
|
||||
"""
|
||||
if not vec1 or not vec2:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
a = np.array(vec1)
|
||||
b = np.array(vec2)
|
||||
similarity = np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
||||
return float(max(0, similarity))
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
async def _text_search(
|
||||
cls,
|
||||
agent_id: str,
|
||||
query: str,
|
||||
limit: int,
|
||||
memory_type: Optional[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
文本搜索(后备方案)
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
query: 查询文本
|
||||
limit: 返回数量
|
||||
memory_type: 记忆类型
|
||||
|
||||
Returns:
|
||||
记忆列表
|
||||
"""
|
||||
filter_query = {"agent_id": agent_id}
|
||||
if memory_type:
|
||||
filter_query["memory_type"] = memory_type
|
||||
|
||||
# 简单的文本匹配
|
||||
memories = await AgentMemory.find(filter_query).to_list()
|
||||
|
||||
results = []
|
||||
query_lower = query.lower()
|
||||
for memory in memories:
|
||||
if memory.is_expired():
|
||||
continue
|
||||
|
||||
content_lower = memory.content.lower()
|
||||
if query_lower in content_lower:
|
||||
# 计算简单的匹配分数
|
||||
score = len(query_lower) / len(content_lower)
|
||||
results.append({
|
||||
"memory": memory,
|
||||
"similarity": score,
|
||||
"relevance": score * memory.importance
|
||||
})
|
||||
|
||||
results.sort(key=lambda x: x["relevance"], reverse=True)
|
||||
return results[:limit]
|
||||
335
backend/services/message_router.py
Normal file
335
backend/services/message_router.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
消息路由服务
|
||||
管理消息的发送和广播
|
||||
"""
|
||||
import uuid
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, Callable, Set
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
from fastapi import WebSocket
|
||||
|
||||
from models.message import Message, MessageType
|
||||
from models.chatroom import ChatRoom
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSocketConnection:
|
||||
"""WebSocket连接信息"""
|
||||
websocket: WebSocket
|
||||
room_id: str
|
||||
connected_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class MessageRouter:
|
||||
"""
|
||||
消息路由器
|
||||
管理WebSocket连接和消息广播
|
||||
"""
|
||||
|
||||
# 房间连接映射: room_id -> Set[WebSocket]
|
||||
_room_connections: Dict[str, Set[WebSocket]] = {}
|
||||
|
||||
# 消息回调: 用于外部订阅消息
|
||||
_message_callbacks: List[Callable] = []
|
||||
|
||||
@classmethod
|
||||
async def connect(cls, room_id: str, websocket: WebSocket) -> None:
|
||||
"""
|
||||
建立WebSocket连接
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
websocket: WebSocket实例
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
if room_id not in cls._room_connections:
|
||||
cls._room_connections[room_id] = set()
|
||||
|
||||
cls._room_connections[room_id].add(websocket)
|
||||
|
||||
logger.info(f"WebSocket连接建立: {room_id}, 当前连接数: {len(cls._room_connections[room_id])}")
|
||||
|
||||
@classmethod
|
||||
async def disconnect(cls, room_id: str, websocket: WebSocket) -> None:
|
||||
"""
|
||||
断开WebSocket连接
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
websocket: WebSocket实例
|
||||
"""
|
||||
if room_id in cls._room_connections:
|
||||
cls._room_connections[room_id].discard(websocket)
|
||||
|
||||
# 清理空房间
|
||||
if not cls._room_connections[room_id]:
|
||||
del cls._room_connections[room_id]
|
||||
|
||||
logger.info(f"WebSocket连接断开: {room_id}")
|
||||
|
||||
@classmethod
|
||||
async def broadcast_to_room(
|
||||
cls,
|
||||
room_id: str,
|
||||
message: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
向聊天室广播消息
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
message: 消息内容
|
||||
"""
|
||||
if room_id not in cls._room_connections:
|
||||
return
|
||||
|
||||
# 获取所有连接
|
||||
connections = cls._room_connections[room_id].copy()
|
||||
|
||||
# 并发发送
|
||||
tasks = []
|
||||
for websocket in connections:
|
||||
tasks.append(cls._send_message(room_id, websocket, message))
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
@classmethod
|
||||
async def _send_message(
|
||||
cls,
|
||||
room_id: str,
|
||||
websocket: WebSocket,
|
||||
message: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
向单个WebSocket发送消息
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
websocket: WebSocket实例
|
||||
message: 消息内容
|
||||
"""
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket发送失败: {e}")
|
||||
# 移除断开的连接
|
||||
await cls.disconnect(room_id, websocket)
|
||||
|
||||
@classmethod
|
||||
async def save_and_broadcast_message(
|
||||
cls,
|
||||
room_id: str,
|
||||
discussion_id: str,
|
||||
agent_id: Optional[str],
|
||||
content: str,
|
||||
message_type: str = MessageType.TEXT.value,
|
||||
round_num: int = 0,
|
||||
attachments: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_results: Optional[List[Dict[str, Any]]] = None
|
||||
) -> Message:
|
||||
"""
|
||||
保存消息并广播
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
discussion_id: 讨论ID
|
||||
agent_id: 发送Agent ID
|
||||
content: 消息内容
|
||||
message_type: 消息类型
|
||||
round_num: 轮次号
|
||||
attachments: 附件
|
||||
tool_calls: 工具调用
|
||||
tool_results: 工具结果
|
||||
|
||||
Returns:
|
||||
保存的Message文档
|
||||
"""
|
||||
# 创建消息
|
||||
message = Message(
|
||||
message_id=f"msg-{uuid.uuid4().hex[:12]}",
|
||||
room_id=room_id,
|
||||
discussion_id=discussion_id,
|
||||
agent_id=agent_id,
|
||||
content=content,
|
||||
message_type=message_type,
|
||||
attachments=attachments or [],
|
||||
round=round_num,
|
||||
token_count=len(content) // 4, # 粗略估计
|
||||
tool_calls=tool_calls or [],
|
||||
tool_results=tool_results or [],
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
await message.insert()
|
||||
|
||||
# 构建广播消息
|
||||
broadcast_data = {
|
||||
"type": "message",
|
||||
"data": {
|
||||
"message_id": message.message_id,
|
||||
"room_id": message.room_id,
|
||||
"discussion_id": message.discussion_id,
|
||||
"agent_id": message.agent_id,
|
||||
"content": message.content,
|
||||
"message_type": message.message_type,
|
||||
"round": message.round,
|
||||
"created_at": message.created_at.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
# 广播消息
|
||||
await cls.broadcast_to_room(room_id, broadcast_data)
|
||||
|
||||
# 触发回调
|
||||
for callback in cls._message_callbacks:
|
||||
try:
|
||||
await callback(message)
|
||||
except Exception as e:
|
||||
logger.error(f"消息回调执行失败: {e}")
|
||||
|
||||
return message
|
||||
|
||||
@classmethod
|
||||
async def broadcast_status(
|
||||
cls,
|
||||
room_id: str,
|
||||
status: str,
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""
|
||||
广播状态更新
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
status: 状态类型
|
||||
data: 附加数据
|
||||
"""
|
||||
message = {
|
||||
"type": "status",
|
||||
"status": status,
|
||||
"data": data or {},
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await cls.broadcast_to_room(room_id, message)
|
||||
|
||||
@classmethod
|
||||
async def broadcast_typing(
|
||||
cls,
|
||||
room_id: str,
|
||||
agent_id: str,
|
||||
is_typing: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
广播Agent输入状态
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
agent_id: Agent ID
|
||||
is_typing: 是否正在输入
|
||||
"""
|
||||
message = {
|
||||
"type": "typing",
|
||||
"agent_id": agent_id,
|
||||
"is_typing": is_typing,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await cls.broadcast_to_room(room_id, message)
|
||||
|
||||
@classmethod
|
||||
async def broadcast_round_info(
|
||||
cls,
|
||||
room_id: str,
|
||||
round_num: int,
|
||||
total_rounds: int
|
||||
) -> None:
|
||||
"""
|
||||
广播轮次信息
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
round_num: 当前轮次
|
||||
total_rounds: 最大轮次
|
||||
"""
|
||||
message = {
|
||||
"type": "round",
|
||||
"round": round_num,
|
||||
"total_rounds": total_rounds,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await cls.broadcast_to_room(room_id, message)
|
||||
|
||||
@classmethod
|
||||
async def broadcast_error(
|
||||
cls,
|
||||
room_id: str,
|
||||
error: str,
|
||||
agent_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
广播错误信息
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
error: 错误信息
|
||||
agent_id: 相关Agent ID
|
||||
"""
|
||||
message = {
|
||||
"type": "error",
|
||||
"error": error,
|
||||
"agent_id": agent_id,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await cls.broadcast_to_room(room_id, message)
|
||||
|
||||
@classmethod
|
||||
def register_callback(cls, callback: Callable) -> None:
|
||||
"""
|
||||
注册消息回调
|
||||
|
||||
Args:
|
||||
callback: 回调函数,接收Message参数
|
||||
"""
|
||||
cls._message_callbacks.append(callback)
|
||||
|
||||
@classmethod
|
||||
def unregister_callback(cls, callback: Callable) -> None:
|
||||
"""
|
||||
注销消息回调
|
||||
|
||||
Args:
|
||||
callback: 回调函数
|
||||
"""
|
||||
if callback in cls._message_callbacks:
|
||||
cls._message_callbacks.remove(callback)
|
||||
|
||||
@classmethod
|
||||
def get_connection_count(cls, room_id: str) -> int:
|
||||
"""
|
||||
获取房间连接数
|
||||
|
||||
Args:
|
||||
room_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
连接数
|
||||
"""
|
||||
return len(cls._room_connections.get(room_id, set()))
|
||||
|
||||
@classmethod
|
||||
def get_all_room_ids(cls) -> List[str]:
|
||||
"""
|
||||
获取所有活跃房间ID
|
||||
|
||||
Returns:
|
||||
房间ID列表
|
||||
"""
|
||||
return list(cls._room_connections.keys())
|
||||
Reference in New Issue
Block a user