""" OCR 模块 - 纯云端版本 提供云端 API 文字识别功能: - 云端 OCR API 调用(百度/腾讯/阿里云等) - 图片预处理增强 - 多语言支持(中/英/混合) 注意:本版本不包含本地 OCR 引擎,所有 OCR 处理通过云端 API 完成。 """ from abc import ABC, abstractmethod from pathlib import Path from typing import List, Optional, Dict, Any from dataclasses import dataclass from enum import Enum import logging import base64 import io try: from PIL import Image, ImageEnhance, ImageFilter except ImportError: raise ImportError( "请安装图像处理库: pip install pillow" ) try: import requests except ImportError: raise ImportError( "请安装 requests 库: pip install requests" ) # 配置日志 logger = logging.getLogger(__name__) class OCRLanguage(str, Enum): """OCR 支持的语言""" CHINESE = "ch" # 中文 ENGLISH = "en" # 英文 MIXED = "ch_en" # 中英文混合 class OCRProvider(str, Enum): """OCR 云端服务提供商""" BAIDU = "baidu" # 百度 OCR TENCENT = "tencent" # 腾讯云 OCR ALIYUN = "aliyun" # 阿里云 OCR CUSTOM = "custom" # 自定义 API @dataclass class OCRResult: """ OCR 识别结果 Attributes: text: 识别的文本内容 confidence: 置信度 (0-1) bbox: 文本框坐标 [[x1, y1], [x2, y2], [x3, y3], [x4, y4]] line_index: 行索引(从0开始) """ text: str confidence: float bbox: Optional[List[List[float]]] = None line_index: int = 0 def __repr__(self) -> str: return f"OCRResult(text='{self.text[:30]}...', confidence={self.confidence:.2f})" @dataclass class OCRBatchResult: """ OCR 批量识别结果 Attributes: results: 所有的识别结果列表 full_text: 完整文本(所有行拼接) total_confidence: 平均置信度 success: 是否识别成功 error_message: 错误信息(如果失败) """ results: List[OCRResult] full_text: str total_confidence: float success: bool = True error_message: Optional[str] = None def __repr__(self) -> str: return f"OCRBatchResult(lines={len(self.results)}, confidence={self.total_confidence:.2f})" class ImagePreprocessor: """ 图像预处理器 提供常见的图像增强和预处理功能,提高 OCR 识别准确率 """ @staticmethod def load_image(image_path: str) -> Image.Image: """ 加载图像 Args: image_path: 图像文件路径 Returns: PIL Image 对象 """ image = Image.open(image_path) # 转换为 RGB 模式 if image.mode != 'RGB': image = image.convert('RGB') return image @staticmethod def resize_image(image: Image.Image, max_width: int = 2000) -> Image.Image: """ 调整图像大小(保持宽高比) Args: image: PIL Image 对象 max_width: 最大宽度 Returns: 调整后的图像 """ if image.width > max_width: ratio = max_width / image.width new_height = int(image.height * ratio) image = image.resize((max_width, new_height), Image.Resampling.LANCZOS) return image @staticmethod def enhance_contrast(image: Image.Image, factor: float = 1.5) -> Image.Image: """ 增强对比度 Args: image: PIL Image 对象 factor: 增强因子,1.0 表示原始,>1.0 增强,<1.0 减弱 Returns: 处理后的图像 """ enhancer = ImageEnhance.Contrast(image) return enhancer.enhance(factor) @staticmethod def enhance_sharpness(image: Image.Image, factor: float = 1.5) -> Image.Image: """ 增强锐度 Args: image: PIL Image 对象 factor: 锐化因子 Returns: 处理后的图像 """ enhancer = ImageEnhance.Sharpness(image) return enhancer.enhance(factor) @staticmethod def denoise(image: Image.Image) -> Image.Image: """ 去噪(使用中值滤波) Args: image: PIL Image 对象 Returns: 处理后的图像 """ return image.filter(ImageFilter.MedianFilter(size=3)) @staticmethod def preprocess( image: Image.Image, resize: bool = True, enhance_contrast: bool = True, enhance_sharpness: bool = True, denoise: bool = False ) -> Image.Image: """ 综合预处理(根据指定选项) Args: image: PIL Image 对象 resize: 是否调整大小 enhance_contrast: 是否增强对比度 enhance_sharpness: 是否增强锐度 denoise: 是否去噪 Returns: 处理后的图像 """ result = image.copy() if resize: result = ImagePreprocessor.resize_image(result) if enhance_contrast: result = ImagePreprocessor.enhance_contrast(result) if enhance_sharpness: result = ImagePreprocessor.enhance_sharpness(result) if denoise: result = ImagePreprocessor.denoise(result) return result @staticmethod def image_to_base64(image: Image.Image, format: str = "JPEG") -> str: """ 将 PIL Image 转换为 base64 编码 Args: image: PIL Image 对象 format: 图像格式 (JPEG/PNG) Returns: base64 编码的字符串 """ buffer = io.BytesIO() image.save(buffer, format=format) img_bytes = buffer.getvalue() return base64.b64encode(img_bytes).decode('utf-8') class BaseOCREngine(ABC): """ OCR 引擎基类 所有 OCR 实现必须继承此类并实现 recognize 方法 """ def __init__(self, config: Optional[Dict[str, Any]] = None): """ 初始化 OCR 引擎 Args: config: OCR 配置字典 """ self.config = config or {} self.preprocessor = ImagePreprocessor() @abstractmethod def recognize( self, image, preprocess: bool = True, **kwargs ) -> OCRBatchResult: """ 识别图像中的文本 Args: image: 图像(可以是路径或 PIL Image) preprocess: 是否预处理图像 **kwargs: 其他参数 Returns: OCRBatchResult: 识别结果 """ pass def _load_image(self, image) -> Image.Image: """ 加载图像(支持多种输入格式) Args: image: 图像(路径或 PIL Image) Returns: PIL Image 对象 """ if isinstance(image, str) or isinstance(image, Path): return self.preprocessor.load_image(str(image)) elif isinstance(image, Image.Image): return image else: raise ValueError(f"不支持的图像类型: {type(image)}") def _calculate_total_confidence(self, results: List[OCRResult]) -> float: """ 计算平均置信度 Args: results: OCR 结果列表 Returns: 平均置信度 (0-1) """ if not results: return 0.0 return sum(r.confidence for r in results) / len(results) class CloudOCREngine(BaseOCREngine): """ 云端 OCR 引擎 支持多种云端 OCR 服务: - 百度 OCR - 腾讯云 OCR - 阿里云 OCR - 自定义 API """ def __init__(self, config: Optional[Dict[str, Any]] = None): """ 初始化云端 OCR 引擎 Args: config: 配置字典,支持: - api_endpoint: API 端点 - api_key: API 密钥 - api_secret: API 密钥(部分服务商需要) - provider: 提供商 (baidu/tencent/aliyun/custom) - timeout: 超时时间(秒) """ super().__init__(config) self.api_endpoint = self.config.get('api_endpoint', '') self.api_key = self.config.get('api_key', '') self.api_secret = self.config.get('api_secret', '') self.provider = self.config.get('provider', 'custom') self.timeout = self.config.get('timeout', 30) if not self.api_endpoint: logger.warning("云端 OCR: api_endpoint 未配置,OCR 功能将不可用") def recognize( self, image, preprocess: bool = True, **kwargs ) -> OCRBatchResult: """ 使用云端 API 识别图像中的文本 Args: image: 图像(路径或 PIL Image) preprocess: 是否预处理图像 **kwargs: 其他参数 Returns: OCRBatchResult: 识别结果 """ try: # 加载图像 pil_image = self._load_image(image) # 预处理(如果启用) if preprocess: pil_image = self.preprocessor.preprocess(pil_image) # 转换为 base64 img_base64 = self.preprocessor.image_to_base64(pil_image) # 根据提供商调用不同的 API if self.provider == OCRProvider.BAIDU: return self._baidu_ocr(img_base64) elif self.provider == OCRProvider.TENCENT: return self._tencent_ocr(img_base64) elif self.provider == OCRProvider.ALIYUN: return self._aliyun_ocr(img_base64) else: return self._custom_api_ocr(img_base64) except Exception as e: logger.error(f"云端 OCR 识别失败: {e}", exc_info=True) return OCRBatchResult( results=[], full_text="", total_confidence=0.0, success=False, error_message=str(e) ) def _baidu_ocr(self, img_base64: str) -> OCRBatchResult: """百度 OCR API""" try: # 百度 OCR API 实现 url = "https://aip.baidubce.com/rest/2.0/ocr/v1/general_basic" # 获取 access_token(简化版本,实际应该缓存) if self.api_key and self.api_secret: token_url = f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={self.api_key}&client_secret={self.api_secret}" token_resp = requests.get(token_url, timeout=self.timeout) if token_resp.status_code == 200: access_token = token_resp.json().get('access_token', '') url = f"{url}?access_token={access_token}" data = { 'image': img_base64 } response = requests.post(url, data=data, timeout=self.timeout) result = response.json() if 'words_result' in result: ocr_results = [] full_lines = [] for idx, item in enumerate(result['words_result']): text = item.get('words', '') ocr_result = OCRResult( text=text, confidence=0.95, # 百度 API 不返回置信度 line_index=idx ) ocr_results.append(ocr_result) full_lines.append(text) full_text = '\n'.join(full_lines) total_confidence = self._calculate_total_confidence(ocr_results) logger.info(f"百度 OCR 识别完成: {len(ocr_results)} 行") return OCRBatchResult( results=ocr_results, full_text=full_text, total_confidence=total_confidence, success=True ) else: error_msg = result.get('error_msg', '未知错误') return OCRBatchResult( results=[], full_text="", total_confidence=0.0, success=False, error_message=error_msg ) except Exception as e: logger.error(f"百度 OCR 调用失败: {e}") return OCRBatchResult( results=[], full_text="", total_confidence=0.0, success=False, error_message=f"百度 OCR 调用失败: {str(e)}" ) def _tencent_ocr(self, img_base64: str) -> OCRBatchResult: """腾讯云 OCR API""" # 腾讯云 OCR 实现占位 logger.warning("腾讯云 OCR 尚未实现") return OCRBatchResult( results=[], full_text="", total_confidence=0.0, success=False, error_message="腾讯云 OCR 尚未实现" ) def _aliyun_ocr(self, img_base64: str) -> OCRBatchResult: """阿里云 OCR API""" # 阿里云 OCR 实现占位 logger.warning("阿里云 OCR 尚未实现") return OCRBatchResult( results=[], full_text="", total_confidence=0.0, success=False, error_message="阿里云 OCR 尚未实现" ) def _custom_api_ocr(self, img_base64: str) -> OCRBatchResult: """自定义 API OCR""" if not self.api_endpoint: return OCRBatchResult( results=[], full_text="", total_confidence=0.0, success=False, error_message="未配置云端 OCR API endpoint" ) try: headers = { 'Content-Type': 'application/json', } # 添加 API Key(如果有) if self.api_key: headers['Authorization'] = f'Bearer {self.api_key}' data = { 'image': img_base64, 'format': 'base64' } response = requests.post( self.api_endpoint, json=data, headers=headers, timeout=self.timeout ) if response.status_code == 200: result = response.json() # 尝试解析常见格式 if 'text' in result: # 简单文本格式 full_text = result['text'] ocr_results = [OCRResult(text=full_text, confidence=0.9)] return OCRBatchResult( results=ocr_results, full_text=full_text, total_confidence=0.9, success=True ) elif 'lines' in result: # 多行格式 ocr_results = [] full_lines = [] for idx, line in enumerate(result['lines']): text = line.get('text', '') conf = line.get('confidence', 0.9) ocr_results.append(OCRResult(text=text, confidence=conf, line_index=idx)) full_lines.append(text) full_text = '\n'.join(full_lines) total_confidence = self._calculate_total_confidence(ocr_results) return OCRBatchResult( results=ocr_results, full_text=full_text, total_confidence=total_confidence, success=True ) else: return OCRBatchResult( results=[], full_text="", total_confidence=0.0, success=False, error_message=f"未知的响应格式: {list(result.keys())}" ) else: return OCRBatchResult( results=[], full_text="", total_confidence=0.0, success=False, error_message=f"API 请求失败: HTTP {response.status_code}" ) except Exception as e: logger.error(f"自定义 API OCR 调用失败: {e}") return OCRBatchResult( results=[], full_text="", total_confidence=0.0, success=False, error_message=f"API 调用失败: {str(e)}" ) class OCRFactory: """ OCR 引擎工厂 根据配置创建对应的 OCR 引擎实例 """ @staticmethod def create_engine( mode: str = "cloud", config: Optional[Dict[str, Any]] = None ) -> BaseOCREngine: """ 创建 OCR 引擎 Args: mode: OCR 模式(当前仅支持 "cloud") config: 配置字典 Returns: BaseOCREngine: OCR 引擎实例 Raises: ValueError: 不支持的 OCR 模式 """ if mode == "cloud": return CloudOCREngine(config) else: # 为了向后兼容,非 cloud 模式也返回云端引擎 logger.warning(f"OCR 模式 '{mode}' 已弃用,使用云端 OCR") return CloudOCREngine(config) # 便捷函数 def recognize_text( image, mode: str = "cloud", preprocess: bool = True, **kwargs ) -> OCRBatchResult: """ 快捷识别文本 Args: image: 图像(路径或 PIL Image) mode: OCR 模式(仅支持 "cloud") preprocess: 是否预处理图像 **kwargs: 其他配置 Returns: OCRBatchResult: 识别结果 """ config = kwargs.copy() engine = OCRFactory.create_engine(mode, config) return engine.recognize(image, preprocess=preprocess) def preprocess_image( image_path: str, output_path: Optional[str] = None, **kwargs ) -> Image.Image: """ 快捷预处理图像 Args: image_path: 输入图像路径 output_path: 输出图像路径(如果指定,则保存) **kwargs: 预处理参数 Returns: PIL Image: 处理后的图像 """ processed = ImagePreprocessor.preprocess_from_path(image_path, **kwargs) if output_path: processed.save(output_path) logger.info(f"预处理图像已保存到: {output_path}") return processed