""" AI 分类模块 负责调用不同的 AI 提供商进行文本分类和内容生成 支持的提供商:OpenAI, Anthropic (Claude), 通义千问, 本地 Ollama """ import json import time from typing import Optional, Dict, Any, List from enum import Enum from dataclasses import dataclass, field, asdict import logging logger = logging.getLogger(__name__) class CategoryType(str, Enum): """文本分类类型枚举""" TODO = "TODO" # 待办事项 NOTE = "NOTE" # 笔记 IDEA = "IDEA" # 灵感 REF = "REF" # 参考资料 FUNNY = "FUNNY" # 搞笑文案 TEXT = "TEXT" # 纯文本 @classmethod def all(cls) -> List[str]: """获取所有分类类型""" return [c.value for c in cls] @classmethod def is_valid(cls, category: str) -> bool: """验证分类是否有效""" return category in cls.all() @dataclass class ClassificationResult: """AI 分类结果数据结构""" category: CategoryType # 分类类型 confidence: float # 置信度 (0-1) title: str # 生成的标题 content: str # 生成的 Markdown 内容 tags: List[str] = field(default_factory=list) # 提取的标签 reasoning: str = "" # AI 的分类理由(可选) raw_response: str = "" # 原始响应(用于调试) def to_dict(self) -> Dict[str, Any]: """转换为字典""" return { 'category': self.category.value, 'confidence': self.confidence, 'title': self.title, 'content': self.content, 'tags': self.tags, 'reasoning': self.reasoning, } @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'ClassificationResult': """从字典创建实例""" return cls( category=CategoryType(data['category']), confidence=data.get('confidence', 0.0), title=data.get('title', ''), content=data.get('content', ''), tags=data.get('tags', []), reasoning=data.get('reasoning', ''), ) class AIError(Exception): """AI 调用错误基类""" pass class AIAPIError(AIError): """AI API 调用错误""" pass class AIRateLimitError(AIError): """AI API 速率限制错误""" pass class AIAuthenticationError(AIError): """AI 认证错误""" pass class AITimeoutError(AIError): """AI 请求超时错误""" pass # 分类提示词模板 CLASSIFICATION_PROMPT_TEMPLATE = """你是一个智能文本分类助手。请分析以下OCR识别的文本,将其分类为以下6种类型之一: ## 分类类型说明 1. **TODO (待办事项)**:包含任务、待办清单、行动项、计划等内容 - 特征:包含"待办"、"任务"、"完成"、"截止日期"等关键词 - 例如:工作计划、购物清单、行动项列表 2. **NOTE (笔记)**:学习笔记、会议记录、知识整理、信息摘录 - 特征:知识性、信息性内容,通常是学习或工作的记录 - 例如:课程笔记、会议纪要、知识点总结 3. **IDEA (灵感)**:创新想法、产品思路、创意点子、灵感记录 - 特征:创造性、前瞻性、头脑风暴相关 - 例如:产品创意、写作灵感、改进建议 4. **REF (参考资料)**:需要保存的参考资料、文档片段、教程链接 - 特征:信息密度高,作为后续参考使用 - 例如:API文档、配置示例、技术教程 5. **FUNNY (搞笑文案)**:幽默段子、搞笑图片文字、娱乐内容 - 特征:娱乐性、搞笑、轻松的内容 - 例如:段子、表情包配文、搞笑对话 6. **TEXT (纯文本)**:不适合归入以上类别的普通文本 - 特征:信息量较低或难以明确分类的内容 - 例如:广告、通知、普通对话 ## 任务要求 请分析以下文本,并以 JSON 格式返回分类结果: ```json {{ "category": "分类类型(TODO/NOTE/IDEA/REF/FUNNY/TEXT之一)", "confidence": 0.95, "title": "生成的简短标题(不超过20字)", "content": "根据文本内容整理成 Markdown 格式的结构化内容", "tags": ["标签1", "标签2", "标签3"], "reasoning": "选择该分类的理由(简短说明)" }} ``` ## 注意事项 - content 字段要生成格式化的 Markdown 内容,使用列表、标题等结构化元素 - 对于 TODO 类型,请用任务列表格式:- [ ] 任务1 - 对于 NOTE 类型,请用清晰的标题和分段 - 对于 IDEA 类型,突出创新点 - 对于 REF 类型,保留关键信息和结构 - 对于 FUNNY 类型,保留原文的趣味性 - confidence 为 0-1 之间的浮点数,表示分类的置信度 - 提取 3-5 个最相关的标签 ## 待分析的文本 ``` {text} ``` 请仅返回 JSON 格式,不要包含其他说明文字。 """ class AIClientBase: """AI 客户端基类""" def __init__( self, api_key: str, model: str, temperature: float = 0.7, max_tokens: int = 4096, timeout: int = 60, max_retries: int = 3, retry_delay: float = 1.0, ): """ 初始化 AI 客户端 Args: api_key: API 密钥 model: 模型名称 temperature: 温度参数 (0-2) max_tokens: 最大生成长度 timeout: 请求超时时间(秒) max_retries: 最大重试次数 retry_delay: 重试延迟(秒) """ self.api_key = api_key self.model = model self.temperature = temperature self.max_tokens = max_tokens self.timeout = timeout self.max_retries = max_retries self.retry_delay = retry_delay def classify(self, text: str) -> ClassificationResult: """ 对文本进行分类 Args: text: 待分类的文本 Returns: ClassificationResult: 分类结果 Raises: AIError: 分类失败 """ raise NotImplementedError("子类必须实现此方法") def _parse_json_response(self, response_text: str) -> Dict[str, Any]: """ 解析 JSON 响应 Args: response_text: AI 返回的文本 Returns: 解析后的字典 Raises: AIError: JSON 解析失败 """ # 尝试直接解析 try: return json.loads(response_text.strip()) except json.JSONDecodeError: pass # 尝试提取 JSON 代码块 if "```json" in response_text: start = response_text.find("```json") + 7 end = response_text.find("```", start) if end != -1: try: json_str = response_text[start:end].strip() return json.loads(json_str) except json.JSONDecodeError: pass # 尝试提取普通代码块 if "```" in response_text: start = response_text.find("```") + 3 end = response_text.find("```", start) if end != -1: try: json_str = response_text[start:end].strip() return json.loads(json_str) except json.JSONDecodeError: pass # 尝试查找 { } 包围的 JSON start = response_text.find("{") end = response_text.rfind("}") if start != -1 and end != -1 and end > start: try: json_str = response_text[start:end+1] return json.loads(json_str) except json.JSONDecodeError: pass raise AIError(f"无法解析 AI 响应为 JSON: {response_text[:200]}...") def _retry_on_failure(self, func, *args, **kwargs): """ 在失败时重试 Args: func: 要执行的函数 *args: 位置参数 **kwargs: 关键字参数 Returns: 函数执行结果 Raises: AIError: 重试次数用尽后仍然失败 """ last_error = None for attempt in range(self.max_retries): try: return func(*args, **kwargs) except Exception as e: last_error = e logger.warning(f"AI 调用失败(尝试 {attempt + 1}/{self.max_retries}): {e}") # 最后一次不等待 if attempt < self.max_retries - 1: delay = self.retry_delay * (2 ** attempt) # 指数退避 logger.info(f"等待 {delay:.1f} 秒后重试...") time.sleep(delay) raise AIError(f"AI 调用失败,已重试 {self.max_retries} 次: {last_error}") class OpenAIClient(AIClientBase): """OpenAI 客户端""" def __init__(self, **kwargs): super().__init__(**kwargs) try: import openai self.openai = openai self.client = openai.OpenAI(api_key=self.api_key, timeout=self.timeout) except ImportError: raise AIError("OpenAI 库未安装,请运行: pip install openai") def classify(self, text: str) -> ClassificationResult: """使用 OpenAI API 进行分类""" def _do_classify(): prompt = CLASSIFICATION_PROMPT_TEMPLATE.format(text=text[:4000]) # 限制长度 try: response = self.client.chat.completions.create( model=self.model, messages=[ {"role": "system", "content": "你是一个专业的文本分类助手。"}, {"role": "user", "content": prompt} ], temperature=self.temperature, max_tokens=self.max_tokens, ) result_text = response.choices[0].message.content.strip() # 解析 JSON 响应 result_dict = self._parse_json_response(result_text) # 验证分类 category = result_dict.get('category', 'TEXT') if not CategoryType.is_valid(category): category = 'TEXT' return ClassificationResult( category=CategoryType(category), confidence=float(result_dict.get('confidence', 0.8)), title=str(result_dict.get('title', '未命名'))[:50], content=str(result_dict.get('content', text)), tags=list(result_dict.get('tags', []))[:5], reasoning=str(result_dict.get('reasoning', '')), raw_response=result_text, ) except self.openai.AuthenticationError as e: raise AIAuthenticationError(f"OpenAI 认证失败: {e}") except self.openai.RateLimitError as e: raise AIRateLimitError(f"OpenAI API 速率限制: {e}") except self.openai.APITimeoutError as e: raise AITimeoutError(f"OpenAI API 请求超时: {e}") except self.openai.APIError as e: raise AIAPIError(f"OpenAI API 错误: {e}") return self._retry_on_failure(_do_classify) class AnthropicClient(AIClientBase): """Anthropic (Claude) 客户端""" def __init__(self, **kwargs): super().__init__(**kwargs) try: import anthropic self.anthropic = anthropic self.client = anthropic.Anthropic(api_key=self.api_key, timeout=self.timeout) except ImportError: raise AIError("Anthropic 库未安装,请运行: pip install anthropic") def classify(self, text: str) -> ClassificationResult: """使用 Claude API 进行分类""" def _do_classify(): prompt = CLASSIFICATION_PROMPT_TEMPLATE.format(text=text[:4000]) try: response = self.client.messages.create( model=self.model, max_tokens=self.max_tokens, temperature=self.temperature, messages=[ {"role": "user", "content": prompt} ] ) result_text = response.content[0].text.strip() # 解析 JSON 响应 result_dict = self._parse_json_response(result_text) # 验证分类 category = result_dict.get('category', 'TEXT') if not CategoryType.is_valid(category): category = 'TEXT' return ClassificationResult( category=CategoryType(category), confidence=float(result_dict.get('confidence', 0.8)), title=str(result_dict.get('title', '未命名'))[:50], content=str(result_dict.get('content', text)), tags=list(result_dict.get('tags', []))[:5], reasoning=str(result_dict.get('reasoning', '')), raw_response=result_text, ) except self.anthropic.AuthenticationError as e: raise AIAuthenticationError(f"Claude 认证失败: {e}") except self.anthropic.RateLimitError as e: raise AIRateLimitError(f"Claude API 速率限制: {e}") except self.anthropic.APITimeoutError as e: raise AITimeoutError(f"Claude API 请求超时: {e}") except self.anthropic.APIError as e: raise AIAPIError(f"Claude API 错误: {e}") return self._retry_on_failure(_do_classify) class QwenClient(AIClientBase): """通义千问客户端 (兼容 OpenAI API)""" def __init__(self, base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1", **kwargs): super().__init__(**kwargs) self.base_url = base_url try: import openai self.openai = openai self.client = openai.OpenAI( api_key=self.api_key, base_url=self.base_url, timeout=self.timeout ) except ImportError: raise AIError("OpenAI 库未安装,请运行: pip install openai") def classify(self, text: str) -> ClassificationResult: """使用通义千问 API 进行分类""" def _do_classify(): prompt = CLASSIFICATION_PROMPT_TEMPLATE.format(text=text[:4000]) try: response = self.client.chat.completions.create( model=self.model, messages=[ {"role": "system", "content": "你是一个专业的文本分类助手。"}, {"role": "user", "content": prompt} ], temperature=self.temperature, max_tokens=self.max_tokens, ) result_text = response.choices[0].message.content.strip() # 解析 JSON 响应 result_dict = self._parse_json_response(result_text) # 验证分类 category = result_dict.get('category', 'TEXT') if not CategoryType.is_valid(category): category = 'TEXT' return ClassificationResult( category=CategoryType(category), confidence=float(result_dict.get('confidence', 0.8)), title=str(result_dict.get('title', '未命名'))[:50], content=str(result_dict.get('content', text)), tags=list(result_dict.get('tags', []))[:5], reasoning=str(result_dict.get('reasoning', '')), raw_response=result_text, ) except Exception as e: if "authentication" in str(e).lower(): raise AIAuthenticationError(f"通义千问认证失败: {e}") elif "rate limit" in str(e).lower(): raise AIRateLimitError(f"通义千问 API 速率限制: {e}") elif "timeout" in str(e).lower(): raise AITimeoutError(f"通义千问 API 请求超时: {e}") else: raise AIAPIError(f"通义千问 API 错误: {e}") return self._retry_on_failure(_do_classify) class OllamaClient(AIClientBase): """Ollama 本地模型客户端 (兼容 OpenAI API)""" def __init__(self, base_url: str = "http://localhost:11434/v1", **kwargs): super().__init__(**kwargs) self.base_url = base_url try: import openai self.openai = openai # Ollama 通常不需要 API key,使用任意值 self.client = openai.OpenAI( api_key=self.api_key or "ollama", base_url=self.base_url, timeout=self.timeout ) except ImportError: raise AIError("OpenAI 库未安装,请运行: pip install openai") def classify(self, text: str) -> ClassificationResult: """使用 Ollama 本地模型进行分类""" def _do_classify(): prompt = CLASSIFICATION_PROMPT_TEMPLATE.format(text=text[:4000]) try: response = self.client.chat.completions.create( model=self.model, messages=[ {"role": "system", "content": "你是一个专业的文本分类助手。"}, {"role": "user", "content": prompt} ], temperature=self.temperature, max_tokens=self.max_tokens, ) result_text = response.choices[0].message.content.strip() # 解析 JSON 响应 result_dict = self._parse_json_response(result_text) # 验证分类 category = result_dict.get('category', 'TEXT') if not CategoryType.is_valid(category): category = 'TEXT' return ClassificationResult( category=CategoryType(category), confidence=float(result_dict.get('confidence', 0.8)), title=str(result_dict.get('title', '未命名'))[:50], content=str(result_dict.get('content', text)), tags=list(result_dict.get('tags', []))[:5], reasoning=str(result_dict.get('reasoning', '')), raw_response=result_text, ) except Exception as e: if "connection" in str(e).lower(): raise AIError(f"无法连接到 Ollama 服务 ({self.base_url}): {e}") else: raise AIAPIError(f"Ollama API 错误: {e}") return self._retry_on_failure(_do_classify) class AIClassifier: """ AI 分类器主类 根据配置自动选择合适的 AI 客户端进行文本分类 """ # 支持的提供商映射 CLIENTS = { "openai": OpenAIClient, "anthropic": AnthropicClient, "qwen": QwenClient, "ollama": OllamaClient, } @classmethod def create_client( cls, provider: str, api_key: str, model: str, **kwargs ) -> AIClientBase: """ 创建 AI 客户端 Args: provider: 提供商名称 (openai, anthropic, qwen, ollama) api_key: API 密钥 model: 模型名称 **kwargs: 其他参数 Returns: AI 客户端实例 Raises: AIError: 不支持的提供商 """ provider_lower = provider.lower() if provider_lower not in cls.CLIENTS: raise AIError( f"不支持的 AI 提供商: {provider}. " f"支持的提供商: {', '.join(cls.CLIENTS.keys())}" ) client_class = cls.CLIENTS[provider_lower] # 根据不同提供商设置默认模型 if not model: default_models = { "openai": "gpt-4o-mini", "anthropic": "claude-3-5-sonnet-20241022", "qwen": "qwen-turbo", "ollama": "llama3.2", } model = default_models.get(provider_lower, "default") return client_class( api_key=api_key, model=model, **kwargs ) @classmethod def classify( cls, text: str, provider: str, api_key: str, model: str = "", **kwargs ) -> ClassificationResult: """ 对文本进行分类 Args: text: 待分类的文本 provider: 提供商名称 api_key: API 密钥 model: 模型名称 **kwargs: 其他参数 Returns: ClassificationResult: 分类结果 """ client = cls.create_client(provider, api_key, model, **kwargs) return client.classify(text) def create_classifier_from_config(ai_config) -> AIClassifier: """ 从配置对象创建 AI 分类器 Args: ai_config: AI 配置对象 (来自 config.settings.AIConfig) Returns: 配置好的 AI 客户端 Example: >>> from src.config.settings import get_settings >>> settings = get_settings() >>> client = create_classifier_from_config(settings.ai) >>> result = client.classify("待分析的文本") """ return AIClassifier.create_client( provider=ai_config.provider.value, api_key=ai_config.api_key, model=ai_config.model, temperature=ai_config.temperature, max_tokens=ai_config.max_tokens, timeout=ai_config.timeout, ) # 便捷函数 def classify_text(text: str, ai_config) -> ClassificationResult: """ 使用配置的 AI 服务对文本进行分类 Args: text: 待分类的文本 ai_config: AI 配置对象 Returns: ClassificationResult: 分类结果 Raises: AIError: 分类失败 """ client = create_classifier_from_config(ai_config) return client.classify(text)