Files
cutThenThink/src/core/ocr.py

614 lines
16 KiB
Python
Raw Normal View History

"""
OCR 模块
提供文字识别功能支持
- 本地 PaddleOCR 识别
- 云端 OCR API 扩展
- 图片预处理增强
- 多语言支持//混合
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Dict, Any, Tuple
from dataclasses import dataclass
from enum import Enum
import logging
try:
from PIL import Image, ImageEnhance, ImageFilter
import numpy as np
except ImportError:
raise ImportError(
"请安装图像处理库: pip install pillow numpy"
)
try:
from paddleocr import PaddleOCR
except ImportError:
PaddleOCR = None
logging.warning("PaddleOCR 未安装,本地 OCR 功能不可用")
# 配置日志
logger = logging.getLogger(__name__)
class OCRLanguage(str, Enum):
"""OCR 支持的语言"""
CHINESE = "ch" # 中文
ENGLISH = "en" # 英文
MIXED = "chinese_chinese" # 中英文混合
@dataclass
class OCRResult:
"""
OCR 识别结果
Attributes:
text: 识别的文本内容
confidence: 置信度 (0-1)
bbox: 文本框坐标 [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
line_index: 行索引从0开始
"""
text: str
confidence: float
bbox: Optional[List[List[float]]] = None
line_index: int = 0
def __repr__(self) -> str:
return f"OCRResult(text='{self.text[:30]}...', confidence={self.confidence:.2f})"
@dataclass
class OCRBatchResult:
"""
OCR 批量识别结果
Attributes:
results: 所有的识别结果列表
full_text: 完整文本所有行拼接
total_confidence: 平均置信度
success: 是否识别成功
error_message: 错误信息如果失败
"""
results: List[OCRResult]
full_text: str
total_confidence: float
success: bool = True
error_message: Optional[str] = None
def __repr__(self) -> str:
return f"OCRBatchResult(lines={len(self.results)}, confidence={self.total_confidence:.2f})"
class ImagePreprocessor:
"""
图像预处理器
提供常见的图像增强和预处理功能提高 OCR 识别准确率
"""
@staticmethod
def load_image(image_path: str) -> Image.Image:
"""
加载图像
Args:
image_path: 图像文件路径
Returns:
PIL Image 对象
"""
image = Image.open(image_path)
# 转换为 RGB 模式
if image.mode != 'RGB':
image = image.convert('RGB')
return image
@staticmethod
def resize_image(image: Image.Image, max_width: int = 2000) -> Image.Image:
"""
调整图像大小保持宽高比
Args:
image: PIL Image 对象
max_width: 最大宽度
Returns:
调整后的图像
"""
if image.width > max_width:
ratio = max_width / image.width
new_height = int(image.height * ratio)
image = image.resize((max_width, new_height), Image.Resampling.LANCZOS)
return image
@staticmethod
def enhance_contrast(image: Image.Image, factor: float = 1.5) -> Image.Image:
"""
增强对比度
Args:
image: PIL Image 对象
factor: 增强因子1.0 表示原始>1.0 增强<1.0 减弱
Returns:
处理后的图像
"""
enhancer = ImageEnhance.Contrast(image)
return enhancer.enhance(factor)
@staticmethod
def enhance_sharpness(image: Image.Image, factor: float = 1.5) -> Image.Image:
"""
增强锐度
Args:
image: PIL Image 对象
factor: 锐化因子
Returns:
处理后的图像
"""
enhancer = ImageEnhance.Sharpness(image)
return enhancer.enhance(factor)
@staticmethod
def enhance_brightness(image: Image.Image, factor: float = 1.1) -> Image.Image:
"""
调整亮度
Args:
image: PIL Image 对象
factor: 亮度因子
Returns:
处理后的图像
"""
enhancer = ImageEnhance.Brightness(image)
return enhancer.enhance(factor)
@staticmethod
def denoise(image: Image.Image) -> Image.Image:
"""
去噪使用中值滤波
Args:
image: PIL Image 对象
Returns:
处理后的图像
"""
return image.filter(ImageFilter.MedianFilter(size=3))
@staticmethod
def binarize(image: Image.Image, threshold: int = 127) -> Image.Image:
"""
二值化转换为黑白图像
Args:
image: PIL Image 对象
threshold: 二值化阈值
Returns:
处理后的图像
"""
# 先转为灰度图
gray = image.convert('L')
# 二值化
binary = gray.point(lambda x: 0 if x < threshold else 255, '1')
# 转回 RGB
return binary.convert('RGB')
@staticmethod
def preprocess(
image: Image.Image,
resize: bool = True,
enhance_contrast: bool = True,
enhance_sharpness: bool = True,
denoise: bool = False,
binarize: bool = False
) -> Image.Image:
"""
综合预处理根据指定选项
Args:
image: PIL Image 对象
resize: 是否调整大小
enhance_contrast: 是否增强对比度
enhance_sharpness: 是否增强锐度
denoise: 是否去噪
binarize: 是否二值化
Returns:
处理后的图像
"""
result = image.copy()
if resize:
result = ImagePreprocessor.resize_image(result)
if enhance_contrast:
result = ImagePreprocessor.enhance_contrast(result)
if enhance_sharpness:
result = ImagePreprocessor.enhance_sharpness(result)
if denoise:
result = ImagePreprocessor.denoise(result)
if binarize:
result = ImagePreprocessor.binarize(result)
return result
@staticmethod
def preprocess_from_path(
image_path: str,
**kwargs
) -> Image.Image:
"""
从文件路径加载并预处理图像
Args:
image_path: 图像文件路径
**kwargs: preprocess 方法的参数
Returns:
处理后的图像
"""
image = ImagePreprocessor.load_image(image_path)
return ImagePreprocessor.preprocess(image, **kwargs)
class BaseOCREngine(ABC):
"""
OCR 引擎基类
所有 OCR 实现必须继承此类并实现 recognize 方法
"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""
初始化 OCR 引擎
Args:
config: OCR 配置字典
"""
self.config = config or {}
self.preprocessor = ImagePreprocessor()
@abstractmethod
def recognize(
self,
image,
preprocess: bool = True,
**kwargs
) -> OCRBatchResult:
"""
识别图像中的文本
Args:
image: 图像可以是路径PIL Image numpy 数组
preprocess: 是否预处理图像
**kwargs: 其他参数
Returns:
OCRBatchResult: 识别结果
"""
pass
def _load_image(self, image) -> Image.Image:
"""
加载图像支持多种输入格式
Args:
image: 图像路径PIL Image numpy 数组
Returns:
PIL Image 对象
"""
if isinstance(image, str) or isinstance(image, Path):
return self.preprocessor.load_image(str(image))
elif isinstance(image, Image.Image):
return image
elif isinstance(image, np.ndarray):
return Image.fromarray(image)
else:
raise ValueError(f"不支持的图像类型: {type(image)}")
def _calculate_total_confidence(self, results: List[OCRResult]) -> float:
"""
计算平均置信度
Args:
results: OCR 结果列表
Returns:
平均置信度 (0-1)
"""
if not results:
return 0.0
return sum(r.confidence for r in results) / len(results)
class PaddleOCREngine(BaseOCREngine):
"""
PaddleOCR 本地识别引擎
使用 PaddleOCR 进行本地文字识别
"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""
初始化 PaddleOCR 引擎
Args:
config: 配置字典支持:
- use_gpu: 是否使用 GPU (默认 False)
- lang: 语言 (默认 "ch"支持 ch/en/chinese_chinese)
- show_log: 是否显示日志 (默认 False)
"""
super().__init__(config)
if PaddleOCR is None:
raise ImportError(
"PaddleOCR 未安装。请运行: pip install paddleocr paddlepaddle"
)
# 解析配置
self.use_gpu = self.config.get('use_gpu', False)
self.lang = self.config.get('lang', 'ch')
self.show_log = self.config.get('show_log', False)
# 初始化 PaddleOCR
logger.info(f"初始化 PaddleOCR (lang={self.lang}, gpu={self.use_gpu})")
self.ocr = PaddleOCR(
use_angle_cls=True, # 使用方向分类器
lang=self.lang,
use_gpu=self.use_gpu,
show_log=self.show_log
)
def recognize(
self,
image,
preprocess: bool = False,
**kwargs
) -> OCRBatchResult:
"""
使用 PaddleOCR 识别图像中的文本
Args:
image: 图像路径PIL Image numpy 数组
preprocess: 是否预处理图像
**kwargs: 其他参数未使用
Returns:
OCRBatchResult: 识别结果
"""
try:
# 加载图像
pil_image = self._load_image(image)
# 预处理(如果启用)
if preprocess:
pil_image = self.preprocessor.preprocess(pil_image)
# 转换为 numpy 数组PaddleOCR 需要)
img_array = np.array(pil_image)
# 执行 OCR
result = self.ocr.ocr(img_array, cls=True)
# 解析结果
ocr_results = []
full_lines = []
if result and result[0]:
for line_index, line in enumerate(result[0]):
if line:
# PaddleOCR 返回格式: [[bbox], (text, confidence)]
bbox = line[0]
text_info = line[1]
text = text_info[0]
confidence = float(text_info[1])
ocr_result = OCRResult(
text=text,
confidence=confidence,
bbox=bbox,
line_index=line_index
)
ocr_results.append(ocr_result)
full_lines.append(text)
# 计算平均置信度
total_confidence = self._calculate_total_confidence(ocr_results)
# 拼接完整文本
full_text = '\n'.join(full_lines)
logger.info(f"OCR 识别完成: {len(ocr_results)} 行, 平均置信度 {total_confidence:.2f}")
return OCRBatchResult(
results=ocr_results,
full_text=full_text,
total_confidence=total_confidence,
success=True
)
except Exception as e:
logger.error(f"OCR 识别失败: {e}", exc_info=True)
return OCRBatchResult(
results=[],
full_text="",
total_confidence=0.0,
success=False,
error_message=str(e)
)
class CloudOCREngine(BaseOCREngine):
"""
云端 OCR 引擎适配器
预留接口用于扩展云端 OCR 服务
支持百度 OCR腾讯 OCR阿里云 OCR
"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""
初始化云端 OCR 引擎
Args:
config: 配置字典支持:
- api_endpoint: API 端点
- api_key: API 密钥
- provider: 提供商 (baidu/tencent/aliyun/custom)
- timeout: 超时时间
"""
super().__init__(config)
self.api_endpoint = self.config.get('api_endpoint', '')
self.api_key = self.config.get('api_key', '')
self.provider = self.config.get('provider', 'custom')
self.timeout = self.config.get('timeout', 30)
if not self.api_endpoint:
logger.warning("云端 OCR: api_endpoint 未配置")
def recognize(
self,
image,
preprocess: bool = False,
**kwargs
) -> OCRBatchResult:
"""
使用云端 API 识别图像中的文本
Args:
image: 图像路径PIL Image
preprocess: 是否预处理图像
**kwargs: 其他参数
Returns:
OCRBatchResult: 识别结果
"""
# 这是一个占位实现
# 实际使用时需要根据具体的云端 OCR API 实现
logger.warning("云端 OCR 尚未实现,请使用本地 PaddleOCR 或自行实现")
return OCRBatchResult(
results=[],
full_text="",
total_confidence=0.0,
success=False,
error_message="云端 OCR 尚未实现"
)
def _send_request(self, image_data: bytes) -> Dict[str, Any]:
"""
发送 API 请求待实现
Args:
image_data: 图像二进制数据
Returns:
API 响应
"""
raise NotImplementedError("请根据具体云服务 API 实现此方法")
class OCRFactory:
"""
OCR 引擎工厂
根据配置创建对应的 OCR 引擎实例
"""
@staticmethod
def create_engine(
mode: str = "local",
config: Optional[Dict[str, Any]] = None
) -> BaseOCREngine:
"""
创建 OCR 引擎
Args:
mode: OCR 模式 ("local" "cloud")
config: 配置字典
Returns:
BaseOCREngine: OCR 引擎实例
Raises:
ValueError: 不支持的 OCR 模式
"""
if mode == "local":
return PaddleOCREngine(config)
elif mode == "cloud":
return CloudOCREngine(config)
else:
raise ValueError(f"不支持的 OCR 模式: {mode}")
# 便捷函数
def recognize_text(
image,
mode: str = "local",
lang: str = "ch",
use_gpu: bool = False,
preprocess: bool = False,
**kwargs
) -> OCRBatchResult:
"""
快捷识别文本
Args:
image: 图像路径PIL Image
mode: OCR 模式 ("local" "cloud")
lang: 语言 (ch/en/chinese_chinese)
use_gpu: 是否使用 GPU仅本地模式
preprocess: 是否预处理图像
**kwargs: 其他配置
Returns:
OCRBatchResult: 识别结果
"""
config = {
'lang': lang,
'use_gpu': use_gpu,
**kwargs
}
engine = OCRFactory.create_engine(mode, config)
return engine.recognize(image, preprocess=preprocess)
def preprocess_image(
image_path: str,
output_path: Optional[str] = None,
**kwargs
) -> Image.Image:
"""
快捷预处理图像
Args:
image_path: 输入图像路径
output_path: 输出图像路径如果指定则保存
**kwargs: 预处理参数
Returns:
PIL Image: 处理后的图像
"""
processed = ImagePreprocessor.preprocess_from_path(image_path, **kwargs)
if output_path:
processed.save(output_path)
logger.info(f"预处理图像已保存到: {output_path}")
return processed