feat: AI聊天室多Agent协作讨论平台
- 实现Agent管理,支持AI辅助生成系统提示词 - 支持多个AI提供商(OpenRouter、智谱、MiniMax等) - 实现聊天室和讨论引擎 - WebSocket实时消息推送 - 前端使用React + Ant Design - 后端使用FastAPI + MongoDB Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
25
backend/models/__init__.py
Normal file
25
backend/models/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
数据模型模块
|
||||
"""
|
||||
from .ai_provider import AIProvider, ProxyConfig, RateLimit
|
||||
from .agent import Agent, AgentCapabilities, AgentBehavior
|
||||
from .chatroom import ChatRoom, ChatRoomConfig
|
||||
from .message import Message, MessageType
|
||||
from .discussion_result import DiscussionResult
|
||||
from .agent_memory import AgentMemory, MemoryType
|
||||
|
||||
__all__ = [
|
||||
"AIProvider",
|
||||
"ProxyConfig",
|
||||
"RateLimit",
|
||||
"Agent",
|
||||
"AgentCapabilities",
|
||||
"AgentBehavior",
|
||||
"ChatRoom",
|
||||
"ChatRoomConfig",
|
||||
"Message",
|
||||
"MessageType",
|
||||
"DiscussionResult",
|
||||
"AgentMemory",
|
||||
"MemoryType",
|
||||
]
|
||||
168
backend/models/agent.py
Normal file
168
backend/models/agent.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
Agent数据模型
|
||||
定义AI聊天代理的配置结构
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from pydantic import Field
|
||||
from beanie import Document
|
||||
|
||||
|
||||
class AgentCapabilities:
|
||||
"""Agent能力配置"""
|
||||
memory_enabled: bool = False # 是否启用记忆
|
||||
mcp_tools: List[str] = [] # 可用的MCP工具
|
||||
skills: List[str] = [] # 可用的技能
|
||||
multimodal: bool = False # 是否支持多模态
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
memory_enabled: bool = False,
|
||||
mcp_tools: Optional[List[str]] = None,
|
||||
skills: Optional[List[str]] = None,
|
||||
multimodal: bool = False
|
||||
):
|
||||
self.memory_enabled = memory_enabled
|
||||
self.mcp_tools = mcp_tools or []
|
||||
self.skills = skills or []
|
||||
self.multimodal = multimodal
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"memory_enabled": self.memory_enabled,
|
||||
"mcp_tools": self.mcp_tools,
|
||||
"skills": self.skills,
|
||||
"multimodal": self.multimodal
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AgentCapabilities":
|
||||
"""从字典创建"""
|
||||
if not data:
|
||||
return cls()
|
||||
return cls(
|
||||
memory_enabled=data.get("memory_enabled", False),
|
||||
mcp_tools=data.get("mcp_tools", []),
|
||||
skills=data.get("skills", []),
|
||||
multimodal=data.get("multimodal", False)
|
||||
)
|
||||
|
||||
|
||||
class AgentBehavior:
|
||||
"""Agent行为配置"""
|
||||
speak_threshold: float = 0.5 # 发言阈值(判断是否需要发言)
|
||||
max_speak_per_round: int = 2 # 每轮最多发言次数
|
||||
speak_style: str = "balanced" # 发言风格: concise, balanced, detailed
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
speak_threshold: float = 0.5,
|
||||
max_speak_per_round: int = 2,
|
||||
speak_style: str = "balanced"
|
||||
):
|
||||
self.speak_threshold = speak_threshold
|
||||
self.max_speak_per_round = max_speak_per_round
|
||||
self.speak_style = speak_style
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"speak_threshold": self.speak_threshold,
|
||||
"max_speak_per_round": self.max_speak_per_round,
|
||||
"speak_style": self.speak_style
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AgentBehavior":
|
||||
"""从字典创建"""
|
||||
if not data:
|
||||
return cls()
|
||||
return cls(
|
||||
speak_threshold=data.get("speak_threshold", 0.5),
|
||||
max_speak_per_round=data.get("max_speak_per_round", 2),
|
||||
speak_style=data.get("speak_style", "balanced")
|
||||
)
|
||||
|
||||
|
||||
class Agent(Document):
|
||||
"""
|
||||
Agent文档模型
|
||||
存储AI代理的配置信息
|
||||
"""
|
||||
agent_id: str = Field(..., description="唯一标识")
|
||||
name: str = Field(..., description="Agent名称")
|
||||
role: str = Field(..., description="角色定义")
|
||||
system_prompt: str = Field(..., description="系统提示词")
|
||||
provider_id: str = Field(..., description="使用的AI接口ID")
|
||||
|
||||
# 模型参数
|
||||
temperature: float = Field(default=0.7, ge=0, le=2, description="温度参数")
|
||||
max_tokens: int = Field(default=2000, gt=0, description="最大token数")
|
||||
|
||||
# 能力配置
|
||||
capabilities: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {
|
||||
"memory_enabled": False,
|
||||
"mcp_tools": [],
|
||||
"skills": [],
|
||||
"multimodal": False
|
||||
},
|
||||
description="能力配置"
|
||||
)
|
||||
|
||||
# 行为配置
|
||||
behavior: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {
|
||||
"speak_threshold": 0.5,
|
||||
"max_speak_per_round": 2,
|
||||
"speak_style": "balanced"
|
||||
},
|
||||
description="行为配置"
|
||||
)
|
||||
|
||||
# 外观配置
|
||||
avatar: Optional[str] = Field(default=None, description="头像URL")
|
||||
color: str = Field(default="#1890ff", description="代表颜色")
|
||||
|
||||
# 元数据
|
||||
enabled: bool = Field(default=True, description="是否启用")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
class Settings:
|
||||
name = "agents"
|
||||
|
||||
def get_capabilities(self) -> AgentCapabilities:
|
||||
"""获取能力配置对象"""
|
||||
return AgentCapabilities.from_dict(self.capabilities)
|
||||
|
||||
def get_behavior(self) -> AgentBehavior:
|
||||
"""获取行为配置对象"""
|
||||
return AgentBehavior.from_dict(self.behavior)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"agent_id": "product-manager",
|
||||
"name": "产品经理",
|
||||
"role": "产品规划和需求分析专家",
|
||||
"system_prompt": "你是一位经验丰富的产品经理,擅长分析用户需求...",
|
||||
"provider_id": "openrouter-gpt4",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 2000,
|
||||
"capabilities": {
|
||||
"memory_enabled": True,
|
||||
"mcp_tools": ["web_search"],
|
||||
"skills": [],
|
||||
"multimodal": False
|
||||
},
|
||||
"behavior": {
|
||||
"speak_threshold": 0.5,
|
||||
"max_speak_per_round": 2,
|
||||
"speak_style": "balanced"
|
||||
},
|
||||
"avatar": "https://example.com/avatar.png",
|
||||
"color": "#1890ff"
|
||||
}
|
||||
}
|
||||
123
backend/models/agent_memory.py
Normal file
123
backend/models/agent_memory.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Agent记忆数据模型
|
||||
定义Agent的记忆存储结构
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from enum import Enum
|
||||
from pydantic import Field
|
||||
from beanie import Document
|
||||
|
||||
|
||||
class MemoryType(str, Enum):
|
||||
"""记忆类型枚举"""
|
||||
SHORT_TERM = "short_term" # 短期记忆(会话内)
|
||||
LONG_TERM = "long_term" # 长期记忆(跨会话)
|
||||
EPISODIC = "episodic" # 情景记忆(特定事件)
|
||||
SEMANTIC = "semantic" # 语义记忆(知识性)
|
||||
|
||||
|
||||
class AgentMemory(Document):
|
||||
"""
|
||||
Agent记忆文档模型
|
||||
存储Agent的记忆内容
|
||||
"""
|
||||
memory_id: str = Field(..., description="唯一标识")
|
||||
agent_id: str = Field(..., description="Agent ID")
|
||||
|
||||
# 记忆内容
|
||||
memory_type: str = Field(
|
||||
default=MemoryType.SHORT_TERM.value,
|
||||
description="记忆类型"
|
||||
)
|
||||
content: str = Field(..., description="记忆内容")
|
||||
summary: str = Field(default="", description="内容摘要")
|
||||
|
||||
# 向量嵌入(用于相似度检索)
|
||||
embedding: List[float] = Field(default_factory=list, description="向量嵌入")
|
||||
|
||||
# 元数据
|
||||
importance: float = Field(default=0.5, ge=0, le=1, description="重要性评分")
|
||||
access_count: int = Field(default=0, description="访问次数")
|
||||
|
||||
# 关联信息
|
||||
source_room_id: Optional[str] = Field(default=None, description="来源聊天室ID")
|
||||
source_discussion_id: Optional[str] = Field(default=None, description="来源讨论ID")
|
||||
related_agents: List[str] = Field(default_factory=list, description="相关Agent列表")
|
||||
tags: List[str] = Field(default_factory=list, description="标签")
|
||||
|
||||
# 时间戳
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
last_accessed: datetime = Field(default_factory=datetime.utcnow)
|
||||
expires_at: Optional[datetime] = Field(default=None, description="过期时间")
|
||||
|
||||
class Settings:
|
||||
name = "agent_memories"
|
||||
indexes = [
|
||||
[("agent_id", 1)],
|
||||
[("memory_type", 1)],
|
||||
[("importance", -1)],
|
||||
[("last_accessed", -1)],
|
||||
]
|
||||
|
||||
def access(self) -> None:
|
||||
"""
|
||||
记录访问,更新访问计数和时间
|
||||
"""
|
||||
self.access_count += 1
|
||||
self.last_accessed = datetime.utcnow()
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""
|
||||
检查记忆是否已过期
|
||||
|
||||
Returns:
|
||||
是否过期
|
||||
"""
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
return datetime.utcnow() > self.expires_at
|
||||
|
||||
def calculate_relevance_score(
|
||||
self,
|
||||
similarity: float,
|
||||
time_decay_factor: float = 0.1
|
||||
) -> float:
|
||||
"""
|
||||
计算综合相关性分数
|
||||
结合向量相似度、重要性和时间衰减
|
||||
|
||||
Args:
|
||||
similarity: 向量相似度 (0-1)
|
||||
time_decay_factor: 时间衰减因子
|
||||
|
||||
Returns:
|
||||
综合相关性分数
|
||||
"""
|
||||
# 计算时间衰减
|
||||
hours_since_access = (datetime.utcnow() - self.last_accessed).total_seconds() / 3600
|
||||
time_decay = 1.0 / (1.0 + time_decay_factor * hours_since_access)
|
||||
|
||||
# 综合评分
|
||||
score = (
|
||||
0.5 * similarity +
|
||||
0.3 * self.importance +
|
||||
0.2 * time_decay
|
||||
)
|
||||
|
||||
return min(1.0, max(0.0, score))
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"memory_id": "mem-001",
|
||||
"agent_id": "product-manager",
|
||||
"memory_type": "long_term",
|
||||
"content": "在登录系统设计讨论中,团队决定采用OAuth2.0方案",
|
||||
"summary": "登录系统采用OAuth2.0",
|
||||
"importance": 0.8,
|
||||
"access_count": 5,
|
||||
"source_room_id": "product-design-room",
|
||||
"tags": ["登录", "OAuth", "认证"]
|
||||
}
|
||||
}
|
||||
149
backend/models/ai_provider.py
Normal file
149
backend/models/ai_provider.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
AI接口提供商数据模型
|
||||
定义AI服务配置结构
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from enum import Enum
|
||||
from pydantic import Field
|
||||
from beanie import Document
|
||||
|
||||
|
||||
class ProviderType(str, Enum):
|
||||
"""AI提供商类型枚举"""
|
||||
MINIMAX = "minimax"
|
||||
ZHIPU = "zhipu"
|
||||
OPENROUTER = "openrouter"
|
||||
KIMI = "kimi"
|
||||
DEEPSEEK = "deepseek"
|
||||
GEMINI = "gemini"
|
||||
OLLAMA = "ollama"
|
||||
LLMSTUDIO = "llmstudio"
|
||||
|
||||
|
||||
class ProxyConfig:
|
||||
"""代理配置"""
|
||||
http_proxy: Optional[str] = None # HTTP代理地址
|
||||
https_proxy: Optional[str] = None # HTTPS代理地址
|
||||
no_proxy: List[str] = [] # 不使用代理的域名列表
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
http_proxy: Optional[str] = None,
|
||||
https_proxy: Optional[str] = None,
|
||||
no_proxy: Optional[List[str]] = None
|
||||
):
|
||||
self.http_proxy = http_proxy
|
||||
self.https_proxy = https_proxy
|
||||
self.no_proxy = no_proxy or []
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"http_proxy": self.http_proxy,
|
||||
"https_proxy": self.https_proxy,
|
||||
"no_proxy": self.no_proxy
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ProxyConfig":
|
||||
"""从字典创建"""
|
||||
if not data:
|
||||
return cls()
|
||||
return cls(
|
||||
http_proxy=data.get("http_proxy"),
|
||||
https_proxy=data.get("https_proxy"),
|
||||
no_proxy=data.get("no_proxy", [])
|
||||
)
|
||||
|
||||
|
||||
class RateLimit:
|
||||
"""速率限制配置"""
|
||||
requests_per_minute: int = 60 # 每分钟请求数
|
||||
tokens_per_minute: int = 100000 # 每分钟token数
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
requests_per_minute: int = 60,
|
||||
tokens_per_minute: int = 100000
|
||||
):
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.tokens_per_minute = tokens_per_minute
|
||||
|
||||
def to_dict(self) -> Dict[str, int]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"requests_per_minute": self.requests_per_minute,
|
||||
"tokens_per_minute": self.tokens_per_minute
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, int]) -> "RateLimit":
|
||||
"""从字典创建"""
|
||||
if not data:
|
||||
return cls()
|
||||
return cls(
|
||||
requests_per_minute=data.get("requests_per_minute", 60),
|
||||
tokens_per_minute=data.get("tokens_per_minute", 100000)
|
||||
)
|
||||
|
||||
|
||||
class AIProvider(Document):
|
||||
"""
|
||||
AI接口提供商文档模型
|
||||
存储各AI服务的配置信息
|
||||
"""
|
||||
provider_id: str = Field(..., description="唯一标识")
|
||||
provider_type: str = Field(..., description="提供商类型: minimax, zhipu等")
|
||||
name: str = Field(..., description="自定义名称")
|
||||
api_key: str = Field(default="", description="API密钥(加密存储)")
|
||||
base_url: str = Field(default="", description="API基础URL")
|
||||
model: str = Field(..., description="使用的模型名称")
|
||||
|
||||
# 代理配置
|
||||
use_proxy: bool = Field(default=False, description="是否使用代理")
|
||||
proxy_config: Dict[str, Any] = Field(default_factory=dict, description="代理配置")
|
||||
|
||||
# 速率限制
|
||||
rate_limit: Dict[str, int] = Field(
|
||||
default_factory=lambda: {"requests_per_minute": 60, "tokens_per_minute": 100000},
|
||||
description="速率限制配置"
|
||||
)
|
||||
|
||||
# 其他配置
|
||||
timeout: int = Field(default=60, description="超时时间(秒)")
|
||||
extra_params: Dict[str, Any] = Field(default_factory=dict, description="额外参数")
|
||||
|
||||
# 元数据
|
||||
enabled: bool = Field(default=True, description="是否启用")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
class Settings:
|
||||
name = "ai_providers"
|
||||
|
||||
def get_proxy_config(self) -> ProxyConfig:
|
||||
"""获取代理配置对象"""
|
||||
return ProxyConfig.from_dict(self.proxy_config)
|
||||
|
||||
def get_rate_limit(self) -> RateLimit:
|
||||
"""获取速率限制配置对象"""
|
||||
return RateLimit.from_dict(self.rate_limit)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"provider_id": "openrouter-gpt4",
|
||||
"provider_type": "openrouter",
|
||||
"name": "OpenRouter GPT-4",
|
||||
"api_key": "sk-xxx",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"model": "openai/gpt-4-turbo",
|
||||
"use_proxy": True,
|
||||
"proxy_config": {
|
||||
"http_proxy": "http://127.0.0.1:7890",
|
||||
"https_proxy": "http://127.0.0.1:7890"
|
||||
},
|
||||
"timeout": 60
|
||||
}
|
||||
}
|
||||
131
backend/models/chatroom.py
Normal file
131
backend/models/chatroom.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
聊天室数据模型
|
||||
定义讨论聊天室的配置结构
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from enum import Enum
|
||||
from pydantic import Field
|
||||
from beanie import Document
|
||||
|
||||
|
||||
class ChatRoomStatus(str, Enum):
|
||||
"""聊天室状态枚举"""
|
||||
IDLE = "idle" # 空闲,等待开始
|
||||
ACTIVE = "active" # 讨论进行中
|
||||
PAUSED = "paused" # 暂停
|
||||
COMPLETED = "completed" # 已完成
|
||||
ERROR = "error" # 出错
|
||||
|
||||
|
||||
class ChatRoomConfig:
|
||||
"""聊天室配置"""
|
||||
max_rounds: int = 50 # 最大轮数(备用终止条件)
|
||||
message_history_size: int = 20 # 上下文消息数
|
||||
consensus_threshold: float = 0.8 # 共识阈值
|
||||
round_interval: float = 1.0 # 轮次间隔(秒)
|
||||
allow_user_interrupt: bool = True # 允许用户中断
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_rounds: int = 50,
|
||||
message_history_size: int = 20,
|
||||
consensus_threshold: float = 0.8,
|
||||
round_interval: float = 1.0,
|
||||
allow_user_interrupt: bool = True
|
||||
):
|
||||
self.max_rounds = max_rounds
|
||||
self.message_history_size = message_history_size
|
||||
self.consensus_threshold = consensus_threshold
|
||||
self.round_interval = round_interval
|
||||
self.allow_user_interrupt = allow_user_interrupt
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"max_rounds": self.max_rounds,
|
||||
"message_history_size": self.message_history_size,
|
||||
"consensus_threshold": self.consensus_threshold,
|
||||
"round_interval": self.round_interval,
|
||||
"allow_user_interrupt": self.allow_user_interrupt
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ChatRoomConfig":
|
||||
"""从字典创建"""
|
||||
if not data:
|
||||
return cls()
|
||||
return cls(
|
||||
max_rounds=data.get("max_rounds", 50),
|
||||
message_history_size=data.get("message_history_size", 20),
|
||||
consensus_threshold=data.get("consensus_threshold", 0.8),
|
||||
round_interval=data.get("round_interval", 1.0),
|
||||
allow_user_interrupt=data.get("allow_user_interrupt", True)
|
||||
)
|
||||
|
||||
|
||||
class ChatRoom(Document):
|
||||
"""
|
||||
聊天室文档模型
|
||||
存储讨论聊天室的配置信息
|
||||
"""
|
||||
room_id: str = Field(..., description="唯一标识")
|
||||
name: str = Field(..., description="聊天室名称")
|
||||
description: str = Field(default="", description="描述")
|
||||
objective: str = Field(default="", description="当前讨论目标")
|
||||
|
||||
# 参与者
|
||||
agents: List[str] = Field(default_factory=list, description="Agent ID列表")
|
||||
moderator_agent_id: Optional[str] = Field(default=None, description="共识判断Agent ID")
|
||||
|
||||
# 配置
|
||||
config: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {
|
||||
"max_rounds": 50,
|
||||
"message_history_size": 20,
|
||||
"consensus_threshold": 0.8,
|
||||
"round_interval": 1.0,
|
||||
"allow_user_interrupt": True
|
||||
},
|
||||
description="聊天室配置"
|
||||
)
|
||||
|
||||
# 状态
|
||||
status: str = Field(default=ChatRoomStatus.IDLE.value, description="当前状态")
|
||||
current_round: int = Field(default=0, description="当前轮次")
|
||||
current_discussion_id: Optional[str] = Field(default=None, description="当前讨论ID")
|
||||
|
||||
# 元数据
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
completed_at: Optional[datetime] = Field(default=None, description="完成时间")
|
||||
|
||||
class Settings:
|
||||
name = "chatrooms"
|
||||
|
||||
def get_config(self) -> ChatRoomConfig:
|
||||
"""获取配置对象"""
|
||||
return ChatRoomConfig.from_dict(self.config)
|
||||
|
||||
def is_active(self) -> bool:
|
||||
"""检查聊天室是否处于活动状态"""
|
||||
return self.status == ChatRoomStatus.ACTIVE.value
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"room_id": "product-design-room",
|
||||
"name": "产品设计讨论室",
|
||||
"description": "用于讨论新产品功能设计",
|
||||
"objective": "设计一个用户友好的登录系统",
|
||||
"agents": ["product-manager", "designer", "developer"],
|
||||
"moderator_agent_id": "moderator",
|
||||
"config": {
|
||||
"max_rounds": 50,
|
||||
"message_history_size": 20,
|
||||
"consensus_threshold": 0.8
|
||||
},
|
||||
"status": "idle",
|
||||
"current_round": 0
|
||||
}
|
||||
}
|
||||
126
backend/models/discussion_result.py
Normal file
126
backend/models/discussion_result.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""
|
||||
讨论结果数据模型
|
||||
定义讨论结果的结构
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from pydantic import Field
|
||||
from beanie import Document
|
||||
|
||||
|
||||
class DiscussionResult(Document):
|
||||
"""
|
||||
讨论结果文档模型
|
||||
存储讨论的最终结果
|
||||
"""
|
||||
discussion_id: str = Field(..., description="讨论唯一标识")
|
||||
room_id: str = Field(..., description="聊天室ID")
|
||||
objective: str = Field(..., description="讨论目标")
|
||||
|
||||
# 共识结果
|
||||
consensus_reached: bool = Field(default=False, description="是否达成共识")
|
||||
confidence: float = Field(default=0.0, ge=0, le=1, description="共识置信度")
|
||||
|
||||
# 结果摘要
|
||||
summary: str = Field(default="", description="讨论结果摘要")
|
||||
action_items: List[str] = Field(default_factory=list, description="行动项列表")
|
||||
unresolved_issues: List[str] = Field(default_factory=list, description="未解决的问题")
|
||||
key_decisions: List[str] = Field(default_factory=list, description="关键决策")
|
||||
|
||||
# 统计信息
|
||||
total_rounds: int = Field(default=0, description="总轮数")
|
||||
total_messages: int = Field(default=0, description="总消息数")
|
||||
participating_agents: List[str] = Field(default_factory=list, description="参与的Agent列表")
|
||||
agent_contributions: Dict[str, int] = Field(
|
||||
default_factory=dict,
|
||||
description="各Agent发言次数统计"
|
||||
)
|
||||
|
||||
# 状态
|
||||
status: str = Field(default="in_progress", description="状态: in_progress, completed, failed")
|
||||
end_reason: str = Field(default="", description="结束原因")
|
||||
|
||||
# 时间戳
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
completed_at: Optional[datetime] = Field(default=None, description="完成时间")
|
||||
|
||||
class Settings:
|
||||
name = "discussions"
|
||||
indexes = [
|
||||
[("room_id", 1)],
|
||||
[("created_at", -1)],
|
||||
]
|
||||
|
||||
def mark_completed(
|
||||
self,
|
||||
consensus_reached: bool,
|
||||
confidence: float,
|
||||
summary: str,
|
||||
action_items: List[str] = None,
|
||||
unresolved_issues: List[str] = None,
|
||||
end_reason: str = "consensus"
|
||||
) -> None:
|
||||
"""
|
||||
标记讨论为已完成
|
||||
|
||||
Args:
|
||||
consensus_reached: 是否达成共识
|
||||
confidence: 置信度
|
||||
summary: 结果摘要
|
||||
action_items: 行动项
|
||||
unresolved_issues: 未解决问题
|
||||
end_reason: 结束原因
|
||||
"""
|
||||
self.consensus_reached = consensus_reached
|
||||
self.confidence = confidence
|
||||
self.summary = summary
|
||||
self.action_items = action_items or []
|
||||
self.unresolved_issues = unresolved_issues or []
|
||||
self.status = "completed"
|
||||
self.end_reason = end_reason
|
||||
self.completed_at = datetime.utcnow()
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def update_stats(
|
||||
self,
|
||||
total_rounds: int,
|
||||
total_messages: int,
|
||||
agent_contributions: Dict[str, int]
|
||||
) -> None:
|
||||
"""
|
||||
更新统计信息
|
||||
|
||||
Args:
|
||||
total_rounds: 总轮数
|
||||
total_messages: 总消息数
|
||||
agent_contributions: Agent贡献统计
|
||||
"""
|
||||
self.total_rounds = total_rounds
|
||||
self.total_messages = total_messages
|
||||
self.agent_contributions = agent_contributions
|
||||
self.participating_agents = list(agent_contributions.keys())
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"discussion_id": "disc-001",
|
||||
"room_id": "product-design-room",
|
||||
"objective": "设计用户登录系统",
|
||||
"consensus_reached": True,
|
||||
"confidence": 0.85,
|
||||
"summary": "团队一致同意采用OAuth2.0 + 手机验证码的混合认证方案...",
|
||||
"action_items": [
|
||||
"设计OAuth2.0集成方案",
|
||||
"开发短信验证服务",
|
||||
"编写安全测试用例"
|
||||
],
|
||||
"unresolved_issues": [
|
||||
"第三方登录的优先级排序"
|
||||
],
|
||||
"total_rounds": 15,
|
||||
"total_messages": 45,
|
||||
"status": "completed"
|
||||
}
|
||||
}
|
||||
123
backend/models/message.py
Normal file
123
backend/models/message.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
消息数据模型
|
||||
定义聊天消息的结构
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from enum import Enum
|
||||
from pydantic import Field
|
||||
from beanie import Document
|
||||
|
||||
|
||||
class MessageType(str, Enum):
|
||||
"""消息类型枚举"""
|
||||
TEXT = "text" # 纯文本
|
||||
IMAGE = "image" # 图片
|
||||
FILE = "file" # 文件
|
||||
SYSTEM = "system" # 系统消息
|
||||
ACTION = "action" # 动作消息(如调用工具)
|
||||
|
||||
|
||||
class MessageAttachment:
|
||||
"""消息附件"""
|
||||
attachment_type: str # 附件类型: image, file
|
||||
url: str # 资源URL
|
||||
name: str # 文件名
|
||||
size: int = 0 # 文件大小(字节)
|
||||
mime_type: str = "" # MIME类型
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attachment_type: str,
|
||||
url: str,
|
||||
name: str,
|
||||
size: int = 0,
|
||||
mime_type: str = ""
|
||||
):
|
||||
self.attachment_type = attachment_type
|
||||
self.url = url
|
||||
self.name = name
|
||||
self.size = size
|
||||
self.mime_type = mime_type
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"attachment_type": self.attachment_type,
|
||||
"url": self.url,
|
||||
"name": self.name,
|
||||
"size": self.size,
|
||||
"mime_type": self.mime_type
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "MessageAttachment":
|
||||
"""从字典创建"""
|
||||
return cls(
|
||||
attachment_type=data.get("attachment_type", ""),
|
||||
url=data.get("url", ""),
|
||||
name=data.get("name", ""),
|
||||
size=data.get("size", 0),
|
||||
mime_type=data.get("mime_type", "")
|
||||
)
|
||||
|
||||
|
||||
class Message(Document):
|
||||
"""
|
||||
消息文档模型
|
||||
存储聊天消息
|
||||
"""
|
||||
message_id: str = Field(..., description="唯一标识")
|
||||
room_id: str = Field(..., description="聊天室ID")
|
||||
discussion_id: str = Field(..., description="讨论ID")
|
||||
agent_id: Optional[str] = Field(default=None, description="发送Agent ID(系统消息为空)")
|
||||
|
||||
# 消息内容
|
||||
content: str = Field(..., description="消息内容")
|
||||
message_type: str = Field(default=MessageType.TEXT.value, description="消息类型")
|
||||
attachments: List[Dict[str, Any]] = Field(default_factory=list, description="附件列表")
|
||||
|
||||
# 元数据
|
||||
round: int = Field(default=0, description="所属轮次")
|
||||
token_count: int = Field(default=0, description="token数量")
|
||||
|
||||
# 工具调用相关
|
||||
tool_calls: List[Dict[str, Any]] = Field(default_factory=list, description="工具调用记录")
|
||||
tool_results: List[Dict[str, Any]] = Field(default_factory=list, description="工具调用结果")
|
||||
|
||||
# 时间戳
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
class Settings:
|
||||
name = "messages"
|
||||
indexes = [
|
||||
[("room_id", 1), ("created_at", 1)],
|
||||
[("discussion_id", 1)],
|
||||
[("agent_id", 1)],
|
||||
]
|
||||
|
||||
def get_attachments(self) -> List[MessageAttachment]:
|
||||
"""获取附件对象列表"""
|
||||
return [MessageAttachment.from_dict(a) for a in self.attachments]
|
||||
|
||||
def is_from_agent(self, agent_id: str) -> bool:
|
||||
"""检查消息是否来自指定Agent"""
|
||||
return self.agent_id == agent_id
|
||||
|
||||
def is_system_message(self) -> bool:
|
||||
"""检查是否为系统消息"""
|
||||
return self.message_type == MessageType.SYSTEM.value
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"message_id": "msg-001",
|
||||
"room_id": "product-design-room",
|
||||
"discussion_id": "disc-001",
|
||||
"agent_id": "product-manager",
|
||||
"content": "我认为登录系统应该支持多种认证方式...",
|
||||
"message_type": "text",
|
||||
"round": 1,
|
||||
"token_count": 150
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user