feat: 实现CutThenThink P0阶段核心功能
项目初始化 - 创建完整项目结构(src/, data/, docs/, examples/, tests/) - 配置requirements.txt依赖 - 创建.gitignore P0基础框架 - 数据库模型:Record模型,6种分类类型 - 配置管理:YAML配置,支持AI/OCR/云存储/UI配置 - OCR模块:PaddleOCR本地识别,支持云端扩展 - AI模块:支持OpenAI/Claude/通义/Ollama,6种分类 - 存储模块:完整CRUD,搜索,统计,导入导出 - 主窗口框架:侧边栏导航,米白配色方案 - 图片处理:截图/剪贴板/文件选择/图片预览 - 处理流程整合:OCR→AI→存储串联,Markdown展示,剪贴板复制 - 分类浏览:卡片网格展示,分类筛选,搜索,详情查看 技术栈 - PyQt6 + SQLAlchemy + PaddleOCR + OpenAI/Claude SDK - 共47个Python文件,4000+行代码 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
680
src/core/ai.py
Normal file
680
src/core/ai.py
Normal file
@@ -0,0 +1,680 @@
|
||||
"""
|
||||
AI 分类模块
|
||||
|
||||
负责调用不同的 AI 提供商进行文本分类和内容生成
|
||||
支持的提供商:OpenAI, Anthropic (Claude), 通义千问, 本地 Ollama
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Optional, Dict, Any, List
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, field, asdict
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CategoryType(str, Enum):
|
||||
"""文本分类类型枚举"""
|
||||
TODO = "TODO" # 待办事项
|
||||
NOTE = "NOTE" # 笔记
|
||||
IDEA = "IDEA" # 灵感
|
||||
REF = "REF" # 参考资料
|
||||
FUNNY = "FUNNY" # 搞笑文案
|
||||
TEXT = "TEXT" # 纯文本
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> List[str]:
|
||||
"""获取所有分类类型"""
|
||||
return [c.value for c in cls]
|
||||
|
||||
@classmethod
|
||||
def is_valid(cls, category: str) -> bool:
|
||||
"""验证分类是否有效"""
|
||||
return category in cls.all()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassificationResult:
|
||||
"""AI 分类结果数据结构"""
|
||||
category: CategoryType # 分类类型
|
||||
confidence: float # 置信度 (0-1)
|
||||
title: str # 生成的标题
|
||||
content: str # 生成的 Markdown 内容
|
||||
tags: List[str] = field(default_factory=list) # 提取的标签
|
||||
reasoning: str = "" # AI 的分类理由(可选)
|
||||
raw_response: str = "" # 原始响应(用于调试)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
'category': self.category.value,
|
||||
'confidence': self.confidence,
|
||||
'title': self.title,
|
||||
'content': self.content,
|
||||
'tags': self.tags,
|
||||
'reasoning': self.reasoning,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ClassificationResult':
|
||||
"""从字典创建实例"""
|
||||
return cls(
|
||||
category=CategoryType(data['category']),
|
||||
confidence=data.get('confidence', 0.0),
|
||||
title=data.get('title', ''),
|
||||
content=data.get('content', ''),
|
||||
tags=data.get('tags', []),
|
||||
reasoning=data.get('reasoning', ''),
|
||||
)
|
||||
|
||||
|
||||
class AIError(Exception):
|
||||
"""AI 调用错误基类"""
|
||||
pass
|
||||
|
||||
|
||||
class AIAPIError(AIError):
|
||||
"""AI API 调用错误"""
|
||||
pass
|
||||
|
||||
|
||||
class AIRateLimitError(AIError):
|
||||
"""AI API 速率限制错误"""
|
||||
pass
|
||||
|
||||
|
||||
class AIAuthenticationError(AIError):
|
||||
"""AI 认证错误"""
|
||||
pass
|
||||
|
||||
|
||||
class AITimeoutError(AIError):
|
||||
"""AI 请求超时错误"""
|
||||
pass
|
||||
|
||||
|
||||
# 分类提示词模板
|
||||
CLASSIFICATION_PROMPT_TEMPLATE = """你是一个智能文本分类助手。请分析以下OCR识别的文本,将其分类为以下6种类型之一:
|
||||
|
||||
## 分类类型说明
|
||||
|
||||
1. **TODO (待办事项)**:包含任务、待办清单、行动项、计划等内容
|
||||
- 特征:包含"待办"、"任务"、"完成"、"截止日期"等关键词
|
||||
- 例如:工作计划、购物清单、行动项列表
|
||||
|
||||
2. **NOTE (笔记)**:学习笔记、会议记录、知识整理、信息摘录
|
||||
- 特征:知识性、信息性内容,通常是学习或工作的记录
|
||||
- 例如:课程笔记、会议纪要、知识点总结
|
||||
|
||||
3. **IDEA (灵感)**:创新想法、产品思路、创意点子、灵感记录
|
||||
- 特征:创造性、前瞻性、头脑风暴相关
|
||||
- 例如:产品创意、写作灵感、改进建议
|
||||
|
||||
4. **REF (参考资料)**:需要保存的参考资料、文档片段、教程链接
|
||||
- 特征:信息密度高,作为后续参考使用
|
||||
- 例如:API文档、配置示例、技术教程
|
||||
|
||||
5. **FUNNY (搞笑文案)**:幽默段子、搞笑图片文字、娱乐内容
|
||||
- 特征:娱乐性、搞笑、轻松的内容
|
||||
- 例如:段子、表情包配文、搞笑对话
|
||||
|
||||
6. **TEXT (纯文本)**:不适合归入以上类别的普通文本
|
||||
- 特征:信息量较低或难以明确分类的内容
|
||||
- 例如:广告、通知、普通对话
|
||||
|
||||
## 任务要求
|
||||
|
||||
请分析以下文本,并以 JSON 格式返回分类结果:
|
||||
|
||||
```json
|
||||
{{
|
||||
"category": "分类类型(TODO/NOTE/IDEA/REF/FUNNY/TEXT之一)",
|
||||
"confidence": 0.95,
|
||||
"title": "生成的简短标题(不超过20字)",
|
||||
"content": "根据文本内容整理成 Markdown 格式的结构化内容",
|
||||
"tags": ["标签1", "标签2", "标签3"],
|
||||
"reasoning": "选择该分类的理由(简短说明)"
|
||||
}}
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
- content 字段要生成格式化的 Markdown 内容,使用列表、标题等结构化元素
|
||||
- 对于 TODO 类型,请用任务列表格式:- [ ] 任务1
|
||||
- 对于 NOTE 类型,请用清晰的标题和分段
|
||||
- 对于 IDEA 类型,突出创新点
|
||||
- 对于 REF 类型,保留关键信息和结构
|
||||
- 对于 FUNNY 类型,保留原文的趣味性
|
||||
- confidence 为 0-1 之间的浮点数,表示分类的置信度
|
||||
- 提取 3-5 个最相关的标签
|
||||
|
||||
## 待分析的文本
|
||||
|
||||
```
|
||||
{text}
|
||||
```
|
||||
|
||||
请仅返回 JSON 格式,不要包含其他说明文字。
|
||||
"""
|
||||
|
||||
|
||||
class AIClientBase:
|
||||
"""AI 客户端基类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4096,
|
||||
timeout: int = 60,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
):
|
||||
"""
|
||||
初始化 AI 客户端
|
||||
|
||||
Args:
|
||||
api_key: API 密钥
|
||||
model: 模型名称
|
||||
temperature: 温度参数 (0-2)
|
||||
max_tokens: 最大生成长度
|
||||
timeout: 请求超时时间(秒)
|
||||
max_retries: 最大重试次数
|
||||
retry_delay: 重试延迟(秒)
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
|
||||
def classify(self, text: str) -> ClassificationResult:
|
||||
"""
|
||||
对文本进行分类
|
||||
|
||||
Args:
|
||||
text: 待分类的文本
|
||||
|
||||
Returns:
|
||||
ClassificationResult: 分类结果
|
||||
|
||||
Raises:
|
||||
AIError: 分类失败
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现此方法")
|
||||
|
||||
def _parse_json_response(self, response_text: str) -> Dict[str, Any]:
|
||||
"""
|
||||
解析 JSON 响应
|
||||
|
||||
Args:
|
||||
response_text: AI 返回的文本
|
||||
|
||||
Returns:
|
||||
解析后的字典
|
||||
|
||||
Raises:
|
||||
AIError: JSON 解析失败
|
||||
"""
|
||||
# 尝试直接解析
|
||||
try:
|
||||
return json.loads(response_text.strip())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试提取 JSON 代码块
|
||||
if "```json" in response_text:
|
||||
start = response_text.find("```json") + 7
|
||||
end = response_text.find("```", start)
|
||||
if end != -1:
|
||||
try:
|
||||
json_str = response_text[start:end].strip()
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试提取普通代码块
|
||||
if "```" in response_text:
|
||||
start = response_text.find("```") + 3
|
||||
end = response_text.find("```", start)
|
||||
if end != -1:
|
||||
try:
|
||||
json_str = response_text[start:end].strip()
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试查找 { } 包围的 JSON
|
||||
start = response_text.find("{")
|
||||
end = response_text.rfind("}")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
try:
|
||||
json_str = response_text[start:end+1]
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
raise AIError(f"无法解析 AI 响应为 JSON: {response_text[:200]}...")
|
||||
|
||||
def _retry_on_failure(self, func, *args, **kwargs):
|
||||
"""
|
||||
在失败时重试
|
||||
|
||||
Args:
|
||||
func: 要执行的函数
|
||||
*args: 位置参数
|
||||
**kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
函数执行结果
|
||||
|
||||
Raises:
|
||||
AIError: 重试次数用尽后仍然失败
|
||||
"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.warning(f"AI 调用失败(尝试 {attempt + 1}/{self.max_retries}): {e}")
|
||||
|
||||
# 最后一次不等待
|
||||
if attempt < self.max_retries - 1:
|
||||
delay = self.retry_delay * (2 ** attempt) # 指数退避
|
||||
logger.info(f"等待 {delay:.1f} 秒后重试...")
|
||||
time.sleep(delay)
|
||||
|
||||
raise AIError(f"AI 调用失败,已重试 {self.max_retries} 次: {last_error}")
|
||||
|
||||
|
||||
class OpenAIClient(AIClientBase):
|
||||
"""OpenAI 客户端"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import openai
|
||||
self.openai = openai
|
||||
self.client = openai.OpenAI(api_key=self.api_key, timeout=self.timeout)
|
||||
except ImportError:
|
||||
raise AIError("OpenAI 库未安装,请运行: pip install openai")
|
||||
|
||||
def classify(self, text: str) -> ClassificationResult:
|
||||
"""使用 OpenAI API 进行分类"""
|
||||
|
||||
def _do_classify():
|
||||
prompt = CLASSIFICATION_PROMPT_TEMPLATE.format(text=text[:4000]) # 限制长度
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "你是一个专业的文本分类助手。"},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
result_text = response.choices[0].message.content.strip()
|
||||
|
||||
# 解析 JSON 响应
|
||||
result_dict = self._parse_json_response(result_text)
|
||||
|
||||
# 验证分类
|
||||
category = result_dict.get('category', 'TEXT')
|
||||
if not CategoryType.is_valid(category):
|
||||
category = 'TEXT'
|
||||
|
||||
return ClassificationResult(
|
||||
category=CategoryType(category),
|
||||
confidence=float(result_dict.get('confidence', 0.8)),
|
||||
title=str(result_dict.get('title', '未命名'))[:50],
|
||||
content=str(result_dict.get('content', text)),
|
||||
tags=list(result_dict.get('tags', []))[:5],
|
||||
reasoning=str(result_dict.get('reasoning', '')),
|
||||
raw_response=result_text,
|
||||
)
|
||||
|
||||
except self.openai.AuthenticationError as e:
|
||||
raise AIAuthenticationError(f"OpenAI 认证失败: {e}")
|
||||
except self.openai.RateLimitError as e:
|
||||
raise AIRateLimitError(f"OpenAI API 速率限制: {e}")
|
||||
except self.openai.APITimeoutError as e:
|
||||
raise AITimeoutError(f"OpenAI API 请求超时: {e}")
|
||||
except self.openai.APIError as e:
|
||||
raise AIAPIError(f"OpenAI API 错误: {e}")
|
||||
|
||||
return self._retry_on_failure(_do_classify)
|
||||
|
||||
|
||||
class AnthropicClient(AIClientBase):
|
||||
"""Anthropic (Claude) 客户端"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import anthropic
|
||||
self.anthropic = anthropic
|
||||
self.client = anthropic.Anthropic(api_key=self.api_key, timeout=self.timeout)
|
||||
except ImportError:
|
||||
raise AIError("Anthropic 库未安装,请运行: pip install anthropic")
|
||||
|
||||
def classify(self, text: str) -> ClassificationResult:
|
||||
"""使用 Claude API 进行分类"""
|
||||
|
||||
def _do_classify():
|
||||
prompt = CLASSIFICATION_PROMPT_TEMPLATE.format(text=text[:4000])
|
||||
|
||||
try:
|
||||
response = self.client.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=self.temperature,
|
||||
messages=[
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
)
|
||||
|
||||
result_text = response.content[0].text.strip()
|
||||
|
||||
# 解析 JSON 响应
|
||||
result_dict = self._parse_json_response(result_text)
|
||||
|
||||
# 验证分类
|
||||
category = result_dict.get('category', 'TEXT')
|
||||
if not CategoryType.is_valid(category):
|
||||
category = 'TEXT'
|
||||
|
||||
return ClassificationResult(
|
||||
category=CategoryType(category),
|
||||
confidence=float(result_dict.get('confidence', 0.8)),
|
||||
title=str(result_dict.get('title', '未命名'))[:50],
|
||||
content=str(result_dict.get('content', text)),
|
||||
tags=list(result_dict.get('tags', []))[:5],
|
||||
reasoning=str(result_dict.get('reasoning', '')),
|
||||
raw_response=result_text,
|
||||
)
|
||||
|
||||
except self.anthropic.AuthenticationError as e:
|
||||
raise AIAuthenticationError(f"Claude 认证失败: {e}")
|
||||
except self.anthropic.RateLimitError as e:
|
||||
raise AIRateLimitError(f"Claude API 速率限制: {e}")
|
||||
except self.anthropic.APITimeoutError as e:
|
||||
raise AITimeoutError(f"Claude API 请求超时: {e}")
|
||||
except self.anthropic.APIError as e:
|
||||
raise AIAPIError(f"Claude API 错误: {e}")
|
||||
|
||||
return self._retry_on_failure(_do_classify)
|
||||
|
||||
|
||||
class QwenClient(AIClientBase):
|
||||
"""通义千问客户端 (兼容 OpenAI API)"""
|
||||
|
||||
def __init__(self, base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.base_url = base_url
|
||||
try:
|
||||
import openai
|
||||
self.openai = openai
|
||||
self.client = openai.OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
timeout=self.timeout
|
||||
)
|
||||
except ImportError:
|
||||
raise AIError("OpenAI 库未安装,请运行: pip install openai")
|
||||
|
||||
def classify(self, text: str) -> ClassificationResult:
|
||||
"""使用通义千问 API 进行分类"""
|
||||
|
||||
def _do_classify():
|
||||
prompt = CLASSIFICATION_PROMPT_TEMPLATE.format(text=text[:4000])
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "你是一个专业的文本分类助手。"},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
result_text = response.choices[0].message.content.strip()
|
||||
|
||||
# 解析 JSON 响应
|
||||
result_dict = self._parse_json_response(result_text)
|
||||
|
||||
# 验证分类
|
||||
category = result_dict.get('category', 'TEXT')
|
||||
if not CategoryType.is_valid(category):
|
||||
category = 'TEXT'
|
||||
|
||||
return ClassificationResult(
|
||||
category=CategoryType(category),
|
||||
confidence=float(result_dict.get('confidence', 0.8)),
|
||||
title=str(result_dict.get('title', '未命名'))[:50],
|
||||
content=str(result_dict.get('content', text)),
|
||||
tags=list(result_dict.get('tags', []))[:5],
|
||||
reasoning=str(result_dict.get('reasoning', '')),
|
||||
raw_response=result_text,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if "authentication" in str(e).lower():
|
||||
raise AIAuthenticationError(f"通义千问认证失败: {e}")
|
||||
elif "rate limit" in str(e).lower():
|
||||
raise AIRateLimitError(f"通义千问 API 速率限制: {e}")
|
||||
elif "timeout" in str(e).lower():
|
||||
raise AITimeoutError(f"通义千问 API 请求超时: {e}")
|
||||
else:
|
||||
raise AIAPIError(f"通义千问 API 错误: {e}")
|
||||
|
||||
return self._retry_on_failure(_do_classify)
|
||||
|
||||
|
||||
class OllamaClient(AIClientBase):
|
||||
"""Ollama 本地模型客户端 (兼容 OpenAI API)"""
|
||||
|
||||
def __init__(self, base_url: str = "http://localhost:11434/v1", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.base_url = base_url
|
||||
try:
|
||||
import openai
|
||||
self.openai = openai
|
||||
# Ollama 通常不需要 API key,使用任意值
|
||||
self.client = openai.OpenAI(
|
||||
api_key=self.api_key or "ollama",
|
||||
base_url=self.base_url,
|
||||
timeout=self.timeout
|
||||
)
|
||||
except ImportError:
|
||||
raise AIError("OpenAI 库未安装,请运行: pip install openai")
|
||||
|
||||
def classify(self, text: str) -> ClassificationResult:
|
||||
"""使用 Ollama 本地模型进行分类"""
|
||||
|
||||
def _do_classify():
|
||||
prompt = CLASSIFICATION_PROMPT_TEMPLATE.format(text=text[:4000])
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "你是一个专业的文本分类助手。"},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
result_text = response.choices[0].message.content.strip()
|
||||
|
||||
# 解析 JSON 响应
|
||||
result_dict = self._parse_json_response(result_text)
|
||||
|
||||
# 验证分类
|
||||
category = result_dict.get('category', 'TEXT')
|
||||
if not CategoryType.is_valid(category):
|
||||
category = 'TEXT'
|
||||
|
||||
return ClassificationResult(
|
||||
category=CategoryType(category),
|
||||
confidence=float(result_dict.get('confidence', 0.8)),
|
||||
title=str(result_dict.get('title', '未命名'))[:50],
|
||||
content=str(result_dict.get('content', text)),
|
||||
tags=list(result_dict.get('tags', []))[:5],
|
||||
reasoning=str(result_dict.get('reasoning', '')),
|
||||
raw_response=result_text,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if "connection" in str(e).lower():
|
||||
raise AIError(f"无法连接到 Ollama 服务 ({self.base_url}): {e}")
|
||||
else:
|
||||
raise AIAPIError(f"Ollama API 错误: {e}")
|
||||
|
||||
return self._retry_on_failure(_do_classify)
|
||||
|
||||
|
||||
class AIClassifier:
|
||||
"""
|
||||
AI 分类器主类
|
||||
|
||||
根据配置自动选择合适的 AI 客户端进行文本分类
|
||||
"""
|
||||
|
||||
# 支持的提供商映射
|
||||
CLIENTS = {
|
||||
"openai": OpenAIClient,
|
||||
"anthropic": AnthropicClient,
|
||||
"qwen": QwenClient,
|
||||
"ollama": OllamaClient,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_client(
|
||||
cls,
|
||||
provider: str,
|
||||
api_key: str,
|
||||
model: str,
|
||||
**kwargs
|
||||
) -> AIClientBase:
|
||||
"""
|
||||
创建 AI 客户端
|
||||
|
||||
Args:
|
||||
provider: 提供商名称 (openai, anthropic, qwen, ollama)
|
||||
api_key: API 密钥
|
||||
model: 模型名称
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
AI 客户端实例
|
||||
|
||||
Raises:
|
||||
AIError: 不支持的提供商
|
||||
"""
|
||||
provider_lower = provider.lower()
|
||||
|
||||
if provider_lower not in cls.CLIENTS:
|
||||
raise AIError(
|
||||
f"不支持的 AI 提供商: {provider}. "
|
||||
f"支持的提供商: {', '.join(cls.CLIENTS.keys())}"
|
||||
)
|
||||
|
||||
client_class = cls.CLIENTS[provider_lower]
|
||||
|
||||
# 根据不同提供商设置默认模型
|
||||
if not model:
|
||||
default_models = {
|
||||
"openai": "gpt-4o-mini",
|
||||
"anthropic": "claude-3-5-sonnet-20241022",
|
||||
"qwen": "qwen-turbo",
|
||||
"ollama": "llama3.2",
|
||||
}
|
||||
model = default_models.get(provider_lower, "default")
|
||||
|
||||
return client_class(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def classify(
|
||||
cls,
|
||||
text: str,
|
||||
provider: str,
|
||||
api_key: str,
|
||||
model: str = "",
|
||||
**kwargs
|
||||
) -> ClassificationResult:
|
||||
"""
|
||||
对文本进行分类
|
||||
|
||||
Args:
|
||||
text: 待分类的文本
|
||||
provider: 提供商名称
|
||||
api_key: API 密钥
|
||||
model: 模型名称
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
ClassificationResult: 分类结果
|
||||
"""
|
||||
client = cls.create_client(provider, api_key, model, **kwargs)
|
||||
return client.classify(text)
|
||||
|
||||
|
||||
def create_classifier_from_config(ai_config) -> AIClassifier:
|
||||
"""
|
||||
从配置对象创建 AI 分类器
|
||||
|
||||
Args:
|
||||
ai_config: AI 配置对象 (来自 config.settings.AIConfig)
|
||||
|
||||
Returns:
|
||||
配置好的 AI 客户端
|
||||
|
||||
Example:
|
||||
>>> from src.config.settings import get_settings
|
||||
>>> settings = get_settings()
|
||||
>>> client = create_classifier_from_config(settings.ai)
|
||||
>>> result = client.classify("待分析的文本")
|
||||
"""
|
||||
return AIClassifier.create_client(
|
||||
provider=ai_config.provider.value,
|
||||
api_key=ai_config.api_key,
|
||||
model=ai_config.model,
|
||||
temperature=ai_config.temperature,
|
||||
max_tokens=ai_config.max_tokens,
|
||||
timeout=ai_config.timeout,
|
||||
)
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def classify_text(text: str, ai_config) -> ClassificationResult:
|
||||
"""
|
||||
使用配置的 AI 服务对文本进行分类
|
||||
|
||||
Args:
|
||||
text: 待分类的文本
|
||||
ai_config: AI 配置对象
|
||||
|
||||
Returns:
|
||||
ClassificationResult: 分类结果
|
||||
|
||||
Raises:
|
||||
AIError: 分类失败
|
||||
"""
|
||||
client = create_classifier_from_config(ai_config)
|
||||
return client.classify(text)
|
||||
Reference in New Issue
Block a user