124 lines
3.9 KiB
Python
124 lines
3.9 KiB
Python
|
|
"""
|
|||
|
|
Agent记忆数据模型
|
|||
|
|
定义Agent的记忆存储结构
|
|||
|
|
"""
|
|||
|
|
from datetime import datetime
|
|||
|
|
from typing import Optional, List
|
|||
|
|
from enum import Enum
|
|||
|
|
from pydantic import Field
|
|||
|
|
from beanie import Document
|
|||
|
|
|
|||
|
|
|
|||
|
|
class MemoryType(str, Enum):
|
|||
|
|
"""记忆类型枚举"""
|
|||
|
|
SHORT_TERM = "short_term" # 短期记忆(会话内)
|
|||
|
|
LONG_TERM = "long_term" # 长期记忆(跨会话)
|
|||
|
|
EPISODIC = "episodic" # 情景记忆(特定事件)
|
|||
|
|
SEMANTIC = "semantic" # 语义记忆(知识性)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AgentMemory(Document):
|
|||
|
|
"""
|
|||
|
|
Agent记忆文档模型
|
|||
|
|
存储Agent的记忆内容
|
|||
|
|
"""
|
|||
|
|
memory_id: str = Field(..., description="唯一标识")
|
|||
|
|
agent_id: str = Field(..., description="Agent ID")
|
|||
|
|
|
|||
|
|
# 记忆内容
|
|||
|
|
memory_type: str = Field(
|
|||
|
|
default=MemoryType.SHORT_TERM.value,
|
|||
|
|
description="记忆类型"
|
|||
|
|
)
|
|||
|
|
content: str = Field(..., description="记忆内容")
|
|||
|
|
summary: str = Field(default="", description="内容摘要")
|
|||
|
|
|
|||
|
|
# 向量嵌入(用于相似度检索)
|
|||
|
|
embedding: List[float] = Field(default_factory=list, description="向量嵌入")
|
|||
|
|
|
|||
|
|
# 元数据
|
|||
|
|
importance: float = Field(default=0.5, ge=0, le=1, description="重要性评分")
|
|||
|
|
access_count: int = Field(default=0, description="访问次数")
|
|||
|
|
|
|||
|
|
# 关联信息
|
|||
|
|
source_room_id: Optional[str] = Field(default=None, description="来源聊天室ID")
|
|||
|
|
source_discussion_id: Optional[str] = Field(default=None, description="来源讨论ID")
|
|||
|
|
related_agents: List[str] = Field(default_factory=list, description="相关Agent列表")
|
|||
|
|
tags: List[str] = Field(default_factory=list, description="标签")
|
|||
|
|
|
|||
|
|
# 时间戳
|
|||
|
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
|||
|
|
last_accessed: datetime = Field(default_factory=datetime.utcnow)
|
|||
|
|
expires_at: Optional[datetime] = Field(default=None, description="过期时间")
|
|||
|
|
|
|||
|
|
class Settings:
|
|||
|
|
name = "agent_memories"
|
|||
|
|
indexes = [
|
|||
|
|
[("agent_id", 1)],
|
|||
|
|
[("memory_type", 1)],
|
|||
|
|
[("importance", -1)],
|
|||
|
|
[("last_accessed", -1)],
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
def access(self) -> None:
|
|||
|
|
"""
|
|||
|
|
记录访问,更新访问计数和时间
|
|||
|
|
"""
|
|||
|
|
self.access_count += 1
|
|||
|
|
self.last_accessed = datetime.utcnow()
|
|||
|
|
|
|||
|
|
def is_expired(self) -> bool:
|
|||
|
|
"""
|
|||
|
|
检查记忆是否已过期
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
是否过期
|
|||
|
|
"""
|
|||
|
|
if self.expires_at is None:
|
|||
|
|
return False
|
|||
|
|
return datetime.utcnow() > self.expires_at
|
|||
|
|
|
|||
|
|
def calculate_relevance_score(
|
|||
|
|
self,
|
|||
|
|
similarity: float,
|
|||
|
|
time_decay_factor: float = 0.1
|
|||
|
|
) -> float:
|
|||
|
|
"""
|
|||
|
|
计算综合相关性分数
|
|||
|
|
结合向量相似度、重要性和时间衰减
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
similarity: 向量相似度 (0-1)
|
|||
|
|
time_decay_factor: 时间衰减因子
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
综合相关性分数
|
|||
|
|
"""
|
|||
|
|
# 计算时间衰减
|
|||
|
|
hours_since_access = (datetime.utcnow() - self.last_accessed).total_seconds() / 3600
|
|||
|
|
time_decay = 1.0 / (1.0 + time_decay_factor * hours_since_access)
|
|||
|
|
|
|||
|
|
# 综合评分
|
|||
|
|
score = (
|
|||
|
|
0.5 * similarity +
|
|||
|
|
0.3 * self.importance +
|
|||
|
|
0.2 * time_decay
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return min(1.0, max(0.0, score))
|
|||
|
|
|
|||
|
|
class Config:
|
|||
|
|
json_schema_extra = {
|
|||
|
|
"example": {
|
|||
|
|
"memory_id": "mem-001",
|
|||
|
|
"agent_id": "product-manager",
|
|||
|
|
"memory_type": "long_term",
|
|||
|
|
"content": "在登录系统设计讨论中,团队决定采用OAuth2.0方案",
|
|||
|
|
"summary": "登录系统采用OAuth2.0",
|
|||
|
|
"importance": 0.8,
|
|||
|
|
"access_count": 5,
|
|||
|
|
"source_room_id": "product-design-room",
|
|||
|
|
"tags": ["登录", "OAuth", "认证"]
|
|||
|
|
}
|
|||
|
|
}
|