""" OCR 模块 提供文字识别功能,支持: - 本地 PaddleOCR 识别 - 云端 OCR API 扩展 - 图片预处理增强 - 多语言支持(中/英/混合) """ from abc import ABC, abstractmethod from pathlib import Path from typing import List, Optional, Dict, Any, Tuple from dataclasses import dataclass from enum import Enum import logging try: from PIL import Image, ImageEnhance, ImageFilter import numpy as np except ImportError: raise ImportError( "请安装图像处理库: pip install pillow numpy" ) try: from paddleocr import PaddleOCR except ImportError: PaddleOCR = None logging.warning("PaddleOCR 未安装,本地 OCR 功能不可用") # 配置日志 logger = logging.getLogger(__name__) class OCRLanguage(str, Enum): """OCR 支持的语言""" CHINESE = "ch" # 中文 ENGLISH = "en" # 英文 MIXED = "chinese_chinese" # 中英文混合 @dataclass class OCRResult: """ OCR 识别结果 Attributes: text: 识别的文本内容 confidence: 置信度 (0-1) bbox: 文本框坐标 [[x1, y1], [x2, y2], [x3, y3], [x4, y4]] line_index: 行索引(从0开始) """ text: str confidence: float bbox: Optional[List[List[float]]] = None line_index: int = 0 def __repr__(self) -> str: return f"OCRResult(text='{self.text[:30]}...', confidence={self.confidence:.2f})" @dataclass class OCRBatchResult: """ OCR 批量识别结果 Attributes: results: 所有的识别结果列表 full_text: 完整文本(所有行拼接) total_confidence: 平均置信度 success: 是否识别成功 error_message: 错误信息(如果失败) """ results: List[OCRResult] full_text: str total_confidence: float success: bool = True error_message: Optional[str] = None def __repr__(self) -> str: return f"OCRBatchResult(lines={len(self.results)}, confidence={self.total_confidence:.2f})" class ImagePreprocessor: """ 图像预处理器 提供常见的图像增强和预处理功能,提高 OCR 识别准确率 """ @staticmethod def load_image(image_path: str) -> Image.Image: """ 加载图像 Args: image_path: 图像文件路径 Returns: PIL Image 对象 """ image = Image.open(image_path) # 转换为 RGB 模式 if image.mode != 'RGB': image = image.convert('RGB') return image @staticmethod def resize_image(image: Image.Image, max_width: int = 2000) -> Image.Image: """ 调整图像大小(保持宽高比) Args: image: PIL Image 对象 max_width: 最大宽度 Returns: 调整后的图像 """ if image.width > max_width: ratio = max_width / image.width new_height = int(image.height * ratio) image = image.resize((max_width, new_height), Image.Resampling.LANCZOS) return image @staticmethod def enhance_contrast(image: Image.Image, factor: float = 1.5) -> Image.Image: """ 增强对比度 Args: image: PIL Image 对象 factor: 增强因子,1.0 表示原始,>1.0 增强,<1.0 减弱 Returns: 处理后的图像 """ enhancer = ImageEnhance.Contrast(image) return enhancer.enhance(factor) @staticmethod def enhance_sharpness(image: Image.Image, factor: float = 1.5) -> Image.Image: """ 增强锐度 Args: image: PIL Image 对象 factor: 锐化因子 Returns: 处理后的图像 """ enhancer = ImageEnhance.Sharpness(image) return enhancer.enhance(factor) @staticmethod def enhance_brightness(image: Image.Image, factor: float = 1.1) -> Image.Image: """ 调整亮度 Args: image: PIL Image 对象 factor: 亮度因子 Returns: 处理后的图像 """ enhancer = ImageEnhance.Brightness(image) return enhancer.enhance(factor) @staticmethod def denoise(image: Image.Image) -> Image.Image: """ 去噪(使用中值滤波) Args: image: PIL Image 对象 Returns: 处理后的图像 """ return image.filter(ImageFilter.MedianFilter(size=3)) @staticmethod def binarize(image: Image.Image, threshold: int = 127) -> Image.Image: """ 二值化(转换为黑白图像) Args: image: PIL Image 对象 threshold: 二值化阈值 Returns: 处理后的图像 """ # 先转为灰度图 gray = image.convert('L') # 二值化 binary = gray.point(lambda x: 0 if x < threshold else 255, '1') # 转回 RGB return binary.convert('RGB') @staticmethod def preprocess( image: Image.Image, resize: bool = True, enhance_contrast: bool = True, enhance_sharpness: bool = True, denoise: bool = False, binarize: bool = False ) -> Image.Image: """ 综合预处理(根据指定选项) Args: image: PIL Image 对象 resize: 是否调整大小 enhance_contrast: 是否增强对比度 enhance_sharpness: 是否增强锐度 denoise: 是否去噪 binarize: 是否二值化 Returns: 处理后的图像 """ result = image.copy() if resize: result = ImagePreprocessor.resize_image(result) if enhance_contrast: result = ImagePreprocessor.enhance_contrast(result) if enhance_sharpness: result = ImagePreprocessor.enhance_sharpness(result) if denoise: result = ImagePreprocessor.denoise(result) if binarize: result = ImagePreprocessor.binarize(result) return result @staticmethod def preprocess_from_path( image_path: str, **kwargs ) -> Image.Image: """ 从文件路径加载并预处理图像 Args: image_path: 图像文件路径 **kwargs: preprocess 方法的参数 Returns: 处理后的图像 """ image = ImagePreprocessor.load_image(image_path) return ImagePreprocessor.preprocess(image, **kwargs) class BaseOCREngine(ABC): """ OCR 引擎基类 所有 OCR 实现必须继承此类并实现 recognize 方法 """ def __init__(self, config: Optional[Dict[str, Any]] = None): """ 初始化 OCR 引擎 Args: config: OCR 配置字典 """ self.config = config or {} self.preprocessor = ImagePreprocessor() @abstractmethod def recognize( self, image, preprocess: bool = True, **kwargs ) -> OCRBatchResult: """ 识别图像中的文本 Args: image: 图像(可以是路径、PIL Image 或 numpy 数组) preprocess: 是否预处理图像 **kwargs: 其他参数 Returns: OCRBatchResult: 识别结果 """ pass def _load_image(self, image) -> Image.Image: """ 加载图像(支持多种输入格式) Args: image: 图像(路径、PIL Image 或 numpy 数组) Returns: PIL Image 对象 """ if isinstance(image, str) or isinstance(image, Path): return self.preprocessor.load_image(str(image)) elif isinstance(image, Image.Image): return image elif isinstance(image, np.ndarray): return Image.fromarray(image) else: raise ValueError(f"不支持的图像类型: {type(image)}") def _calculate_total_confidence(self, results: List[OCRResult]) -> float: """ 计算平均置信度 Args: results: OCR 结果列表 Returns: 平均置信度 (0-1) """ if not results: return 0.0 return sum(r.confidence for r in results) / len(results) class PaddleOCREngine(BaseOCREngine): """ PaddleOCR 本地识别引擎 使用 PaddleOCR 进行本地文字识别 """ def __init__(self, config: Optional[Dict[str, Any]] = None): """ 初始化 PaddleOCR 引擎 Args: config: 配置字典,支持: - use_gpu: 是否使用 GPU (默认 False) - lang: 语言 (默认 "ch",支持 ch/en/chinese_chinese) - show_log: 是否显示日志 (默认 False) """ super().__init__(config) if PaddleOCR is None: raise ImportError( "PaddleOCR 未安装。请运行: pip install paddleocr paddlepaddle" ) # 解析配置 self.use_gpu = self.config.get('use_gpu', False) self.lang = self.config.get('lang', 'ch') self.show_log = self.config.get('show_log', False) # 初始化 PaddleOCR logger.info(f"初始化 PaddleOCR (lang={self.lang}, gpu={self.use_gpu})") self.ocr = PaddleOCR( use_angle_cls=True, # 使用方向分类器 lang=self.lang, use_gpu=self.use_gpu, show_log=self.show_log ) def recognize( self, image, preprocess: bool = False, **kwargs ) -> OCRBatchResult: """ 使用 PaddleOCR 识别图像中的文本 Args: image: 图像(路径、PIL Image 或 numpy 数组) preprocess: 是否预处理图像 **kwargs: 其他参数(未使用) Returns: OCRBatchResult: 识别结果 """ try: # 加载图像 pil_image = self._load_image(image) # 预处理(如果启用) if preprocess: pil_image = self.preprocessor.preprocess(pil_image) # 转换为 numpy 数组(PaddleOCR 需要) img_array = np.array(pil_image) # 执行 OCR result = self.ocr.ocr(img_array, cls=True) # 解析结果 ocr_results = [] full_lines = [] if result and result[0]: for line_index, line in enumerate(result[0]): if line: # PaddleOCR 返回格式: [[bbox], (text, confidence)] bbox = line[0] text_info = line[1] text = text_info[0] confidence = float(text_info[1]) ocr_result = OCRResult( text=text, confidence=confidence, bbox=bbox, line_index=line_index ) ocr_results.append(ocr_result) full_lines.append(text) # 计算平均置信度 total_confidence = self._calculate_total_confidence(ocr_results) # 拼接完整文本 full_text = '\n'.join(full_lines) logger.info(f"OCR 识别完成: {len(ocr_results)} 行, 平均置信度 {total_confidence:.2f}") return OCRBatchResult( results=ocr_results, full_text=full_text, total_confidence=total_confidence, success=True ) except Exception as e: logger.error(f"OCR 识别失败: {e}", exc_info=True) return OCRBatchResult( results=[], full_text="", total_confidence=0.0, success=False, error_message=str(e) ) class CloudOCREngine(BaseOCREngine): """ 云端 OCR 引擎(适配器) 预留接口,用于扩展云端 OCR 服务 支持:百度 OCR、腾讯 OCR、阿里云 OCR 等 """ def __init__(self, config: Optional[Dict[str, Any]] = None): """ 初始化云端 OCR 引擎 Args: config: 配置字典,支持: - api_endpoint: API 端点 - api_key: API 密钥 - provider: 提供商 (baidu/tencent/aliyun/custom) - timeout: 超时时间(秒) """ super().__init__(config) self.api_endpoint = self.config.get('api_endpoint', '') self.api_key = self.config.get('api_key', '') self.provider = self.config.get('provider', 'custom') self.timeout = self.config.get('timeout', 30) if not self.api_endpoint: logger.warning("云端 OCR: api_endpoint 未配置") def recognize( self, image, preprocess: bool = False, **kwargs ) -> OCRBatchResult: """ 使用云端 API 识别图像中的文本 Args: image: 图像(路径、PIL Image) preprocess: 是否预处理图像 **kwargs: 其他参数 Returns: OCRBatchResult: 识别结果 """ # 这是一个占位实现 # 实际使用时需要根据具体的云端 OCR API 实现 logger.warning("云端 OCR 尚未实现,请使用本地 PaddleOCR 或自行实现") return OCRBatchResult( results=[], full_text="", total_confidence=0.0, success=False, error_message="云端 OCR 尚未实现" ) def _send_request(self, image_data: bytes) -> Dict[str, Any]: """ 发送 API 请求(待实现) Args: image_data: 图像二进制数据 Returns: API 响应 """ raise NotImplementedError("请根据具体云服务 API 实现此方法") class OCRFactory: """ OCR 引擎工厂 根据配置创建对应的 OCR 引擎实例 """ @staticmethod def create_engine( mode: str = "local", config: Optional[Dict[str, Any]] = None ) -> BaseOCREngine: """ 创建 OCR 引擎 Args: mode: OCR 模式 ("local" 或 "cloud") config: 配置字典 Returns: BaseOCREngine: OCR 引擎实例 Raises: ValueError: 不支持的 OCR 模式 """ if mode == "local": return PaddleOCREngine(config) elif mode == "cloud": return CloudOCREngine(config) else: raise ValueError(f"不支持的 OCR 模式: {mode}") # 便捷函数 def recognize_text( image, mode: str = "local", lang: str = "ch", use_gpu: bool = False, preprocess: bool = False, **kwargs ) -> OCRBatchResult: """ 快捷识别文本 Args: image: 图像(路径、PIL Image) mode: OCR 模式 ("local" 或 "cloud") lang: 语言 (ch/en/chinese_chinese) use_gpu: 是否使用 GPU(仅本地模式) preprocess: 是否预处理图像 **kwargs: 其他配置 Returns: OCRBatchResult: 识别结果 """ config = { 'lang': lang, 'use_gpu': use_gpu, **kwargs } engine = OCRFactory.create_engine(mode, config) return engine.recognize(image, preprocess=preprocess) def preprocess_image( image_path: str, output_path: Optional[str] = None, **kwargs ) -> Image.Image: """ 快捷预处理图像 Args: image_path: 输入图像路径 output_path: 输出图像路径(如果指定,则保存) **kwargs: 预处理参数 Returns: PIL Image: 处理后的图像 """ processed = ImagePreprocessor.preprocess_from_path(image_path, **kwargs) if output_path: processed.save(output_path) logger.info(f"预处理图像已保存到: {output_path}") return processed