115 lines
3.3 KiB
Python
115 lines
3.3 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
数据库模型测试脚本
|
||
|
|
|
||
|
|
用于验证数据库模型的创建和基本功能
|
||
|
|
"""
|
||
|
|
|
||
|
|
import sys
|
||
|
|
import os
|
||
|
|
|
||
|
|
# 添加项目根目录到Python路径
|
||
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
|
|
||
|
|
from src.models import Record, RecordCategory, init_database, get_db
|
||
|
|
|
||
|
|
|
||
|
|
def test_database():
|
||
|
|
"""测试数据库基本功能"""
|
||
|
|
print("=" * 60)
|
||
|
|
print("数据库模型测试")
|
||
|
|
print("=" * 60)
|
||
|
|
|
||
|
|
# 1. 初始化数据库
|
||
|
|
print("\n[1] 初始化数据库...")
|
||
|
|
db_path = "sqlite:////home/congsh/CodeSpace/ClaudeSpace/CutThenThink/data/test.db"
|
||
|
|
db_manager = init_database(db_path)
|
||
|
|
print(f"✓ 数据库初始化成功: {db_path}")
|
||
|
|
|
||
|
|
# 2. 创建测试数据
|
||
|
|
print("\n[2] 创建测试记录...")
|
||
|
|
session = get_db()
|
||
|
|
|
||
|
|
test_record = Record(
|
||
|
|
image_path="/test/image1.png",
|
||
|
|
ocr_text="这是OCR识别的测试文本",
|
||
|
|
category=RecordCategory.NOTE,
|
||
|
|
ai_result="# AI生成的内容\n\n这是一条测试笔记。",
|
||
|
|
tags=["测试", "示例"],
|
||
|
|
notes="这是一条手动备注"
|
||
|
|
)
|
||
|
|
|
||
|
|
session.add(test_record)
|
||
|
|
session.commit()
|
||
|
|
print(f"✓ 记录创建成功: ID={test_record.id}")
|
||
|
|
|
||
|
|
# 3. 测试标签功能
|
||
|
|
print("\n[3] 测试标签功能...")
|
||
|
|
test_record.add_tag("新标签")
|
||
|
|
session.commit()
|
||
|
|
print(f"✓ 标签添加成功: {test_record.tags}")
|
||
|
|
|
||
|
|
# 4. 测试分类常量
|
||
|
|
print("\n[4] 测试分类常量...")
|
||
|
|
print(f"所有分类: {RecordCategory.all()}")
|
||
|
|
print(f"验证分类 'NOTE': {RecordCategory.is_valid('NOTE')}")
|
||
|
|
print(f"验证分类 'INVALID': {RecordCategory.is_valid('INVALID')}")
|
||
|
|
print("✓ 分类常量测试完成")
|
||
|
|
|
||
|
|
# 5. 测试查询功能
|
||
|
|
print("\n[5] 测试查询功能...")
|
||
|
|
records = session.query(Record).all()
|
||
|
|
print(f"✓ 查询到 {len(records)} 条记录")
|
||
|
|
|
||
|
|
for record in records:
|
||
|
|
print(f" - ID: {record.id}, 分类: {record.category}, 路径: {record.image_path}")
|
||
|
|
print(f" 字典格式: {record.to_dict()}")
|
||
|
|
|
||
|
|
# 6. 测试时间戳
|
||
|
|
print("\n[6] 测试时间戳...")
|
||
|
|
print(f"创建时间: {test_record.created_at}")
|
||
|
|
print(f"更新时间: {test_record.updated_at}")
|
||
|
|
|
||
|
|
# 清理
|
||
|
|
session.close()
|
||
|
|
print("\n" + "=" * 60)
|
||
|
|
print("✓ 所有测试完成")
|
||
|
|
print("=" * 60)
|
||
|
|
|
||
|
|
|
||
|
|
def test_all_categories():
|
||
|
|
"""测试所有分类类型"""
|
||
|
|
print("\n[额外测试] 创建不同分类的记录...")
|
||
|
|
session = get_db()
|
||
|
|
|
||
|
|
test_data = [
|
||
|
|
(RecordCategory.TODO, "/test/todo.png", "待办事项图片"),
|
||
|
|
(RecordCategory.IDEA, "/test/idea.png", "灵感记录"),
|
||
|
|
(RecordCategory.REF, "/test/ref.png", "参考资料"),
|
||
|
|
(RecordCategory.FUNNY, "/test/funny.png", "搞笑图片"),
|
||
|
|
(RecordCategory.TEXT, "/test/text.png", "纯文本"),
|
||
|
|
]
|
||
|
|
|
||
|
|
for category, path, notes in test_data:
|
||
|
|
record = Record(
|
||
|
|
image_path=path,
|
||
|
|
category=category,
|
||
|
|
notes=notes
|
||
|
|
)
|
||
|
|
session.add(record)
|
||
|
|
|
||
|
|
session.commit()
|
||
|
|
print(f"✓ 成功创建 {len(test_data)} 条不同分类的记录")
|
||
|
|
|
||
|
|
# 按分类查询
|
||
|
|
for category in RecordCategory.all():
|
||
|
|
count = session.query(Record).filter_by(category=category).count()
|
||
|
|
print(f" {category}: {count} 条")
|
||
|
|
|
||
|
|
session.close()
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
test_database()
|
||
|
|
test_all_categories()
|