226 lines
6.2 KiB
Python
226 lines
6.2 KiB
Python
|
|
"""
|
|||
|
|
AI 模块测试脚本
|
|||
|
|
|
|||
|
|
测试各个 AI 提供商的文本分类功能
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import sys
|
|||
|
|
import os
|
|||
|
|
|
|||
|
|
# 添加项目根目录到 Python 路径
|
|||
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|||
|
|
|
|||
|
|
from src.core.ai import (
|
|||
|
|
CategoryType,
|
|||
|
|
ClassificationResult,
|
|||
|
|
AIClassifier,
|
|||
|
|
classify_text,
|
|||
|
|
AIError,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_classification_result():
|
|||
|
|
"""测试分类结果数据结构"""
|
|||
|
|
print("=== 测试分类结果数据结构 ===")
|
|||
|
|
|
|||
|
|
result = ClassificationResult(
|
|||
|
|
category=CategoryType.TODO,
|
|||
|
|
confidence=0.95,
|
|||
|
|
title="测试任务",
|
|||
|
|
content="- [ ] 完成测试",
|
|||
|
|
tags=["测试", "任务"],
|
|||
|
|
reasoning="包含待办事项关键词"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print(f"分类: {result.category}")
|
|||
|
|
print(f"置信度: {result.confidence}")
|
|||
|
|
print(f"标题: {result.title}")
|
|||
|
|
print(f"内容: {result.content}")
|
|||
|
|
print(f"标签: {result.tags}")
|
|||
|
|
print(f"理由: {result.reasoning}")
|
|||
|
|
|
|||
|
|
# 测试转换为字典
|
|||
|
|
result_dict = result.to_dict()
|
|||
|
|
print(f"\n转换为字典: {result_dict}")
|
|||
|
|
|
|||
|
|
# 测试从字典创建
|
|||
|
|
result2 = ClassificationResult.from_dict(result_dict)
|
|||
|
|
print(f"\n从字典恢复: 分类={result2.category}, 标题={result2.title}")
|
|||
|
|
|
|||
|
|
print("✅ 分类结果数据结构测试通过\n")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_category_types():
|
|||
|
|
"""测试分类类型枚举"""
|
|||
|
|
print("=== 测试分类类型枚举 ===")
|
|||
|
|
|
|||
|
|
print("所有分类类型:")
|
|||
|
|
for category in CategoryType:
|
|||
|
|
print(f" - {category.name}: {category.value}")
|
|||
|
|
|
|||
|
|
print(f"\n所有分类值: {CategoryType.all()}")
|
|||
|
|
|
|||
|
|
# 测试验证
|
|||
|
|
assert CategoryType.is_valid("TODO") == True
|
|||
|
|
assert CategoryType.is_valid("NOTE") == True
|
|||
|
|
assert CategoryType.is_valid("INVALID") == False
|
|||
|
|
|
|||
|
|
print("✅ 分类类型枚举测试通过\n")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_ai_classifier_creation():
|
|||
|
|
"""测试 AI 分类器创建"""
|
|||
|
|
print("=== 测试 AI 分类器创建 ===")
|
|||
|
|
|
|||
|
|
providers = ["openai", "anthropic", "qwen", "ollama"]
|
|||
|
|
|
|||
|
|
for provider in providers:
|
|||
|
|
try:
|
|||
|
|
client = AIClassifier.create_client(
|
|||
|
|
provider=provider,
|
|||
|
|
api_key="test_key",
|
|||
|
|
model="test_model"
|
|||
|
|
)
|
|||
|
|
print(f"✓ {provider} 客户端创建成功: {type(client).__name__}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"✗ {provider} 客户端创建失败: {e}")
|
|||
|
|
|
|||
|
|
print("✅ AI 分类器创建测试通过\n")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_mock_classification():
|
|||
|
|
"""模拟分类测试(不实际调用 API)"""
|
|||
|
|
print("=== 模拟分类测试 ===")
|
|||
|
|
|
|||
|
|
# 测试文本样本
|
|||
|
|
test_cases = [
|
|||
|
|
{
|
|||
|
|
"text": "今天要完成的任务:\n1. 完成项目文档\n2. 修复 Bug #123\n3. 参加团队会议",
|
|||
|
|
"expected": CategoryType.TODO,
|
|||
|
|
"description": "待办事项"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"text": "Python 中的列表推导式是一种简洁的语法糖。\n\n示例:\n[x * 2 for x in range(10)]",
|
|||
|
|
"expected": CategoryType.NOTE,
|
|||
|
|
"description": "编程笔记"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"text": "突然想到一个产品创意:做一个能自动识别截图分类的工具!\n可以使用 AI + OCR 实现。",
|
|||
|
|
"expected": CategoryType.IDEA,
|
|||
|
|
"description": "产品灵感"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"text": "API 文档参考:\nGET /api/users\n获取用户列表\n\n参数:\n- page: 页码\n- limit: 每页数量",
|
|||
|
|
"expected": CategoryType.REF,
|
|||
|
|
"description": "参考资料"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"text": "程序员最讨厌的四件事:\n1. 写注释\n2. 写文档\n3. 别人不写注释\n4. 别人不写文档",
|
|||
|
|
"expected": CategoryType.FUNNY,
|
|||
|
|
"description": "搞笑段子"
|
|||
|
|
},
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
print("测试用例:")
|
|||
|
|
for i, case in enumerate(test_cases, 1):
|
|||
|
|
print(f"\n{i}. {case['description']}")
|
|||
|
|
print(f" 预期分类: {case['expected'].value}")
|
|||
|
|
print(f" 文本预览: {case['text'][:50]}...")
|
|||
|
|
|
|||
|
|
print("\n注意:实际分类需要配置 API key 并调用 AI 服务")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_error_handling():
|
|||
|
|
"""测试错误处理"""
|
|||
|
|
print("=== 测试错误处理 ===")
|
|||
|
|
|
|||
|
|
# 测试不支持的提供商
|
|||
|
|
try:
|
|||
|
|
AIClassifier.create_client(
|
|||
|
|
provider="unsupported_provider",
|
|||
|
|
api_key="test",
|
|||
|
|
model="test-model"
|
|||
|
|
)
|
|||
|
|
print("✗ 应该抛出异常但没有")
|
|||
|
|
except AIError as e:
|
|||
|
|
print(f"✓ 正确捕获不支持的提供商错误: {e}")
|
|||
|
|
|
|||
|
|
print("✅ 错误处理测试通过\n")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def print_usage_examples():
|
|||
|
|
"""打印使用示例"""
|
|||
|
|
print("=== 使用示例 ===\n")
|
|||
|
|
|
|||
|
|
print("1. 使用 OpenAI 进行分类:")
|
|||
|
|
print("""
|
|||
|
|
from src.core.ai import AIClassifier
|
|||
|
|
|
|||
|
|
client = AIClassifier.create_client(
|
|||
|
|
provider="openai",
|
|||
|
|
api_key="your-api-key",
|
|||
|
|
model="gpt-4o-mini"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
result = client.classify("待分析的文本")
|
|||
|
|
print(result.category) # TODO
|
|||
|
|
print(result.content) # Markdown 格式内容
|
|||
|
|
""")
|
|||
|
|
|
|||
|
|
print("\n2. 使用配置文件进行分类:")
|
|||
|
|
print("""
|
|||
|
|
from src.config.settings import get_settings
|
|||
|
|
from src.core.ai import classify_text
|
|||
|
|
|
|||
|
|
settings = get_settings()
|
|||
|
|
result = classify_text("待分析的文本", settings.ai)
|
|||
|
|
""")
|
|||
|
|
|
|||
|
|
print("\n3. 使用 Claude 进行分类:")
|
|||
|
|
print("""
|
|||
|
|
client = AIClassifier.create_client(
|
|||
|
|
provider="anthropic",
|
|||
|
|
api_key="your-api-key",
|
|||
|
|
model="claude-3-5-sonnet-20241022"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
result = client.classify("待分析的文本")
|
|||
|
|
""")
|
|||
|
|
|
|||
|
|
print("\n4. 使用本地 Ollama:")
|
|||
|
|
print("""
|
|||
|
|
client = AIClassifier.create_client(
|
|||
|
|
provider="ollama",
|
|||
|
|
api_key="", # Ollama 不需要 API key
|
|||
|
|
model="llama3.2"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
result = client.classify("待分析的文本")
|
|||
|
|
""")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
"""主测试函数"""
|
|||
|
|
print("=" * 60)
|
|||
|
|
print("AI 模块测试")
|
|||
|
|
print("=" * 60)
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
# 运行所有测试
|
|||
|
|
test_classification_result()
|
|||
|
|
test_category_types()
|
|||
|
|
test_ai_classifier_creation()
|
|||
|
|
test_mock_classification()
|
|||
|
|
test_error_handling()
|
|||
|
|
|
|||
|
|
# 打印使用示例
|
|||
|
|
print_usage_examples()
|
|||
|
|
|
|||
|
|
print("\n" + "=" * 60)
|
|||
|
|
print("所有测试完成!")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|