Files
cutThenThink/src/models/database.py

197 lines
5.3 KiB
Python
Raw Normal View History

"""
数据库模型定义
使用 SQLAlchemy ORM 定义 Record 模型
"""
from datetime import datetime
from typing import Optional
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
Base = declarative_base()
class RecordCategory:
"""记录分类常量"""
TODO = "TODO" # 待办事项
NOTE = "NOTE" # 笔记
IDEA = "IDEA" # 灵感
REF = "REF" # 参考资料
FUNNY = "FUNNY" # 搞笑文案
TEXT = "TEXT" # 纯文本
@classmethod
def all(cls):
"""获取所有分类类型"""
return [cls.TODO, cls.NOTE, cls.IDEA, cls.REF, cls.FUNNY, cls.TEXT]
@classmethod
def is_valid(cls, category: str) -> bool:
"""验证分类是否有效"""
return category in cls.all()
class Record(Base):
"""记录模型 - 存储图片识别和AI处理结果"""
__tablename__ = 'records'
# 主键
id = Column(Integer, primary_key=True, autoincrement=True, comment='记录ID')
# 图片路径
image_path = Column(String(512), nullable=False, unique=True, index=True, comment='图片存储路径')
# OCR识别结果
ocr_text = Column(Text, nullable=True, comment='OCR识别的文本内容')
# 分类类型
category = Column(
String(20),
nullable=False,
default=RecordCategory.NOTE,
index=True,
comment='记录分类'
)
# AI生成的Markdown内容
ai_result = Column(Text, nullable=True, comment='AI处理生成的Markdown内容')
# 标签JSON格式存储
tags = Column(JSON, nullable=True, comment='标签列表')
# 用户备注
notes = Column(Text, nullable=True, comment='用户手动添加的备注')
# 时间戳
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, comment='创建时间')
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False, comment='更新时间')
def __repr__(self):
return f"<Record(id={self.id}, category='{self.category}', image_path='{self.image_path}')>"
def to_dict(self):
"""转换为字典格式"""
return {
'id': self.id,
'image_path': self.image_path,
'ocr_text': self.ocr_text,
'category': self.category,
'ai_result': self.ai_result,
'tags': self.tags or [],
'notes': self.notes,
'created_at': self.created_at.isoformat() if self.created_at else None,
'updated_at': self.updated_at.isoformat() if self.updated_at else None,
}
def update_tags(self, tags: list):
"""更新标签"""
self.tags = tags
def add_tag(self, tag: str):
"""添加单个标签"""
if self.tags is None:
self.tags = []
if tag not in self.tags:
self.tags.append(tag)
# 数据库连接管理
class DatabaseManager:
"""数据库管理器 - 负责连接和会话管理"""
def __init__(self, db_path: str = "sqlite:///cutnthink.db"):
"""
初始化数据库管理器
Args:
db_path: 数据库连接路径默认使用SQLite
"""
self.db_path = db_path
self.engine = None
self.SessionLocal = None
def init_db(self, db_path: Optional[str] = None):
"""
初始化数据库连接和表结构
Args:
db_path: 可选的数据库路径如果提供则覆盖初始化时的路径
"""
if db_path:
self.db_path = db_path
# 创建数据库引擎
self.engine = create_engine(
self.db_path,
echo=False, # 不输出SQL日志
connect_args={"check_same_thread": False} # SQLite特定配置
)
# 创建会话工厂
self.SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=self.engine
)
# 创建所有表
Base.metadata.create_all(bind=self.engine)
def get_session(self):
"""
获取数据库会话
Returns:
SQLAlchemy Session对象
"""
if self.SessionLocal is None:
raise RuntimeError("数据库未初始化,请先调用 init_db() 方法")
return self.SessionLocal()
def close(self):
"""关闭数据库连接"""
if self.engine:
self.engine.dispose()
self.engine = None
self.SessionLocal = None
# 全局数据库管理器实例
db_manager = DatabaseManager()
def init_database(db_path: str = "sqlite:////home/congsh/CodeSpace/ClaudeSpace/CutThenThink/data/cutnthink.db"):
"""
初始化数据库的便捷函数
Args:
db_path: 数据库文件路径
Returns:
DatabaseManager实例
"""
db_manager.init_db(db_path)
return db_manager
def get_db():
"""
获取数据库会话的便捷函数
Returns:
SQLAlchemy Session对象
Example:
>>> session = get_db()
>>> try:
... # 使用session进行数据库操作
... records = session.query(Record).all()
... finally:
... session.close()
"""
return db_manager.get_session()