#!/usr/bin/env python3 """ 数据库模型测试脚本 用于验证数据库模型的创建和基本功能 """ import sys import os # 添加项目根目录到Python路径 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from src.models import Record, RecordCategory, init_database, get_db def test_database(): """测试数据库基本功能""" print("=" * 60) print("数据库模型测试") print("=" * 60) # 1. 初始化数据库 print("\n[1] 初始化数据库...") db_path = "sqlite:////home/congsh/CodeSpace/ClaudeSpace/CutThenThink/data/test.db" db_manager = init_database(db_path) print(f"✓ 数据库初始化成功: {db_path}") # 2. 创建测试数据 print("\n[2] 创建测试记录...") session = get_db() test_record = Record( image_path="/test/image1.png", ocr_text="这是OCR识别的测试文本", category=RecordCategory.NOTE, ai_result="# AI生成的内容\n\n这是一条测试笔记。", tags=["测试", "示例"], notes="这是一条手动备注" ) session.add(test_record) session.commit() print(f"✓ 记录创建成功: ID={test_record.id}") # 3. 测试标签功能 print("\n[3] 测试标签功能...") test_record.add_tag("新标签") session.commit() print(f"✓ 标签添加成功: {test_record.tags}") # 4. 测试分类常量 print("\n[4] 测试分类常量...") print(f"所有分类: {RecordCategory.all()}") print(f"验证分类 'NOTE': {RecordCategory.is_valid('NOTE')}") print(f"验证分类 'INVALID': {RecordCategory.is_valid('INVALID')}") print("✓ 分类常量测试完成") # 5. 测试查询功能 print("\n[5] 测试查询功能...") records = session.query(Record).all() print(f"✓ 查询到 {len(records)} 条记录") for record in records: print(f" - ID: {record.id}, 分类: {record.category}, 路径: {record.image_path}") print(f" 字典格式: {record.to_dict()}") # 6. 测试时间戳 print("\n[6] 测试时间戳...") print(f"创建时间: {test_record.created_at}") print(f"更新时间: {test_record.updated_at}") # 清理 session.close() print("\n" + "=" * 60) print("✓ 所有测试完成") print("=" * 60) def test_all_categories(): """测试所有分类类型""" print("\n[额外测试] 创建不同分类的记录...") session = get_db() test_data = [ (RecordCategory.TODO, "/test/todo.png", "待办事项图片"), (RecordCategory.IDEA, "/test/idea.png", "灵感记录"), (RecordCategory.REF, "/test/ref.png", "参考资料"), (RecordCategory.FUNNY, "/test/funny.png", "搞笑图片"), (RecordCategory.TEXT, "/test/text.png", "纯文本"), ] for category, path, notes in test_data: record = Record( image_path=path, category=category, notes=notes ) session.add(record) session.commit() print(f"✓ 成功创建 {len(test_data)} 条不同分类的记录") # 按分类查询 for category in RecordCategory.all(): count = session.query(Record).filter_by(category=category).count() print(f" {category}: {count} 条") session.close() if __name__ == "__main__": test_database() test_all_categories()