358 lines
9.6 KiB
Python
358 lines
9.6 KiB
Python
|
|
"""
|
|||
|
|
聊天室服务
|
|||
|
|
管理聊天室的创建和状态
|
|||
|
|
"""
|
|||
|
|
import uuid
|
|||
|
|
from datetime import datetime
|
|||
|
|
from typing import List, Dict, Any, Optional
|
|||
|
|
from loguru import logger
|
|||
|
|
|
|||
|
|
from models.chatroom import ChatRoom, ChatRoomStatus
|
|||
|
|
from models.message import Message
|
|||
|
|
from services.agent_service import AgentService
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ChatRoomService:
|
|||
|
|
"""
|
|||
|
|
聊天室服务类
|
|||
|
|
负责聊天室的CRUD操作
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def create_chatroom(
|
|||
|
|
cls,
|
|||
|
|
name: str,
|
|||
|
|
description: str = "",
|
|||
|
|
agents: Optional[List[str]] = None,
|
|||
|
|
moderator_agent_id: Optional[str] = None,
|
|||
|
|
config: Optional[Dict[str, Any]] = None
|
|||
|
|
) -> ChatRoom:
|
|||
|
|
"""
|
|||
|
|
创建新的聊天室
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
name: 聊天室名称
|
|||
|
|
description: 描述
|
|||
|
|
agents: Agent ID列表
|
|||
|
|
moderator_agent_id: 主持人Agent ID
|
|||
|
|
config: 聊天室配置
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
创建的ChatRoom文档
|
|||
|
|
"""
|
|||
|
|
# 验证Agent存在
|
|||
|
|
if agents:
|
|||
|
|
existing_agents = await AgentService.get_agents_by_ids(agents)
|
|||
|
|
existing_ids = {a.agent_id for a in existing_agents}
|
|||
|
|
missing_ids = set(agents) - existing_ids
|
|||
|
|
if missing_ids:
|
|||
|
|
raise ValueError(f"Agent不存在: {', '.join(missing_ids)}")
|
|||
|
|
|
|||
|
|
# 验证主持人Agent
|
|||
|
|
if moderator_agent_id:
|
|||
|
|
moderator = await AgentService.get_agent(moderator_agent_id)
|
|||
|
|
if not moderator:
|
|||
|
|
raise ValueError(f"主持人Agent不存在: {moderator_agent_id}")
|
|||
|
|
|
|||
|
|
# 生成唯一ID
|
|||
|
|
room_id = f"room-{uuid.uuid4().hex[:8]}"
|
|||
|
|
|
|||
|
|
# 默认配置
|
|||
|
|
default_config = {
|
|||
|
|
"max_rounds": 50,
|
|||
|
|
"message_history_size": 20,
|
|||
|
|
"consensus_threshold": 0.8,
|
|||
|
|
"round_interval": 1.0,
|
|||
|
|
"allow_user_interrupt": True
|
|||
|
|
}
|
|||
|
|
if config:
|
|||
|
|
default_config.update(config)
|
|||
|
|
|
|||
|
|
# 创建文档
|
|||
|
|
chatroom = ChatRoom(
|
|||
|
|
room_id=room_id,
|
|||
|
|
name=name,
|
|||
|
|
description=description,
|
|||
|
|
objective="",
|
|||
|
|
agents=agents or [],
|
|||
|
|
moderator_agent_id=moderator_agent_id,
|
|||
|
|
config=default_config,
|
|||
|
|
status=ChatRoomStatus.IDLE.value,
|
|||
|
|
current_round=0,
|
|||
|
|
created_at=datetime.utcnow(),
|
|||
|
|
updated_at=datetime.utcnow()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
await chatroom.insert()
|
|||
|
|
|
|||
|
|
logger.info(f"创建聊天室: {room_id} ({name})")
|
|||
|
|
return chatroom
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def get_chatroom(cls, room_id: str) -> Optional[ChatRoom]:
|
|||
|
|
"""
|
|||
|
|
获取指定聊天室
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
room_id: 聊天室ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
ChatRoom文档或None
|
|||
|
|
"""
|
|||
|
|
return await ChatRoom.find_one(ChatRoom.room_id == room_id)
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def get_all_chatrooms(cls) -> List[ChatRoom]:
|
|||
|
|
"""
|
|||
|
|
获取所有聊天室
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
ChatRoom列表
|
|||
|
|
"""
|
|||
|
|
return await ChatRoom.find_all().to_list()
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def update_chatroom(
|
|||
|
|
cls,
|
|||
|
|
room_id: str,
|
|||
|
|
**kwargs
|
|||
|
|
) -> Optional[ChatRoom]:
|
|||
|
|
"""
|
|||
|
|
更新聊天室配置
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
room_id: 聊天室ID
|
|||
|
|
**kwargs: 要更新的字段
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的ChatRoom或None
|
|||
|
|
"""
|
|||
|
|
chatroom = await cls.get_chatroom(room_id)
|
|||
|
|
if not chatroom:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 验证Agent
|
|||
|
|
if "agents" in kwargs:
|
|||
|
|
existing_agents = await AgentService.get_agents_by_ids(kwargs["agents"])
|
|||
|
|
existing_ids = {a.agent_id for a in existing_agents}
|
|||
|
|
missing_ids = set(kwargs["agents"]) - existing_ids
|
|||
|
|
if missing_ids:
|
|||
|
|
raise ValueError(f"Agent不存在: {', '.join(missing_ids)}")
|
|||
|
|
|
|||
|
|
# 验证主持人
|
|||
|
|
if "moderator_agent_id" in kwargs and kwargs["moderator_agent_id"]:
|
|||
|
|
moderator = await AgentService.get_agent(kwargs["moderator_agent_id"])
|
|||
|
|
if not moderator:
|
|||
|
|
raise ValueError(f"主持人Agent不存在: {kwargs['moderator_agent_id']}")
|
|||
|
|
|
|||
|
|
# 更新字段
|
|||
|
|
kwargs["updated_at"] = datetime.utcnow()
|
|||
|
|
|
|||
|
|
for key, value in kwargs.items():
|
|||
|
|
if hasattr(chatroom, key):
|
|||
|
|
setattr(chatroom, key, value)
|
|||
|
|
|
|||
|
|
await chatroom.save()
|
|||
|
|
|
|||
|
|
logger.info(f"更新聊天室: {room_id}")
|
|||
|
|
return chatroom
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def delete_chatroom(cls, room_id: str) -> bool:
|
|||
|
|
"""
|
|||
|
|
删除聊天室
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
room_id: 聊天室ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
是否删除成功
|
|||
|
|
"""
|
|||
|
|
chatroom = await cls.get_chatroom(room_id)
|
|||
|
|
if not chatroom:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
# 删除相关消息
|
|||
|
|
await Message.find(Message.room_id == room_id).delete()
|
|||
|
|
|
|||
|
|
await chatroom.delete()
|
|||
|
|
|
|||
|
|
logger.info(f"删除聊天室: {room_id}")
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def add_agent(cls, room_id: str, agent_id: str) -> Optional[ChatRoom]:
|
|||
|
|
"""
|
|||
|
|
向聊天室添加Agent
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
room_id: 聊天室ID
|
|||
|
|
agent_id: Agent ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的ChatRoom或None
|
|||
|
|
"""
|
|||
|
|
chatroom = await cls.get_chatroom(room_id)
|
|||
|
|
if not chatroom:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 验证Agent存在
|
|||
|
|
agent = await AgentService.get_agent(agent_id)
|
|||
|
|
if not agent:
|
|||
|
|
raise ValueError(f"Agent不存在: {agent_id}")
|
|||
|
|
|
|||
|
|
# 添加Agent
|
|||
|
|
if agent_id not in chatroom.agents:
|
|||
|
|
chatroom.agents.append(agent_id)
|
|||
|
|
chatroom.updated_at = datetime.utcnow()
|
|||
|
|
await chatroom.save()
|
|||
|
|
|
|||
|
|
return chatroom
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def remove_agent(cls, room_id: str, agent_id: str) -> Optional[ChatRoom]:
|
|||
|
|
"""
|
|||
|
|
从聊天室移除Agent
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
room_id: 聊天室ID
|
|||
|
|
agent_id: Agent ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的ChatRoom或None
|
|||
|
|
"""
|
|||
|
|
chatroom = await cls.get_chatroom(room_id)
|
|||
|
|
if not chatroom:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 移除Agent
|
|||
|
|
if agent_id in chatroom.agents:
|
|||
|
|
chatroom.agents.remove(agent_id)
|
|||
|
|
chatroom.updated_at = datetime.utcnow()
|
|||
|
|
await chatroom.save()
|
|||
|
|
|
|||
|
|
return chatroom
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def set_objective(
|
|||
|
|
cls,
|
|||
|
|
room_id: str,
|
|||
|
|
objective: str
|
|||
|
|
) -> Optional[ChatRoom]:
|
|||
|
|
"""
|
|||
|
|
设置讨论目标
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
room_id: 聊天室ID
|
|||
|
|
objective: 讨论目标
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的ChatRoom或None
|
|||
|
|
"""
|
|||
|
|
return await cls.update_chatroom(room_id, objective=objective)
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def update_status(
|
|||
|
|
cls,
|
|||
|
|
room_id: str,
|
|||
|
|
status: ChatRoomStatus
|
|||
|
|
) -> Optional[ChatRoom]:
|
|||
|
|
"""
|
|||
|
|
更新聊天室状态
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
room_id: 聊天室ID
|
|||
|
|
status: 新状态
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的ChatRoom或None
|
|||
|
|
"""
|
|||
|
|
chatroom = await cls.get_chatroom(room_id)
|
|||
|
|
if not chatroom:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
chatroom.status = status.value
|
|||
|
|
chatroom.updated_at = datetime.utcnow()
|
|||
|
|
|
|||
|
|
if status == ChatRoomStatus.COMPLETED:
|
|||
|
|
chatroom.completed_at = datetime.utcnow()
|
|||
|
|
|
|||
|
|
await chatroom.save()
|
|||
|
|
|
|||
|
|
logger.info(f"聊天室状态更新: {room_id} -> {status.value}")
|
|||
|
|
return chatroom
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def increment_round(cls, room_id: str) -> Optional[ChatRoom]:
|
|||
|
|
"""
|
|||
|
|
增加轮次计数
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
room_id: 聊天室ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的ChatRoom或None
|
|||
|
|
"""
|
|||
|
|
chatroom = await cls.get_chatroom(room_id)
|
|||
|
|
if not chatroom:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
chatroom.current_round += 1
|
|||
|
|
chatroom.updated_at = datetime.utcnow()
|
|||
|
|
await chatroom.save()
|
|||
|
|
|
|||
|
|
return chatroom
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def get_messages(
|
|||
|
|
cls,
|
|||
|
|
room_id: str,
|
|||
|
|
limit: int = 50,
|
|||
|
|
skip: int = 0,
|
|||
|
|
discussion_id: Optional[str] = None
|
|||
|
|
) -> List[Message]:
|
|||
|
|
"""
|
|||
|
|
获取聊天室消息历史
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
room_id: 聊天室ID
|
|||
|
|
limit: 返回数量限制
|
|||
|
|
skip: 跳过数量
|
|||
|
|
discussion_id: 讨论ID(可选)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
消息列表
|
|||
|
|
"""
|
|||
|
|
query = {"room_id": room_id}
|
|||
|
|
if discussion_id:
|
|||
|
|
query["discussion_id"] = discussion_id
|
|||
|
|
|
|||
|
|
return await Message.find(query).sort(
|
|||
|
|
"-created_at"
|
|||
|
|
).skip(skip).limit(limit).to_list()
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def get_recent_messages(
|
|||
|
|
cls,
|
|||
|
|
room_id: str,
|
|||
|
|
count: int = 20,
|
|||
|
|
discussion_id: Optional[str] = None
|
|||
|
|
) -> List[Message]:
|
|||
|
|
"""
|
|||
|
|
获取最近的消息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
room_id: 聊天室ID
|
|||
|
|
count: 消息数量
|
|||
|
|
discussion_id: 讨论ID(可选)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
消息列表(按时间正序)
|
|||
|
|
"""
|
|||
|
|
messages = await cls.get_messages(
|
|||
|
|
room_id,
|
|||
|
|
limit=count,
|
|||
|
|
discussion_id=discussion_id
|
|||
|
|
)
|
|||
|
|
return list(reversed(messages)) # 返回正序
|