417 lines
12 KiB
Python
417 lines
12 KiB
Python
|
|
"""
|
|||
|
|
记忆服务
|
|||
|
|
管理Agent的记忆存储和检索
|
|||
|
|
"""
|
|||
|
|
import uuid
|
|||
|
|
from datetime import datetime, timedelta
|
|||
|
|
from typing import List, Dict, Any, Optional
|
|||
|
|
import numpy as np
|
|||
|
|
from loguru import logger
|
|||
|
|
|
|||
|
|
from models.agent_memory import AgentMemory, MemoryType
|
|||
|
|
|
|||
|
|
|
|||
|
|
class MemoryService:
|
|||
|
|
"""
|
|||
|
|
Agent记忆服务
|
|||
|
|
提供记忆的存储、检索和管理功能
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 嵌入模型(延迟加载)
|
|||
|
|
_embedding_model = None
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def _get_embedding_model(cls):
|
|||
|
|
"""
|
|||
|
|
获取嵌入模型实例(延迟加载)
|
|||
|
|
"""
|
|||
|
|
if cls._embedding_model is None:
|
|||
|
|
try:
|
|||
|
|
from sentence_transformers import SentenceTransformer
|
|||
|
|
cls._embedding_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
|
|||
|
|
logger.info("嵌入模型加载成功")
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"嵌入模型加载失败: {e}")
|
|||
|
|
return None
|
|||
|
|
return cls._embedding_model
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def create_memory(
|
|||
|
|
cls,
|
|||
|
|
agent_id: str,
|
|||
|
|
content: str,
|
|||
|
|
memory_type: str = MemoryType.SHORT_TERM.value,
|
|||
|
|
importance: float = 0.5,
|
|||
|
|
source_room_id: Optional[str] = None,
|
|||
|
|
source_discussion_id: Optional[str] = None,
|
|||
|
|
tags: Optional[List[str]] = None,
|
|||
|
|
expires_in_hours: Optional[int] = None
|
|||
|
|
) -> AgentMemory:
|
|||
|
|
"""
|
|||
|
|
创建新的记忆
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
agent_id: Agent ID
|
|||
|
|
content: 记忆内容
|
|||
|
|
memory_type: 记忆类型
|
|||
|
|
importance: 重要性评分
|
|||
|
|
source_room_id: 来源聊天室
|
|||
|
|
source_discussion_id: 来源讨论
|
|||
|
|
tags: 标签
|
|||
|
|
expires_in_hours: 过期时间(小时)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
创建的AgentMemory文档
|
|||
|
|
"""
|
|||
|
|
memory_id = f"mem-{uuid.uuid4().hex[:12]}"
|
|||
|
|
|
|||
|
|
# 生成向量嵌入
|
|||
|
|
embedding = await cls._generate_embedding(content)
|
|||
|
|
|
|||
|
|
# 生成摘要
|
|||
|
|
summary = content[:100] + "..." if len(content) > 100 else content
|
|||
|
|
|
|||
|
|
# 计算过期时间
|
|||
|
|
expires_at = None
|
|||
|
|
if expires_in_hours:
|
|||
|
|
expires_at = datetime.utcnow() + timedelta(hours=expires_in_hours)
|
|||
|
|
|
|||
|
|
memory = AgentMemory(
|
|||
|
|
memory_id=memory_id,
|
|||
|
|
agent_id=agent_id,
|
|||
|
|
memory_type=memory_type,
|
|||
|
|
content=content,
|
|||
|
|
summary=summary,
|
|||
|
|
embedding=embedding,
|
|||
|
|
importance=importance,
|
|||
|
|
source_room_id=source_room_id,
|
|||
|
|
source_discussion_id=source_discussion_id,
|
|||
|
|
tags=tags or [],
|
|||
|
|
created_at=datetime.utcnow(),
|
|||
|
|
last_accessed=datetime.utcnow(),
|
|||
|
|
expires_at=expires_at
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
await memory.insert()
|
|||
|
|
|
|||
|
|
logger.debug(f"创建记忆: {memory_id} for Agent {agent_id}")
|
|||
|
|
return memory
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def get_memory(cls, memory_id: str) -> Optional[AgentMemory]:
|
|||
|
|
"""
|
|||
|
|
获取指定记忆
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
memory_id: 记忆ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
AgentMemory文档或None
|
|||
|
|
"""
|
|||
|
|
return await AgentMemory.find_one(AgentMemory.memory_id == memory_id)
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def get_agent_memories(
|
|||
|
|
cls,
|
|||
|
|
agent_id: str,
|
|||
|
|
memory_type: Optional[str] = None,
|
|||
|
|
limit: int = 50
|
|||
|
|
) -> List[AgentMemory]:
|
|||
|
|
"""
|
|||
|
|
获取Agent的记忆列表
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
agent_id: Agent ID
|
|||
|
|
memory_type: 记忆类型(可选)
|
|||
|
|
limit: 返回数量限制
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
记忆列表
|
|||
|
|
"""
|
|||
|
|
query = {"agent_id": agent_id}
|
|||
|
|
if memory_type:
|
|||
|
|
query["memory_type"] = memory_type
|
|||
|
|
|
|||
|
|
return await AgentMemory.find(query).sort(
|
|||
|
|
"-importance", "-last_accessed"
|
|||
|
|
).limit(limit).to_list()
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def search_memories(
|
|||
|
|
cls,
|
|||
|
|
agent_id: str,
|
|||
|
|
query: str,
|
|||
|
|
limit: int = 10,
|
|||
|
|
memory_type: Optional[str] = None,
|
|||
|
|
min_relevance: float = 0.3
|
|||
|
|
) -> List[Dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
搜索相关记忆
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
agent_id: Agent ID
|
|||
|
|
query: 查询文本
|
|||
|
|
limit: 返回数量
|
|||
|
|
memory_type: 记忆类型(可选)
|
|||
|
|
min_relevance: 最小相关性阈值
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
带相关性分数的记忆列表
|
|||
|
|
"""
|
|||
|
|
# 生成查询向量
|
|||
|
|
query_embedding = await cls._generate_embedding(query)
|
|||
|
|
if not query_embedding:
|
|||
|
|
# 无法生成向量时,使用文本匹配
|
|||
|
|
return await cls._text_search(agent_id, query, limit, memory_type)
|
|||
|
|
|
|||
|
|
# 获取Agent的所有记忆
|
|||
|
|
filter_query = {"agent_id": agent_id}
|
|||
|
|
if memory_type:
|
|||
|
|
filter_query["memory_type"] = memory_type
|
|||
|
|
|
|||
|
|
memories = await AgentMemory.find(filter_query).to_list()
|
|||
|
|
|
|||
|
|
# 计算相似度
|
|||
|
|
results = []
|
|||
|
|
for memory in memories:
|
|||
|
|
if memory.is_expired():
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
if memory.embedding:
|
|||
|
|
similarity = cls._cosine_similarity(query_embedding, memory.embedding)
|
|||
|
|
relevance = memory.calculate_relevance_score(similarity)
|
|||
|
|
|
|||
|
|
if relevance >= min_relevance:
|
|||
|
|
results.append({
|
|||
|
|
"memory": memory,
|
|||
|
|
"similarity": similarity,
|
|||
|
|
"relevance": relevance
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# 按相关性排序
|
|||
|
|
results.sort(key=lambda x: x["relevance"], reverse=True)
|
|||
|
|
|
|||
|
|
# 更新访问记录
|
|||
|
|
for item in results[:limit]:
|
|||
|
|
memory = item["memory"]
|
|||
|
|
memory.access()
|
|||
|
|
await memory.save()
|
|||
|
|
|
|||
|
|
return results[:limit]
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def update_memory(
|
|||
|
|
cls,
|
|||
|
|
memory_id: str,
|
|||
|
|
**kwargs
|
|||
|
|
) -> Optional[AgentMemory]:
|
|||
|
|
"""
|
|||
|
|
更新记忆
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
memory_id: 记忆ID
|
|||
|
|
**kwargs: 要更新的字段
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的AgentMemory或None
|
|||
|
|
"""
|
|||
|
|
memory = await cls.get_memory(memory_id)
|
|||
|
|
if not memory:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 如果更新了内容,重新生成嵌入
|
|||
|
|
if "content" in kwargs:
|
|||
|
|
kwargs["embedding"] = await cls._generate_embedding(kwargs["content"])
|
|||
|
|
kwargs["summary"] = kwargs["content"][:100] + "..." if len(kwargs["content"]) > 100 else kwargs["content"]
|
|||
|
|
|
|||
|
|
for key, value in kwargs.items():
|
|||
|
|
if hasattr(memory, key):
|
|||
|
|
setattr(memory, key, value)
|
|||
|
|
|
|||
|
|
await memory.save()
|
|||
|
|
return memory
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def delete_memory(cls, memory_id: str) -> bool:
|
|||
|
|
"""
|
|||
|
|
删除记忆
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
memory_id: 记忆ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
是否删除成功
|
|||
|
|
"""
|
|||
|
|
memory = await cls.get_memory(memory_id)
|
|||
|
|
if not memory:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
await memory.delete()
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def delete_agent_memories(
|
|||
|
|
cls,
|
|||
|
|
agent_id: str,
|
|||
|
|
memory_type: Optional[str] = None
|
|||
|
|
) -> int:
|
|||
|
|
"""
|
|||
|
|
删除Agent的记忆
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
agent_id: Agent ID
|
|||
|
|
memory_type: 记忆类型(可选)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
删除的数量
|
|||
|
|
"""
|
|||
|
|
query = {"agent_id": agent_id}
|
|||
|
|
if memory_type:
|
|||
|
|
query["memory_type"] = memory_type
|
|||
|
|
|
|||
|
|
result = await AgentMemory.find(query).delete()
|
|||
|
|
return result.deleted_count if result else 0
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def cleanup_expired_memories(cls) -> int:
|
|||
|
|
"""
|
|||
|
|
清理过期的记忆
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
清理的数量
|
|||
|
|
"""
|
|||
|
|
now = datetime.utcnow()
|
|||
|
|
result = await AgentMemory.find(
|
|||
|
|
{"expires_at": {"$lt": now}}
|
|||
|
|
).delete()
|
|||
|
|
|
|||
|
|
count = result.deleted_count if result else 0
|
|||
|
|
if count > 0:
|
|||
|
|
logger.info(f"清理了 {count} 条过期记忆")
|
|||
|
|
|
|||
|
|
return count
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def consolidate_memories(
|
|||
|
|
cls,
|
|||
|
|
agent_id: str,
|
|||
|
|
min_importance: float = 0.7,
|
|||
|
|
max_age_days: int = 30
|
|||
|
|
) -> None:
|
|||
|
|
"""
|
|||
|
|
整合记忆(将重要的短期记忆转为长期记忆)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
agent_id: Agent ID
|
|||
|
|
min_importance: 最小重要性阈值
|
|||
|
|
max_age_days: 最大年龄(天)
|
|||
|
|
"""
|
|||
|
|
cutoff_date = datetime.utcnow() - timedelta(days=max_age_days)
|
|||
|
|
|
|||
|
|
# 查找符合条件的短期记忆
|
|||
|
|
memories = await AgentMemory.find({
|
|||
|
|
"agent_id": agent_id,
|
|||
|
|
"memory_type": MemoryType.SHORT_TERM.value,
|
|||
|
|
"importance": {"$gte": min_importance},
|
|||
|
|
"created_at": {"$lt": cutoff_date}
|
|||
|
|
}).to_list()
|
|||
|
|
|
|||
|
|
for memory in memories:
|
|||
|
|
memory.memory_type = MemoryType.LONG_TERM.value
|
|||
|
|
memory.expires_at = None # 长期记忆不过期
|
|||
|
|
await memory.save()
|
|||
|
|
|
|||
|
|
if memories:
|
|||
|
|
logger.info(f"整合了 {len(memories)} 条记忆为长期记忆: Agent {agent_id}")
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def _generate_embedding(cls, text: str) -> List[float]:
|
|||
|
|
"""
|
|||
|
|
生成文本的向量嵌入
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
text: 文本内容
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
向量嵌入列表
|
|||
|
|
"""
|
|||
|
|
model = cls._get_embedding_model()
|
|||
|
|
if model is None:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
embedding = model.encode(text, convert_to_numpy=True)
|
|||
|
|
return embedding.tolist()
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"生成嵌入失败: {e}")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def _cosine_similarity(cls, vec1: List[float], vec2: List[float]) -> float:
|
|||
|
|
"""
|
|||
|
|
计算余弦相似度
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
vec1: 向量1
|
|||
|
|
vec2: 向量2
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
相似度 (0-1)
|
|||
|
|
"""
|
|||
|
|
if not vec1 or not vec2:
|
|||
|
|
return 0.0
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
a = np.array(vec1)
|
|||
|
|
b = np.array(vec2)
|
|||
|
|
similarity = np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
|||
|
|
return float(max(0, similarity))
|
|||
|
|
except Exception:
|
|||
|
|
return 0.0
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
async def _text_search(
|
|||
|
|
cls,
|
|||
|
|
agent_id: str,
|
|||
|
|
query: str,
|
|||
|
|
limit: int,
|
|||
|
|
memory_type: Optional[str]
|
|||
|
|
) -> List[Dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
文本搜索(后备方案)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
agent_id: Agent ID
|
|||
|
|
query: 查询文本
|
|||
|
|
limit: 返回数量
|
|||
|
|
memory_type: 记忆类型
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
记忆列表
|
|||
|
|
"""
|
|||
|
|
filter_query = {"agent_id": agent_id}
|
|||
|
|
if memory_type:
|
|||
|
|
filter_query["memory_type"] = memory_type
|
|||
|
|
|
|||
|
|
# 简单的文本匹配
|
|||
|
|
memories = await AgentMemory.find(filter_query).to_list()
|
|||
|
|
|
|||
|
|
results = []
|
|||
|
|
query_lower = query.lower()
|
|||
|
|
for memory in memories:
|
|||
|
|
if memory.is_expired():
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
content_lower = memory.content.lower()
|
|||
|
|
if query_lower in content_lower:
|
|||
|
|
# 计算简单的匹配分数
|
|||
|
|
score = len(query_lower) / len(content_lower)
|
|||
|
|
results.append({
|
|||
|
|
"memory": memory,
|
|||
|
|
"similarity": score,
|
|||
|
|
"relevance": score * memory.importance
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
results.sort(key=lambda x: x["relevance"], reverse=True)
|
|||
|
|
return results[:limit]
|