""" AI 模块测试脚本 测试各个 AI 提供商的文本分类功能 """ import sys import os # 添加项目根目录到 Python 路径 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from src.core.ai import ( CategoryType, ClassificationResult, AIClassifier, classify_text, AIError, ) def test_classification_result(): """测试分类结果数据结构""" print("=== 测试分类结果数据结构 ===") result = ClassificationResult( category=CategoryType.TODO, confidence=0.95, title="测试任务", content="- [ ] 完成测试", tags=["测试", "任务"], reasoning="包含待办事项关键词" ) print(f"分类: {result.category}") print(f"置信度: {result.confidence}") print(f"标题: {result.title}") print(f"内容: {result.content}") print(f"标签: {result.tags}") print(f"理由: {result.reasoning}") # 测试转换为字典 result_dict = result.to_dict() print(f"\n转换为字典: {result_dict}") # 测试从字典创建 result2 = ClassificationResult.from_dict(result_dict) print(f"\n从字典恢复: 分类={result2.category}, 标题={result2.title}") print("✅ 分类结果数据结构测试通过\n") def test_category_types(): """测试分类类型枚举""" print("=== 测试分类类型枚举 ===") print("所有分类类型:") for category in CategoryType: print(f" - {category.name}: {category.value}") print(f"\n所有分类值: {CategoryType.all()}") # 测试验证 assert CategoryType.is_valid("TODO") == True assert CategoryType.is_valid("NOTE") == True assert CategoryType.is_valid("INVALID") == False print("✅ 分类类型枚举测试通过\n") def test_ai_classifier_creation(): """测试 AI 分类器创建""" print("=== 测试 AI 分类器创建 ===") providers = ["openai", "anthropic", "qwen", "ollama"] for provider in providers: try: client = AIClassifier.create_client( provider=provider, api_key="test_key", model="test_model" ) print(f"✓ {provider} 客户端创建成功: {type(client).__name__}") except Exception as e: print(f"✗ {provider} 客户端创建失败: {e}") print("✅ AI 分类器创建测试通过\n") def test_mock_classification(): """模拟分类测试(不实际调用 API)""" print("=== 模拟分类测试 ===") # 测试文本样本 test_cases = [ { "text": "今天要完成的任务:\n1. 完成项目文档\n2. 修复 Bug #123\n3. 参加团队会议", "expected": CategoryType.TODO, "description": "待办事项" }, { "text": "Python 中的列表推导式是一种简洁的语法糖。\n\n示例:\n[x * 2 for x in range(10)]", "expected": CategoryType.NOTE, "description": "编程笔记" }, { "text": "突然想到一个产品创意:做一个能自动识别截图分类的工具!\n可以使用 AI + OCR 实现。", "expected": CategoryType.IDEA, "description": "产品灵感" }, { "text": "API 文档参考:\nGET /api/users\n获取用户列表\n\n参数:\n- page: 页码\n- limit: 每页数量", "expected": CategoryType.REF, "description": "参考资料" }, { "text": "程序员最讨厌的四件事:\n1. 写注释\n2. 写文档\n3. 别人不写注释\n4. 别人不写文档", "expected": CategoryType.FUNNY, "description": "搞笑段子" }, ] print("测试用例:") for i, case in enumerate(test_cases, 1): print(f"\n{i}. {case['description']}") print(f" 预期分类: {case['expected'].value}") print(f" 文本预览: {case['text'][:50]}...") print("\n注意:实际分类需要配置 API key 并调用 AI 服务") def test_error_handling(): """测试错误处理""" print("=== 测试错误处理 ===") # 测试不支持的提供商 try: AIClassifier.create_client( provider="unsupported_provider", api_key="test", model="test-model" ) print("✗ 应该抛出异常但没有") except AIError as e: print(f"✓ 正确捕获不支持的提供商错误: {e}") print("✅ 错误处理测试通过\n") def print_usage_examples(): """打印使用示例""" print("=== 使用示例 ===\n") print("1. 使用 OpenAI 进行分类:") print(""" from src.core.ai import AIClassifier client = AIClassifier.create_client( provider="openai", api_key="your-api-key", model="gpt-4o-mini" ) result = client.classify("待分析的文本") print(result.category) # TODO print(result.content) # Markdown 格式内容 """) print("\n2. 使用配置文件进行分类:") print(""" from src.config.settings import get_settings from src.core.ai import classify_text settings = get_settings() result = classify_text("待分析的文本", settings.ai) """) print("\n3. 使用 Claude 进行分类:") print(""" client = AIClassifier.create_client( provider="anthropic", api_key="your-api-key", model="claude-3-5-sonnet-20241022" ) result = client.classify("待分析的文本") """) print("\n4. 使用本地 Ollama:") print(""" client = AIClassifier.create_client( provider="ollama", api_key="", # Ollama 不需要 API key model="llama3.2" ) result = client.classify("待分析的文本") """) def main(): """主测试函数""" print("=" * 60) print("AI 模块测试") print("=" * 60) print() # 运行所有测试 test_classification_result() test_category_types() test_ai_classifier_creation() test_mock_classification() test_error_handling() # 打印使用示例 print_usage_examples() print("\n" + "=" * 60) print("所有测试完成!") print("=" * 60) if __name__ == "__main__": main()