Files
AIChatRoom/backend/services/memory_service.py

417 lines
12 KiB
Python
Raw Normal View History

"""
记忆服务
管理Agent的记忆存储和检索
"""
import uuid
from datetime import datetime, timedelta
from typing import List, Dict, Any, Optional
import numpy as np
from loguru import logger
from models.agent_memory import AgentMemory, MemoryType
class MemoryService:
"""
Agent记忆服务
提供记忆的存储检索和管理功能
"""
# 嵌入模型(延迟加载)
_embedding_model = None
@classmethod
def _get_embedding_model(cls):
"""
获取嵌入模型实例延迟加载
"""
if cls._embedding_model is None:
try:
from sentence_transformers import SentenceTransformer
cls._embedding_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
logger.info("嵌入模型加载成功")
except Exception as e:
logger.warning(f"嵌入模型加载失败: {e}")
return None
return cls._embedding_model
@classmethod
async def create_memory(
cls,
agent_id: str,
content: str,
memory_type: str = MemoryType.SHORT_TERM.value,
importance: float = 0.5,
source_room_id: Optional[str] = None,
source_discussion_id: Optional[str] = None,
tags: Optional[List[str]] = None,
expires_in_hours: Optional[int] = None
) -> AgentMemory:
"""
创建新的记忆
Args:
agent_id: Agent ID
content: 记忆内容
memory_type: 记忆类型
importance: 重要性评分
source_room_id: 来源聊天室
source_discussion_id: 来源讨论
tags: 标签
expires_in_hours: 过期时间小时
Returns:
创建的AgentMemory文档
"""
memory_id = f"mem-{uuid.uuid4().hex[:12]}"
# 生成向量嵌入
embedding = await cls._generate_embedding(content)
# 生成摘要
summary = content[:100] + "..." if len(content) > 100 else content
# 计算过期时间
expires_at = None
if expires_in_hours:
expires_at = datetime.utcnow() + timedelta(hours=expires_in_hours)
memory = AgentMemory(
memory_id=memory_id,
agent_id=agent_id,
memory_type=memory_type,
content=content,
summary=summary,
embedding=embedding,
importance=importance,
source_room_id=source_room_id,
source_discussion_id=source_discussion_id,
tags=tags or [],
created_at=datetime.utcnow(),
last_accessed=datetime.utcnow(),
expires_at=expires_at
)
await memory.insert()
logger.debug(f"创建记忆: {memory_id} for Agent {agent_id}")
return memory
@classmethod
async def get_memory(cls, memory_id: str) -> Optional[AgentMemory]:
"""
获取指定记忆
Args:
memory_id: 记忆ID
Returns:
AgentMemory文档或None
"""
return await AgentMemory.find_one(AgentMemory.memory_id == memory_id)
@classmethod
async def get_agent_memories(
cls,
agent_id: str,
memory_type: Optional[str] = None,
limit: int = 50
) -> List[AgentMemory]:
"""
获取Agent的记忆列表
Args:
agent_id: Agent ID
memory_type: 记忆类型可选
limit: 返回数量限制
Returns:
记忆列表
"""
query = {"agent_id": agent_id}
if memory_type:
query["memory_type"] = memory_type
return await AgentMemory.find(query).sort(
"-importance", "-last_accessed"
).limit(limit).to_list()
@classmethod
async def search_memories(
cls,
agent_id: str,
query: str,
limit: int = 10,
memory_type: Optional[str] = None,
min_relevance: float = 0.3
) -> List[Dict[str, Any]]:
"""
搜索相关记忆
Args:
agent_id: Agent ID
query: 查询文本
limit: 返回数量
memory_type: 记忆类型可选
min_relevance: 最小相关性阈值
Returns:
带相关性分数的记忆列表
"""
# 生成查询向量
query_embedding = await cls._generate_embedding(query)
if not query_embedding:
# 无法生成向量时,使用文本匹配
return await cls._text_search(agent_id, query, limit, memory_type)
# 获取Agent的所有记忆
filter_query = {"agent_id": agent_id}
if memory_type:
filter_query["memory_type"] = memory_type
memories = await AgentMemory.find(filter_query).to_list()
# 计算相似度
results = []
for memory in memories:
if memory.is_expired():
continue
if memory.embedding:
similarity = cls._cosine_similarity(query_embedding, memory.embedding)
relevance = memory.calculate_relevance_score(similarity)
if relevance >= min_relevance:
results.append({
"memory": memory,
"similarity": similarity,
"relevance": relevance
})
# 按相关性排序
results.sort(key=lambda x: x["relevance"], reverse=True)
# 更新访问记录
for item in results[:limit]:
memory = item["memory"]
memory.access()
await memory.save()
return results[:limit]
@classmethod
async def update_memory(
cls,
memory_id: str,
**kwargs
) -> Optional[AgentMemory]:
"""
更新记忆
Args:
memory_id: 记忆ID
**kwargs: 要更新的字段
Returns:
更新后的AgentMemory或None
"""
memory = await cls.get_memory(memory_id)
if not memory:
return None
# 如果更新了内容,重新生成嵌入
if "content" in kwargs:
kwargs["embedding"] = await cls._generate_embedding(kwargs["content"])
kwargs["summary"] = kwargs["content"][:100] + "..." if len(kwargs["content"]) > 100 else kwargs["content"]
for key, value in kwargs.items():
if hasattr(memory, key):
setattr(memory, key, value)
await memory.save()
return memory
@classmethod
async def delete_memory(cls, memory_id: str) -> bool:
"""
删除记忆
Args:
memory_id: 记忆ID
Returns:
是否删除成功
"""
memory = await cls.get_memory(memory_id)
if not memory:
return False
await memory.delete()
return True
@classmethod
async def delete_agent_memories(
cls,
agent_id: str,
memory_type: Optional[str] = None
) -> int:
"""
删除Agent的记忆
Args:
agent_id: Agent ID
memory_type: 记忆类型可选
Returns:
删除的数量
"""
query = {"agent_id": agent_id}
if memory_type:
query["memory_type"] = memory_type
result = await AgentMemory.find(query).delete()
return result.deleted_count if result else 0
@classmethod
async def cleanup_expired_memories(cls) -> int:
"""
清理过期的记忆
Returns:
清理的数量
"""
now = datetime.utcnow()
result = await AgentMemory.find(
{"expires_at": {"$lt": now}}
).delete()
count = result.deleted_count if result else 0
if count > 0:
logger.info(f"清理了 {count} 条过期记忆")
return count
@classmethod
async def consolidate_memories(
cls,
agent_id: str,
min_importance: float = 0.7,
max_age_days: int = 30
) -> None:
"""
整合记忆将重要的短期记忆转为长期记忆
Args:
agent_id: Agent ID
min_importance: 最小重要性阈值
max_age_days: 最大年龄
"""
cutoff_date = datetime.utcnow() - timedelta(days=max_age_days)
# 查找符合条件的短期记忆
memories = await AgentMemory.find({
"agent_id": agent_id,
"memory_type": MemoryType.SHORT_TERM.value,
"importance": {"$gte": min_importance},
"created_at": {"$lt": cutoff_date}
}).to_list()
for memory in memories:
memory.memory_type = MemoryType.LONG_TERM.value
memory.expires_at = None # 长期记忆不过期
await memory.save()
if memories:
logger.info(f"整合了 {len(memories)} 条记忆为长期记忆: Agent {agent_id}")
@classmethod
async def _generate_embedding(cls, text: str) -> List[float]:
"""
生成文本的向量嵌入
Args:
text: 文本内容
Returns:
向量嵌入列表
"""
model = cls._get_embedding_model()
if model is None:
return []
try:
embedding = model.encode(text, convert_to_numpy=True)
return embedding.tolist()
except Exception as e:
logger.warning(f"生成嵌入失败: {e}")
return []
@classmethod
def _cosine_similarity(cls, vec1: List[float], vec2: List[float]) -> float:
"""
计算余弦相似度
Args:
vec1: 向量1
vec2: 向量2
Returns:
相似度 (0-1)
"""
if not vec1 or not vec2:
return 0.0
try:
a = np.array(vec1)
b = np.array(vec2)
similarity = np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
return float(max(0, similarity))
except Exception:
return 0.0
@classmethod
async def _text_search(
cls,
agent_id: str,
query: str,
limit: int,
memory_type: Optional[str]
) -> List[Dict[str, Any]]:
"""
文本搜索后备方案
Args:
agent_id: Agent ID
query: 查询文本
limit: 返回数量
memory_type: 记忆类型
Returns:
记忆列表
"""
filter_query = {"agent_id": agent_id}
if memory_type:
filter_query["memory_type"] = memory_type
# 简单的文本匹配
memories = await AgentMemory.find(filter_query).to_list()
results = []
query_lower = query.lower()
for memory in memories:
if memory.is_expired():
continue
content_lower = memory.content.lower()
if query_lower in content_lower:
# 计算简单的匹配分数
score = len(query_lower) / len(content_lower)
results.append({
"memory": memory,
"similarity": score,
"relevance": score * memory.importance
})
results.sort(key=lambda x: x["relevance"], reverse=True)
return results[:limit]