Files
cutThenThink/tests/test_ai.py

226 lines
6.2 KiB
Python
Raw Normal View History

"""
AI 模块测试脚本
测试各个 AI 提供商的文本分类功能
"""
import sys
import os
# 添加项目根目录到 Python 路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.core.ai import (
CategoryType,
ClassificationResult,
AIClassifier,
classify_text,
AIError,
)
def test_classification_result():
"""测试分类结果数据结构"""
print("=== 测试分类结果数据结构 ===")
result = ClassificationResult(
category=CategoryType.TODO,
confidence=0.95,
title="测试任务",
content="- [ ] 完成测试",
tags=["测试", "任务"],
reasoning="包含待办事项关键词"
)
print(f"分类: {result.category}")
print(f"置信度: {result.confidence}")
print(f"标题: {result.title}")
print(f"内容: {result.content}")
print(f"标签: {result.tags}")
print(f"理由: {result.reasoning}")
# 测试转换为字典
result_dict = result.to_dict()
print(f"\n转换为字典: {result_dict}")
# 测试从字典创建
result2 = ClassificationResult.from_dict(result_dict)
print(f"\n从字典恢复: 分类={result2.category}, 标题={result2.title}")
print("✅ 分类结果数据结构测试通过\n")
def test_category_types():
"""测试分类类型枚举"""
print("=== 测试分类类型枚举 ===")
print("所有分类类型:")
for category in CategoryType:
print(f" - {category.name}: {category.value}")
print(f"\n所有分类值: {CategoryType.all()}")
# 测试验证
assert CategoryType.is_valid("TODO") == True
assert CategoryType.is_valid("NOTE") == True
assert CategoryType.is_valid("INVALID") == False
print("✅ 分类类型枚举测试通过\n")
def test_ai_classifier_creation():
"""测试 AI 分类器创建"""
print("=== 测试 AI 分类器创建 ===")
providers = ["openai", "anthropic", "qwen", "ollama"]
for provider in providers:
try:
client = AIClassifier.create_client(
provider=provider,
api_key="test_key",
model="test_model"
)
print(f"{provider} 客户端创建成功: {type(client).__name__}")
except Exception as e:
print(f"{provider} 客户端创建失败: {e}")
print("✅ AI 分类器创建测试通过\n")
def test_mock_classification():
"""模拟分类测试(不实际调用 API"""
print("=== 模拟分类测试 ===")
# 测试文本样本
test_cases = [
{
"text": "今天要完成的任务:\n1. 完成项目文档\n2. 修复 Bug #123\n3. 参加团队会议",
"expected": CategoryType.TODO,
"description": "待办事项"
},
{
"text": "Python 中的列表推导式是一种简洁的语法糖。\n\n示例:\n[x * 2 for x in range(10)]",
"expected": CategoryType.NOTE,
"description": "编程笔记"
},
{
"text": "突然想到一个产品创意:做一个能自动识别截图分类的工具!\n可以使用 AI + OCR 实现。",
"expected": CategoryType.IDEA,
"description": "产品灵感"
},
{
"text": "API 文档参考:\nGET /api/users\n获取用户列表\n\n参数:\n- page: 页码\n- limit: 每页数量",
"expected": CategoryType.REF,
"description": "参考资料"
},
{
"text": "程序员最讨厌的四件事:\n1. 写注释\n2. 写文档\n3. 别人不写注释\n4. 别人不写文档",
"expected": CategoryType.FUNNY,
"description": "搞笑段子"
},
]
print("测试用例:")
for i, case in enumerate(test_cases, 1):
print(f"\n{i}. {case['description']}")
print(f" 预期分类: {case['expected'].value}")
print(f" 文本预览: {case['text'][:50]}...")
print("\n注意:实际分类需要配置 API key 并调用 AI 服务")
def test_error_handling():
"""测试错误处理"""
print("=== 测试错误处理 ===")
# 测试不支持的提供商
try:
AIClassifier.create_client(
provider="unsupported_provider",
api_key="test",
model="test-model"
)
print("✗ 应该抛出异常但没有")
except AIError as e:
print(f"✓ 正确捕获不支持的提供商错误: {e}")
print("✅ 错误处理测试通过\n")
def print_usage_examples():
"""打印使用示例"""
print("=== 使用示例 ===\n")
print("1. 使用 OpenAI 进行分类:")
print("""
from src.core.ai import AIClassifier
client = AIClassifier.create_client(
provider="openai",
api_key="your-api-key",
model="gpt-4o-mini"
)
result = client.classify("待分析的文本")
print(result.category) # TODO
print(result.content) # Markdown 格式内容
""")
print("\n2. 使用配置文件进行分类:")
print("""
from src.config.settings import get_settings
from src.core.ai import classify_text
settings = get_settings()
result = classify_text("待分析的文本", settings.ai)
""")
print("\n3. 使用 Claude 进行分类:")
print("""
client = AIClassifier.create_client(
provider="anthropic",
api_key="your-api-key",
model="claude-3-5-sonnet-20241022"
)
result = client.classify("待分析的文本")
""")
print("\n4. 使用本地 Ollama:")
print("""
client = AIClassifier.create_client(
provider="ollama",
api_key="", # Ollama 不需要 API key
model="llama3.2"
)
result = client.classify("待分析的文本")
""")
def main():
"""主测试函数"""
print("=" * 60)
print("AI 模块测试")
print("=" * 60)
print()
# 运行所有测试
test_classification_result()
test_category_types()
test_ai_classifier_creation()
test_mock_classification()
test_error_handling()
# 打印使用示例
print_usage_examples()
print("\n" + "=" * 60)
print("所有测试完成!")
print("=" * 60)
if __name__ == "__main__":
main()