项目初始化 - 创建完整项目结构(src/, data/, docs/, examples/, tests/) - 配置requirements.txt依赖 - 创建.gitignore P0基础框架 - 数据库模型:Record模型,6种分类类型 - 配置管理:YAML配置,支持AI/OCR/云存储/UI配置 - OCR模块:PaddleOCR本地识别,支持云端扩展 - AI模块:支持OpenAI/Claude/通义/Ollama,6种分类 - 存储模块:完整CRUD,搜索,统计,导入导出 - 主窗口框架:侧边栏导航,米白配色方案 - 图片处理:截图/剪贴板/文件选择/图片预览 - 处理流程整合:OCR→AI→存储串联,Markdown展示,剪贴板复制 - 分类浏览:卡片网格展示,分类筛选,搜索,详情查看 技术栈 - PyQt6 + SQLAlchemy + PaddleOCR + OpenAI/Claude SDK - 共47个Python文件,4000+行代码 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
226 lines
6.2 KiB
Python
226 lines
6.2 KiB
Python
"""
|
||
AI 模块测试脚本
|
||
|
||
测试各个 AI 提供商的文本分类功能
|
||
"""
|
||
|
||
import sys
|
||
import os
|
||
|
||
# 添加项目根目录到 Python 路径
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from src.core.ai import (
|
||
CategoryType,
|
||
ClassificationResult,
|
||
AIClassifier,
|
||
classify_text,
|
||
AIError,
|
||
)
|
||
|
||
|
||
def test_classification_result():
|
||
"""测试分类结果数据结构"""
|
||
print("=== 测试分类结果数据结构 ===")
|
||
|
||
result = ClassificationResult(
|
||
category=CategoryType.TODO,
|
||
confidence=0.95,
|
||
title="测试任务",
|
||
content="- [ ] 完成测试",
|
||
tags=["测试", "任务"],
|
||
reasoning="包含待办事项关键词"
|
||
)
|
||
|
||
print(f"分类: {result.category}")
|
||
print(f"置信度: {result.confidence}")
|
||
print(f"标题: {result.title}")
|
||
print(f"内容: {result.content}")
|
||
print(f"标签: {result.tags}")
|
||
print(f"理由: {result.reasoning}")
|
||
|
||
# 测试转换为字典
|
||
result_dict = result.to_dict()
|
||
print(f"\n转换为字典: {result_dict}")
|
||
|
||
# 测试从字典创建
|
||
result2 = ClassificationResult.from_dict(result_dict)
|
||
print(f"\n从字典恢复: 分类={result2.category}, 标题={result2.title}")
|
||
|
||
print("✅ 分类结果数据结构测试通过\n")
|
||
|
||
|
||
def test_category_types():
|
||
"""测试分类类型枚举"""
|
||
print("=== 测试分类类型枚举 ===")
|
||
|
||
print("所有分类类型:")
|
||
for category in CategoryType:
|
||
print(f" - {category.name}: {category.value}")
|
||
|
||
print(f"\n所有分类值: {CategoryType.all()}")
|
||
|
||
# 测试验证
|
||
assert CategoryType.is_valid("TODO") == True
|
||
assert CategoryType.is_valid("NOTE") == True
|
||
assert CategoryType.is_valid("INVALID") == False
|
||
|
||
print("✅ 分类类型枚举测试通过\n")
|
||
|
||
|
||
def test_ai_classifier_creation():
|
||
"""测试 AI 分类器创建"""
|
||
print("=== 测试 AI 分类器创建 ===")
|
||
|
||
providers = ["openai", "anthropic", "qwen", "ollama"]
|
||
|
||
for provider in providers:
|
||
try:
|
||
client = AIClassifier.create_client(
|
||
provider=provider,
|
||
api_key="test_key",
|
||
model="test_model"
|
||
)
|
||
print(f"✓ {provider} 客户端创建成功: {type(client).__name__}")
|
||
except Exception as e:
|
||
print(f"✗ {provider} 客户端创建失败: {e}")
|
||
|
||
print("✅ AI 分类器创建测试通过\n")
|
||
|
||
|
||
def test_mock_classification():
|
||
"""模拟分类测试(不实际调用 API)"""
|
||
print("=== 模拟分类测试 ===")
|
||
|
||
# 测试文本样本
|
||
test_cases = [
|
||
{
|
||
"text": "今天要完成的任务:\n1. 完成项目文档\n2. 修复 Bug #123\n3. 参加团队会议",
|
||
"expected": CategoryType.TODO,
|
||
"description": "待办事项"
|
||
},
|
||
{
|
||
"text": "Python 中的列表推导式是一种简洁的语法糖。\n\n示例:\n[x * 2 for x in range(10)]",
|
||
"expected": CategoryType.NOTE,
|
||
"description": "编程笔记"
|
||
},
|
||
{
|
||
"text": "突然想到一个产品创意:做一个能自动识别截图分类的工具!\n可以使用 AI + OCR 实现。",
|
||
"expected": CategoryType.IDEA,
|
||
"description": "产品灵感"
|
||
},
|
||
{
|
||
"text": "API 文档参考:\nGET /api/users\n获取用户列表\n\n参数:\n- page: 页码\n- limit: 每页数量",
|
||
"expected": CategoryType.REF,
|
||
"description": "参考资料"
|
||
},
|
||
{
|
||
"text": "程序员最讨厌的四件事:\n1. 写注释\n2. 写文档\n3. 别人不写注释\n4. 别人不写文档",
|
||
"expected": CategoryType.FUNNY,
|
||
"description": "搞笑段子"
|
||
},
|
||
]
|
||
|
||
print("测试用例:")
|
||
for i, case in enumerate(test_cases, 1):
|
||
print(f"\n{i}. {case['description']}")
|
||
print(f" 预期分类: {case['expected'].value}")
|
||
print(f" 文本预览: {case['text'][:50]}...")
|
||
|
||
print("\n注意:实际分类需要配置 API key 并调用 AI 服务")
|
||
|
||
|
||
def test_error_handling():
|
||
"""测试错误处理"""
|
||
print("=== 测试错误处理 ===")
|
||
|
||
# 测试不支持的提供商
|
||
try:
|
||
AIClassifier.create_client(
|
||
provider="unsupported_provider",
|
||
api_key="test",
|
||
model="test-model"
|
||
)
|
||
print("✗ 应该抛出异常但没有")
|
||
except AIError as e:
|
||
print(f"✓ 正确捕获不支持的提供商错误: {e}")
|
||
|
||
print("✅ 错误处理测试通过\n")
|
||
|
||
|
||
def print_usage_examples():
|
||
"""打印使用示例"""
|
||
print("=== 使用示例 ===\n")
|
||
|
||
print("1. 使用 OpenAI 进行分类:")
|
||
print("""
|
||
from src.core.ai import AIClassifier
|
||
|
||
client = AIClassifier.create_client(
|
||
provider="openai",
|
||
api_key="your-api-key",
|
||
model="gpt-4o-mini"
|
||
)
|
||
|
||
result = client.classify("待分析的文本")
|
||
print(result.category) # TODO
|
||
print(result.content) # Markdown 格式内容
|
||
""")
|
||
|
||
print("\n2. 使用配置文件进行分类:")
|
||
print("""
|
||
from src.config.settings import get_settings
|
||
from src.core.ai import classify_text
|
||
|
||
settings = get_settings()
|
||
result = classify_text("待分析的文本", settings.ai)
|
||
""")
|
||
|
||
print("\n3. 使用 Claude 进行分类:")
|
||
print("""
|
||
client = AIClassifier.create_client(
|
||
provider="anthropic",
|
||
api_key="your-api-key",
|
||
model="claude-3-5-sonnet-20241022"
|
||
)
|
||
|
||
result = client.classify("待分析的文本")
|
||
""")
|
||
|
||
print("\n4. 使用本地 Ollama:")
|
||
print("""
|
||
client = AIClassifier.create_client(
|
||
provider="ollama",
|
||
api_key="", # Ollama 不需要 API key
|
||
model="llama3.2"
|
||
)
|
||
|
||
result = client.classify("待分析的文本")
|
||
""")
|
||
|
||
|
||
def main():
|
||
"""主测试函数"""
|
||
print("=" * 60)
|
||
print("AI 模块测试")
|
||
print("=" * 60)
|
||
print()
|
||
|
||
# 运行所有测试
|
||
test_classification_result()
|
||
test_category_types()
|
||
test_ai_classifier_creation()
|
||
test_mock_classification()
|
||
test_error_handling()
|
||
|
||
# 打印使用示例
|
||
print_usage_examples()
|
||
|
||
print("\n" + "=" * 60)
|
||
print("所有测试完成!")
|
||
print("=" * 60)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|