feat: AI聊天室多Agent协作讨论平台
- 实现Agent管理,支持AI辅助生成系统提示词 - 支持多个AI提供商(OpenRouter、智谱、MiniMax等) - 实现聊天室和讨论引擎 - WebSocket实时消息推送 - 前端使用React + Ant Design - 后端使用FastAPI + MongoDB Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
416
backend/services/memory_service.py
Normal file
416
backend/services/memory_service.py
Normal file
@@ -0,0 +1,416 @@
|
||||
"""
|
||||
记忆服务
|
||||
管理Agent的记忆存储和检索
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any, Optional
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from models.agent_memory import AgentMemory, MemoryType
|
||||
|
||||
|
||||
class MemoryService:
|
||||
"""
|
||||
Agent记忆服务
|
||||
提供记忆的存储、检索和管理功能
|
||||
"""
|
||||
|
||||
# 嵌入模型(延迟加载)
|
||||
_embedding_model = None
|
||||
|
||||
@classmethod
|
||||
def _get_embedding_model(cls):
|
||||
"""
|
||||
获取嵌入模型实例(延迟加载)
|
||||
"""
|
||||
if cls._embedding_model is None:
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
cls._embedding_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
|
||||
logger.info("嵌入模型加载成功")
|
||||
except Exception as e:
|
||||
logger.warning(f"嵌入模型加载失败: {e}")
|
||||
return None
|
||||
return cls._embedding_model
|
||||
|
||||
@classmethod
|
||||
async def create_memory(
|
||||
cls,
|
||||
agent_id: str,
|
||||
content: str,
|
||||
memory_type: str = MemoryType.SHORT_TERM.value,
|
||||
importance: float = 0.5,
|
||||
source_room_id: Optional[str] = None,
|
||||
source_discussion_id: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
expires_in_hours: Optional[int] = None
|
||||
) -> AgentMemory:
|
||||
"""
|
||||
创建新的记忆
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
content: 记忆内容
|
||||
memory_type: 记忆类型
|
||||
importance: 重要性评分
|
||||
source_room_id: 来源聊天室
|
||||
source_discussion_id: 来源讨论
|
||||
tags: 标签
|
||||
expires_in_hours: 过期时间(小时)
|
||||
|
||||
Returns:
|
||||
创建的AgentMemory文档
|
||||
"""
|
||||
memory_id = f"mem-{uuid.uuid4().hex[:12]}"
|
||||
|
||||
# 生成向量嵌入
|
||||
embedding = await cls._generate_embedding(content)
|
||||
|
||||
# 生成摘要
|
||||
summary = content[:100] + "..." if len(content) > 100 else content
|
||||
|
||||
# 计算过期时间
|
||||
expires_at = None
|
||||
if expires_in_hours:
|
||||
expires_at = datetime.utcnow() + timedelta(hours=expires_in_hours)
|
||||
|
||||
memory = AgentMemory(
|
||||
memory_id=memory_id,
|
||||
agent_id=agent_id,
|
||||
memory_type=memory_type,
|
||||
content=content,
|
||||
summary=summary,
|
||||
embedding=embedding,
|
||||
importance=importance,
|
||||
source_room_id=source_room_id,
|
||||
source_discussion_id=source_discussion_id,
|
||||
tags=tags or [],
|
||||
created_at=datetime.utcnow(),
|
||||
last_accessed=datetime.utcnow(),
|
||||
expires_at=expires_at
|
||||
)
|
||||
|
||||
await memory.insert()
|
||||
|
||||
logger.debug(f"创建记忆: {memory_id} for Agent {agent_id}")
|
||||
return memory
|
||||
|
||||
@classmethod
|
||||
async def get_memory(cls, memory_id: str) -> Optional[AgentMemory]:
|
||||
"""
|
||||
获取指定记忆
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
|
||||
Returns:
|
||||
AgentMemory文档或None
|
||||
"""
|
||||
return await AgentMemory.find_one(AgentMemory.memory_id == memory_id)
|
||||
|
||||
@classmethod
|
||||
async def get_agent_memories(
|
||||
cls,
|
||||
agent_id: str,
|
||||
memory_type: Optional[str] = None,
|
||||
limit: int = 50
|
||||
) -> List[AgentMemory]:
|
||||
"""
|
||||
获取Agent的记忆列表
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
memory_type: 记忆类型(可选)
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
记忆列表
|
||||
"""
|
||||
query = {"agent_id": agent_id}
|
||||
if memory_type:
|
||||
query["memory_type"] = memory_type
|
||||
|
||||
return await AgentMemory.find(query).sort(
|
||||
"-importance", "-last_accessed"
|
||||
).limit(limit).to_list()
|
||||
|
||||
@classmethod
|
||||
async def search_memories(
|
||||
cls,
|
||||
agent_id: str,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
memory_type: Optional[str] = None,
|
||||
min_relevance: float = 0.3
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
搜索相关记忆
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
query: 查询文本
|
||||
limit: 返回数量
|
||||
memory_type: 记忆类型(可选)
|
||||
min_relevance: 最小相关性阈值
|
||||
|
||||
Returns:
|
||||
带相关性分数的记忆列表
|
||||
"""
|
||||
# 生成查询向量
|
||||
query_embedding = await cls._generate_embedding(query)
|
||||
if not query_embedding:
|
||||
# 无法生成向量时,使用文本匹配
|
||||
return await cls._text_search(agent_id, query, limit, memory_type)
|
||||
|
||||
# 获取Agent的所有记忆
|
||||
filter_query = {"agent_id": agent_id}
|
||||
if memory_type:
|
||||
filter_query["memory_type"] = memory_type
|
||||
|
||||
memories = await AgentMemory.find(filter_query).to_list()
|
||||
|
||||
# 计算相似度
|
||||
results = []
|
||||
for memory in memories:
|
||||
if memory.is_expired():
|
||||
continue
|
||||
|
||||
if memory.embedding:
|
||||
similarity = cls._cosine_similarity(query_embedding, memory.embedding)
|
||||
relevance = memory.calculate_relevance_score(similarity)
|
||||
|
||||
if relevance >= min_relevance:
|
||||
results.append({
|
||||
"memory": memory,
|
||||
"similarity": similarity,
|
||||
"relevance": relevance
|
||||
})
|
||||
|
||||
# 按相关性排序
|
||||
results.sort(key=lambda x: x["relevance"], reverse=True)
|
||||
|
||||
# 更新访问记录
|
||||
for item in results[:limit]:
|
||||
memory = item["memory"]
|
||||
memory.access()
|
||||
await memory.save()
|
||||
|
||||
return results[:limit]
|
||||
|
||||
@classmethod
|
||||
async def update_memory(
|
||||
cls,
|
||||
memory_id: str,
|
||||
**kwargs
|
||||
) -> Optional[AgentMemory]:
|
||||
"""
|
||||
更新记忆
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
**kwargs: 要更新的字段
|
||||
|
||||
Returns:
|
||||
更新后的AgentMemory或None
|
||||
"""
|
||||
memory = await cls.get_memory(memory_id)
|
||||
if not memory:
|
||||
return None
|
||||
|
||||
# 如果更新了内容,重新生成嵌入
|
||||
if "content" in kwargs:
|
||||
kwargs["embedding"] = await cls._generate_embedding(kwargs["content"])
|
||||
kwargs["summary"] = kwargs["content"][:100] + "..." if len(kwargs["content"]) > 100 else kwargs["content"]
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(memory, key):
|
||||
setattr(memory, key, value)
|
||||
|
||||
await memory.save()
|
||||
return memory
|
||||
|
||||
@classmethod
|
||||
async def delete_memory(cls, memory_id: str) -> bool:
|
||||
"""
|
||||
删除记忆
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
memory = await cls.get_memory(memory_id)
|
||||
if not memory:
|
||||
return False
|
||||
|
||||
await memory.delete()
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def delete_agent_memories(
|
||||
cls,
|
||||
agent_id: str,
|
||||
memory_type: Optional[str] = None
|
||||
) -> int:
|
||||
"""
|
||||
删除Agent的记忆
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
memory_type: 记忆类型(可选)
|
||||
|
||||
Returns:
|
||||
删除的数量
|
||||
"""
|
||||
query = {"agent_id": agent_id}
|
||||
if memory_type:
|
||||
query["memory_type"] = memory_type
|
||||
|
||||
result = await AgentMemory.find(query).delete()
|
||||
return result.deleted_count if result else 0
|
||||
|
||||
@classmethod
|
||||
async def cleanup_expired_memories(cls) -> int:
|
||||
"""
|
||||
清理过期的记忆
|
||||
|
||||
Returns:
|
||||
清理的数量
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
result = await AgentMemory.find(
|
||||
{"expires_at": {"$lt": now}}
|
||||
).delete()
|
||||
|
||||
count = result.deleted_count if result else 0
|
||||
if count > 0:
|
||||
logger.info(f"清理了 {count} 条过期记忆")
|
||||
|
||||
return count
|
||||
|
||||
@classmethod
|
||||
async def consolidate_memories(
|
||||
cls,
|
||||
agent_id: str,
|
||||
min_importance: float = 0.7,
|
||||
max_age_days: int = 30
|
||||
) -> None:
|
||||
"""
|
||||
整合记忆(将重要的短期记忆转为长期记忆)
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
min_importance: 最小重要性阈值
|
||||
max_age_days: 最大年龄(天)
|
||||
"""
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=max_age_days)
|
||||
|
||||
# 查找符合条件的短期记忆
|
||||
memories = await AgentMemory.find({
|
||||
"agent_id": agent_id,
|
||||
"memory_type": MemoryType.SHORT_TERM.value,
|
||||
"importance": {"$gte": min_importance},
|
||||
"created_at": {"$lt": cutoff_date}
|
||||
}).to_list()
|
||||
|
||||
for memory in memories:
|
||||
memory.memory_type = MemoryType.LONG_TERM.value
|
||||
memory.expires_at = None # 长期记忆不过期
|
||||
await memory.save()
|
||||
|
||||
if memories:
|
||||
logger.info(f"整合了 {len(memories)} 条记忆为长期记忆: Agent {agent_id}")
|
||||
|
||||
@classmethod
|
||||
async def _generate_embedding(cls, text: str) -> List[float]:
|
||||
"""
|
||||
生成文本的向量嵌入
|
||||
|
||||
Args:
|
||||
text: 文本内容
|
||||
|
||||
Returns:
|
||||
向量嵌入列表
|
||||
"""
|
||||
model = cls._get_embedding_model()
|
||||
if model is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
embedding = model.encode(text, convert_to_numpy=True)
|
||||
return embedding.tolist()
|
||||
except Exception as e:
|
||||
logger.warning(f"生成嵌入失败: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _cosine_similarity(cls, vec1: List[float], vec2: List[float]) -> float:
|
||||
"""
|
||||
计算余弦相似度
|
||||
|
||||
Args:
|
||||
vec1: 向量1
|
||||
vec2: 向量2
|
||||
|
||||
Returns:
|
||||
相似度 (0-1)
|
||||
"""
|
||||
if not vec1 or not vec2:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
a = np.array(vec1)
|
||||
b = np.array(vec2)
|
||||
similarity = np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
||||
return float(max(0, similarity))
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
async def _text_search(
|
||||
cls,
|
||||
agent_id: str,
|
||||
query: str,
|
||||
limit: int,
|
||||
memory_type: Optional[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
文本搜索(后备方案)
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
query: 查询文本
|
||||
limit: 返回数量
|
||||
memory_type: 记忆类型
|
||||
|
||||
Returns:
|
||||
记忆列表
|
||||
"""
|
||||
filter_query = {"agent_id": agent_id}
|
||||
if memory_type:
|
||||
filter_query["memory_type"] = memory_type
|
||||
|
||||
# 简单的文本匹配
|
||||
memories = await AgentMemory.find(filter_query).to_list()
|
||||
|
||||
results = []
|
||||
query_lower = query.lower()
|
||||
for memory in memories:
|
||||
if memory.is_expired():
|
||||
continue
|
||||
|
||||
content_lower = memory.content.lower()
|
||||
if query_lower in content_lower:
|
||||
# 计算简单的匹配分数
|
||||
score = len(query_lower) / len(content_lower)
|
||||
results.append({
|
||||
"memory": memory,
|
||||
"similarity": score,
|
||||
"relevance": score * memory.importance
|
||||
})
|
||||
|
||||
results.sort(key=lambda x: x["relevance"], reverse=True)
|
||||
return results[:limit]
|
||||
Reference in New Issue
Block a user