197 lines
5.3 KiB
Python
197 lines
5.3 KiB
Python
|
|
"""
|
|||
|
|
数据库模型定义
|
|||
|
|
|
|||
|
|
使用 SQLAlchemy ORM 定义 Record 模型
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from datetime import datetime
|
|||
|
|
from typing import Optional
|
|||
|
|
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON
|
|||
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|||
|
|
from sqlalchemy import create_engine
|
|||
|
|
from sqlalchemy.orm import sessionmaker
|
|||
|
|
|
|||
|
|
Base = declarative_base()
|
|||
|
|
|
|||
|
|
|
|||
|
|
class RecordCategory:
|
|||
|
|
"""记录分类常量"""
|
|||
|
|
TODO = "TODO" # 待办事项
|
|||
|
|
NOTE = "NOTE" # 笔记
|
|||
|
|
IDEA = "IDEA" # 灵感
|
|||
|
|
REF = "REF" # 参考资料
|
|||
|
|
FUNNY = "FUNNY" # 搞笑文案
|
|||
|
|
TEXT = "TEXT" # 纯文本
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def all(cls):
|
|||
|
|
"""获取所有分类类型"""
|
|||
|
|
return [cls.TODO, cls.NOTE, cls.IDEA, cls.REF, cls.FUNNY, cls.TEXT]
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def is_valid(cls, category: str) -> bool:
|
|||
|
|
"""验证分类是否有效"""
|
|||
|
|
return category in cls.all()
|
|||
|
|
|
|||
|
|
|
|||
|
|
class Record(Base):
|
|||
|
|
"""记录模型 - 存储图片识别和AI处理结果"""
|
|||
|
|
|
|||
|
|
__tablename__ = 'records'
|
|||
|
|
|
|||
|
|
# 主键
|
|||
|
|
id = Column(Integer, primary_key=True, autoincrement=True, comment='记录ID')
|
|||
|
|
|
|||
|
|
# 图片路径
|
|||
|
|
image_path = Column(String(512), nullable=False, unique=True, index=True, comment='图片存储路径')
|
|||
|
|
|
|||
|
|
# OCR识别结果
|
|||
|
|
ocr_text = Column(Text, nullable=True, comment='OCR识别的文本内容')
|
|||
|
|
|
|||
|
|
# 分类类型
|
|||
|
|
category = Column(
|
|||
|
|
String(20),
|
|||
|
|
nullable=False,
|
|||
|
|
default=RecordCategory.NOTE,
|
|||
|
|
index=True,
|
|||
|
|
comment='记录分类'
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# AI生成的Markdown内容
|
|||
|
|
ai_result = Column(Text, nullable=True, comment='AI处理生成的Markdown内容')
|
|||
|
|
|
|||
|
|
# 标签(JSON格式存储)
|
|||
|
|
tags = Column(JSON, nullable=True, comment='标签列表')
|
|||
|
|
|
|||
|
|
# 用户备注
|
|||
|
|
notes = Column(Text, nullable=True, comment='用户手动添加的备注')
|
|||
|
|
|
|||
|
|
# 时间戳
|
|||
|
|
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, comment='创建时间')
|
|||
|
|
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False, comment='更新时间')
|
|||
|
|
|
|||
|
|
def __repr__(self):
|
|||
|
|
return f"<Record(id={self.id}, category='{self.category}', image_path='{self.image_path}')>"
|
|||
|
|
|
|||
|
|
def to_dict(self):
|
|||
|
|
"""转换为字典格式"""
|
|||
|
|
return {
|
|||
|
|
'id': self.id,
|
|||
|
|
'image_path': self.image_path,
|
|||
|
|
'ocr_text': self.ocr_text,
|
|||
|
|
'category': self.category,
|
|||
|
|
'ai_result': self.ai_result,
|
|||
|
|
'tags': self.tags or [],
|
|||
|
|
'notes': self.notes,
|
|||
|
|
'created_at': self.created_at.isoformat() if self.created_at else None,
|
|||
|
|
'updated_at': self.updated_at.isoformat() if self.updated_at else None,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def update_tags(self, tags: list):
|
|||
|
|
"""更新标签"""
|
|||
|
|
self.tags = tags
|
|||
|
|
|
|||
|
|
def add_tag(self, tag: str):
|
|||
|
|
"""添加单个标签"""
|
|||
|
|
if self.tags is None:
|
|||
|
|
self.tags = []
|
|||
|
|
if tag not in self.tags:
|
|||
|
|
self.tags.append(tag)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 数据库连接管理
|
|||
|
|
class DatabaseManager:
|
|||
|
|
"""数据库管理器 - 负责连接和会话管理"""
|
|||
|
|
|
|||
|
|
def __init__(self, db_path: str = "sqlite:///cutnthink.db"):
|
|||
|
|
"""
|
|||
|
|
初始化数据库管理器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
db_path: 数据库连接路径,默认使用SQLite
|
|||
|
|
"""
|
|||
|
|
self.db_path = db_path
|
|||
|
|
self.engine = None
|
|||
|
|
self.SessionLocal = None
|
|||
|
|
|
|||
|
|
def init_db(self, db_path: Optional[str] = None):
|
|||
|
|
"""
|
|||
|
|
初始化数据库连接和表结构
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
db_path: 可选的数据库路径,如果提供则覆盖初始化时的路径
|
|||
|
|
"""
|
|||
|
|
if db_path:
|
|||
|
|
self.db_path = db_path
|
|||
|
|
|
|||
|
|
# 创建数据库引擎
|
|||
|
|
self.engine = create_engine(
|
|||
|
|
self.db_path,
|
|||
|
|
echo=False, # 不输出SQL日志
|
|||
|
|
connect_args={"check_same_thread": False} # SQLite特定配置
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 创建会话工厂
|
|||
|
|
self.SessionLocal = sessionmaker(
|
|||
|
|
autocommit=False,
|
|||
|
|
autoflush=False,
|
|||
|
|
bind=self.engine
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 创建所有表
|
|||
|
|
Base.metadata.create_all(bind=self.engine)
|
|||
|
|
|
|||
|
|
def get_session(self):
|
|||
|
|
"""
|
|||
|
|
获取数据库会话
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
SQLAlchemy Session对象
|
|||
|
|
"""
|
|||
|
|
if self.SessionLocal is None:
|
|||
|
|
raise RuntimeError("数据库未初始化,请先调用 init_db() 方法")
|
|||
|
|
return self.SessionLocal()
|
|||
|
|
|
|||
|
|
def close(self):
|
|||
|
|
"""关闭数据库连接"""
|
|||
|
|
if self.engine:
|
|||
|
|
self.engine.dispose()
|
|||
|
|
self.engine = None
|
|||
|
|
self.SessionLocal = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局数据库管理器实例
|
|||
|
|
db_manager = DatabaseManager()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def init_database(db_path: str = "sqlite:////home/congsh/CodeSpace/ClaudeSpace/CutThenThink/data/cutnthink.db"):
|
|||
|
|
"""
|
|||
|
|
初始化数据库的便捷函数
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
db_path: 数据库文件路径
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
DatabaseManager实例
|
|||
|
|
"""
|
|||
|
|
db_manager.init_db(db_path)
|
|||
|
|
return db_manager
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_db():
|
|||
|
|
"""
|
|||
|
|
获取数据库会话的便捷函数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
SQLAlchemy Session对象
|
|||
|
|
|
|||
|
|
Example:
|
|||
|
|
>>> session = get_db()
|
|||
|
|
>>> try:
|
|||
|
|
... # 使用session进行数据库操作
|
|||
|
|
... records = session.query(Record).all()
|
|||
|
|
... finally:
|
|||
|
|
... session.close()
|
|||
|
|
"""
|
|||
|
|
return db_manager.get_session()
|