- 添加Windows打包脚本 build.bat - 更新打包文档 BUILD.md(轻量版方案) - OCR模块:添加首次运行时自动安装PaddleOCR的功能 - 主窗口:添加OCR安装检测和提示逻辑 - 创建应用入口 src/main.py Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
638 lines
17 KiB
Python
638 lines
17 KiB
Python
"""
|
||
OCR 模块
|
||
|
||
提供文字识别功能,支持:
|
||
- 本地 PaddleOCR 识别
|
||
- 云端 OCR API 扩展
|
||
- 图片预处理增强
|
||
- 多语言支持(中/英/混合)
|
||
"""
|
||
|
||
from abc import ABC, abstractmethod
|
||
from pathlib import Path
|
||
from typing import List, Optional, Dict, Any, Tuple
|
||
from dataclasses import dataclass
|
||
from enum import Enum
|
||
import logging
|
||
|
||
try:
|
||
from PIL import Image, ImageEnhance, ImageFilter
|
||
import numpy as np
|
||
except ImportError:
|
||
raise ImportError(
|
||
"请安装图像处理库: pip install pillow numpy"
|
||
)
|
||
|
||
try:
|
||
from paddleocr import PaddleOCR
|
||
except ImportError:
|
||
PaddleOCR = None
|
||
logging.warning("PaddleOCR 未安装,本地 OCR 功能不可用。首次运行时将自动安装。")
|
||
|
||
def ensure_paddleocr():
|
||
"""
|
||
确保 PaddleOCR 已安装,如果没有则安装
|
||
首次运行时自动下载安装 OCR 库
|
||
"""
|
||
global PaddleOCR
|
||
if PaddleOCR is None:
|
||
import subprocess
|
||
import sys
|
||
logging.info("正在安装 PaddleOCR...")
|
||
try:
|
||
subprocess.check_call([
|
||
sys.executable, "-m", "pip", "install",
|
||
"--break-system-packages",
|
||
"paddleocr"
|
||
])
|
||
# 重新导入
|
||
from paddleocr import PaddleOCR
|
||
globals()["PaddleOCR"] = PaddleOCR
|
||
logging.info("PaddleOCR 安装成功!")
|
||
except subprocess.CalledProcessError as e:
|
||
logging.error(f"PaddleOCR 安装失败: {e}")
|
||
raise
|
||
|
||
|
||
# 配置日志
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class OCRLanguage(str, Enum):
|
||
"""OCR 支持的语言"""
|
||
CHINESE = "ch" # 中文
|
||
ENGLISH = "en" # 英文
|
||
MIXED = "chinese_chinese" # 中英文混合
|
||
|
||
|
||
@dataclass
|
||
class OCRResult:
|
||
"""
|
||
OCR 识别结果
|
||
|
||
Attributes:
|
||
text: 识别的文本内容
|
||
confidence: 置信度 (0-1)
|
||
bbox: 文本框坐标 [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
|
||
line_index: 行索引(从0开始)
|
||
"""
|
||
text: str
|
||
confidence: float
|
||
bbox: Optional[List[List[float]]] = None
|
||
line_index: int = 0
|
||
|
||
def __repr__(self) -> str:
|
||
return f"OCRResult(text='{self.text[:30]}...', confidence={self.confidence:.2f})"
|
||
|
||
|
||
@dataclass
|
||
class OCRBatchResult:
|
||
"""
|
||
OCR 批量识别结果
|
||
|
||
Attributes:
|
||
results: 所有的识别结果列表
|
||
full_text: 完整文本(所有行拼接)
|
||
total_confidence: 平均置信度
|
||
success: 是否识别成功
|
||
error_message: 错误信息(如果失败)
|
||
"""
|
||
results: List[OCRResult]
|
||
full_text: str
|
||
total_confidence: float
|
||
success: bool = True
|
||
error_message: Optional[str] = None
|
||
|
||
def __repr__(self) -> str:
|
||
return f"OCRBatchResult(lines={len(self.results)}, confidence={self.total_confidence:.2f})"
|
||
|
||
|
||
class ImagePreprocessor:
|
||
"""
|
||
图像预处理器
|
||
|
||
提供常见的图像增强和预处理功能,提高 OCR 识别准确率
|
||
"""
|
||
|
||
@staticmethod
|
||
def load_image(image_path: str) -> Image.Image:
|
||
"""
|
||
加载图像
|
||
|
||
Args:
|
||
image_path: 图像文件路径
|
||
|
||
Returns:
|
||
PIL Image 对象
|
||
"""
|
||
image = Image.open(image_path)
|
||
# 转换为 RGB 模式
|
||
if image.mode != 'RGB':
|
||
image = image.convert('RGB')
|
||
return image
|
||
|
||
@staticmethod
|
||
def resize_image(image: Image.Image, max_width: int = 2000) -> Image.Image:
|
||
"""
|
||
调整图像大小(保持宽高比)
|
||
|
||
Args:
|
||
image: PIL Image 对象
|
||
max_width: 最大宽度
|
||
|
||
Returns:
|
||
调整后的图像
|
||
"""
|
||
if image.width > max_width:
|
||
ratio = max_width / image.width
|
||
new_height = int(image.height * ratio)
|
||
image = image.resize((max_width, new_height), Image.Resampling.LANCZOS)
|
||
return image
|
||
|
||
@staticmethod
|
||
def enhance_contrast(image: Image.Image, factor: float = 1.5) -> Image.Image:
|
||
"""
|
||
增强对比度
|
||
|
||
Args:
|
||
image: PIL Image 对象
|
||
factor: 增强因子,1.0 表示原始,>1.0 增强,<1.0 减弱
|
||
|
||
Returns:
|
||
处理后的图像
|
||
"""
|
||
enhancer = ImageEnhance.Contrast(image)
|
||
return enhancer.enhance(factor)
|
||
|
||
@staticmethod
|
||
def enhance_sharpness(image: Image.Image, factor: float = 1.5) -> Image.Image:
|
||
"""
|
||
增强锐度
|
||
|
||
Args:
|
||
image: PIL Image 对象
|
||
factor: 锐化因子
|
||
|
||
Returns:
|
||
处理后的图像
|
||
"""
|
||
enhancer = ImageEnhance.Sharpness(image)
|
||
return enhancer.enhance(factor)
|
||
|
||
@staticmethod
|
||
def enhance_brightness(image: Image.Image, factor: float = 1.1) -> Image.Image:
|
||
"""
|
||
调整亮度
|
||
|
||
Args:
|
||
image: PIL Image 对象
|
||
factor: 亮度因子
|
||
|
||
Returns:
|
||
处理后的图像
|
||
"""
|
||
enhancer = ImageEnhance.Brightness(image)
|
||
return enhancer.enhance(factor)
|
||
|
||
@staticmethod
|
||
def denoise(image: Image.Image) -> Image.Image:
|
||
"""
|
||
去噪(使用中值滤波)
|
||
|
||
Args:
|
||
image: PIL Image 对象
|
||
|
||
Returns:
|
||
处理后的图像
|
||
"""
|
||
return image.filter(ImageFilter.MedianFilter(size=3))
|
||
|
||
@staticmethod
|
||
def binarize(image: Image.Image, threshold: int = 127) -> Image.Image:
|
||
"""
|
||
二值化(转换为黑白图像)
|
||
|
||
Args:
|
||
image: PIL Image 对象
|
||
threshold: 二值化阈值
|
||
|
||
Returns:
|
||
处理后的图像
|
||
"""
|
||
# 先转为灰度图
|
||
gray = image.convert('L')
|
||
# 二值化
|
||
binary = gray.point(lambda x: 0 if x < threshold else 255, '1')
|
||
# 转回 RGB
|
||
return binary.convert('RGB')
|
||
|
||
@staticmethod
|
||
def preprocess(
|
||
image: Image.Image,
|
||
resize: bool = True,
|
||
enhance_contrast: bool = True,
|
||
enhance_sharpness: bool = True,
|
||
denoise: bool = False,
|
||
binarize: bool = False
|
||
) -> Image.Image:
|
||
"""
|
||
综合预处理(根据指定选项)
|
||
|
||
Args:
|
||
image: PIL Image 对象
|
||
resize: 是否调整大小
|
||
enhance_contrast: 是否增强对比度
|
||
enhance_sharpness: 是否增强锐度
|
||
denoise: 是否去噪
|
||
binarize: 是否二值化
|
||
|
||
Returns:
|
||
处理后的图像
|
||
"""
|
||
result = image.copy()
|
||
|
||
if resize:
|
||
result = ImagePreprocessor.resize_image(result)
|
||
|
||
if enhance_contrast:
|
||
result = ImagePreprocessor.enhance_contrast(result)
|
||
|
||
if enhance_sharpness:
|
||
result = ImagePreprocessor.enhance_sharpness(result)
|
||
|
||
if denoise:
|
||
result = ImagePreprocessor.denoise(result)
|
||
|
||
if binarize:
|
||
result = ImagePreprocessor.binarize(result)
|
||
|
||
return result
|
||
|
||
@staticmethod
|
||
def preprocess_from_path(
|
||
image_path: str,
|
||
**kwargs
|
||
) -> Image.Image:
|
||
"""
|
||
从文件路径加载并预处理图像
|
||
|
||
Args:
|
||
image_path: 图像文件路径
|
||
**kwargs: preprocess 方法的参数
|
||
|
||
Returns:
|
||
处理后的图像
|
||
"""
|
||
image = ImagePreprocessor.load_image(image_path)
|
||
return ImagePreprocessor.preprocess(image, **kwargs)
|
||
|
||
|
||
class BaseOCREngine(ABC):
|
||
"""
|
||
OCR 引擎基类
|
||
|
||
所有 OCR 实现必须继承此类并实现 recognize 方法
|
||
"""
|
||
|
||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||
"""
|
||
初始化 OCR 引擎
|
||
|
||
Args:
|
||
config: OCR 配置字典
|
||
"""
|
||
self.config = config or {}
|
||
self.preprocessor = ImagePreprocessor()
|
||
|
||
@abstractmethod
|
||
def recognize(
|
||
self,
|
||
image,
|
||
preprocess: bool = True,
|
||
**kwargs
|
||
) -> OCRBatchResult:
|
||
"""
|
||
识别图像中的文本
|
||
|
||
Args:
|
||
image: 图像(可以是路径、PIL Image 或 numpy 数组)
|
||
preprocess: 是否预处理图像
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
OCRBatchResult: 识别结果
|
||
"""
|
||
pass
|
||
|
||
def _load_image(self, image) -> Image.Image:
|
||
"""
|
||
加载图像(支持多种输入格式)
|
||
|
||
Args:
|
||
image: 图像(路径、PIL Image 或 numpy 数组)
|
||
|
||
Returns:
|
||
PIL Image 对象
|
||
"""
|
||
if isinstance(image, str) or isinstance(image, Path):
|
||
return self.preprocessor.load_image(str(image))
|
||
elif isinstance(image, Image.Image):
|
||
return image
|
||
elif isinstance(image, np.ndarray):
|
||
return Image.fromarray(image)
|
||
else:
|
||
raise ValueError(f"不支持的图像类型: {type(image)}")
|
||
|
||
def _calculate_total_confidence(self, results: List[OCRResult]) -> float:
|
||
"""
|
||
计算平均置信度
|
||
|
||
Args:
|
||
results: OCR 结果列表
|
||
|
||
Returns:
|
||
平均置信度 (0-1)
|
||
"""
|
||
if not results:
|
||
return 0.0
|
||
return sum(r.confidence for r in results) / len(results)
|
||
|
||
|
||
class PaddleOCREngine(BaseOCREngine):
|
||
"""
|
||
PaddleOCR 本地识别引擎
|
||
|
||
使用 PaddleOCR 进行本地文字识别
|
||
"""
|
||
|
||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||
"""
|
||
初始化 PaddleOCR 引擎
|
||
|
||
Args:
|
||
config: 配置字典,支持:
|
||
- use_gpu: 是否使用 GPU (默认 False)
|
||
- lang: 语言 (默认 "ch",支持 ch/en/chinese_chinese)
|
||
- show_log: 是否显示日志 (默认 False)
|
||
"""
|
||
super().__init__(config)
|
||
|
||
if PaddleOCR is None:
|
||
raise ImportError(
|
||
"PaddleOCR 未安装。请运行: pip install paddleocr paddlepaddle"
|
||
)
|
||
|
||
# 解析配置
|
||
self.use_gpu = self.config.get('use_gpu', False)
|
||
self.lang = self.config.get('lang', 'ch')
|
||
self.show_log = self.config.get('show_log', False)
|
||
|
||
# 初始化 PaddleOCR
|
||
logger.info(f"初始化 PaddleOCR (lang={self.lang}, gpu={self.use_gpu})")
|
||
self.ocr = PaddleOCR(
|
||
use_angle_cls=True, # 使用方向分类器
|
||
lang=self.lang,
|
||
use_gpu=self.use_gpu,
|
||
show_log=self.show_log
|
||
)
|
||
|
||
def recognize(
|
||
self,
|
||
image,
|
||
preprocess: bool = False,
|
||
**kwargs
|
||
) -> OCRBatchResult:
|
||
"""
|
||
使用 PaddleOCR 识别图像中的文本
|
||
|
||
Args:
|
||
image: 图像(路径、PIL Image 或 numpy 数组)
|
||
preprocess: 是否预处理图像
|
||
**kwargs: 其他参数(未使用)
|
||
|
||
Returns:
|
||
OCRBatchResult: 识别结果
|
||
"""
|
||
try:
|
||
# 加载图像
|
||
pil_image = self._load_image(image)
|
||
|
||
# 预处理(如果启用)
|
||
if preprocess:
|
||
pil_image = self.preprocessor.preprocess(pil_image)
|
||
|
||
# 转换为 numpy 数组(PaddleOCR 需要)
|
||
img_array = np.array(pil_image)
|
||
|
||
# 执行 OCR
|
||
result = self.ocr.ocr(img_array, cls=True)
|
||
|
||
# 解析结果
|
||
ocr_results = []
|
||
full_lines = []
|
||
|
||
if result and result[0]:
|
||
for line_index, line in enumerate(result[0]):
|
||
if line:
|
||
# PaddleOCR 返回格式: [[bbox], (text, confidence)]
|
||
bbox = line[0]
|
||
text_info = line[1]
|
||
text = text_info[0]
|
||
confidence = float(text_info[1])
|
||
|
||
ocr_result = OCRResult(
|
||
text=text,
|
||
confidence=confidence,
|
||
bbox=bbox,
|
||
line_index=line_index
|
||
)
|
||
ocr_results.append(ocr_result)
|
||
full_lines.append(text)
|
||
|
||
# 计算平均置信度
|
||
total_confidence = self._calculate_total_confidence(ocr_results)
|
||
|
||
# 拼接完整文本
|
||
full_text = '\n'.join(full_lines)
|
||
|
||
logger.info(f"OCR 识别完成: {len(ocr_results)} 行, 平均置信度 {total_confidence:.2f}")
|
||
|
||
return OCRBatchResult(
|
||
results=ocr_results,
|
||
full_text=full_text,
|
||
total_confidence=total_confidence,
|
||
success=True
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"OCR 识别失败: {e}", exc_info=True)
|
||
return OCRBatchResult(
|
||
results=[],
|
||
full_text="",
|
||
total_confidence=0.0,
|
||
success=False,
|
||
error_message=str(e)
|
||
)
|
||
|
||
|
||
class CloudOCREngine(BaseOCREngine):
|
||
"""
|
||
云端 OCR 引擎(适配器)
|
||
|
||
预留接口,用于扩展云端 OCR 服务
|
||
支持:百度 OCR、腾讯 OCR、阿里云 OCR 等
|
||
"""
|
||
|
||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||
"""
|
||
初始化云端 OCR 引擎
|
||
|
||
Args:
|
||
config: 配置字典,支持:
|
||
- api_endpoint: API 端点
|
||
- api_key: API 密钥
|
||
- provider: 提供商 (baidu/tencent/aliyun/custom)
|
||
- timeout: 超时时间(秒)
|
||
"""
|
||
super().__init__(config)
|
||
|
||
self.api_endpoint = self.config.get('api_endpoint', '')
|
||
self.api_key = self.config.get('api_key', '')
|
||
self.provider = self.config.get('provider', 'custom')
|
||
self.timeout = self.config.get('timeout', 30)
|
||
|
||
if not self.api_endpoint:
|
||
logger.warning("云端 OCR: api_endpoint 未配置")
|
||
|
||
def recognize(
|
||
self,
|
||
image,
|
||
preprocess: bool = False,
|
||
**kwargs
|
||
) -> OCRBatchResult:
|
||
"""
|
||
使用云端 API 识别图像中的文本
|
||
|
||
Args:
|
||
image: 图像(路径、PIL Image)
|
||
preprocess: 是否预处理图像
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
OCRBatchResult: 识别结果
|
||
"""
|
||
# 这是一个占位实现
|
||
# 实际使用时需要根据具体的云端 OCR API 实现
|
||
logger.warning("云端 OCR 尚未实现,请使用本地 PaddleOCR 或自行实现")
|
||
|
||
return OCRBatchResult(
|
||
results=[],
|
||
full_text="",
|
||
total_confidence=0.0,
|
||
success=False,
|
||
error_message="云端 OCR 尚未实现"
|
||
)
|
||
|
||
def _send_request(self, image_data: bytes) -> Dict[str, Any]:
|
||
"""
|
||
发送 API 请求(待实现)
|
||
|
||
Args:
|
||
image_data: 图像二进制数据
|
||
|
||
Returns:
|
||
API 响应
|
||
"""
|
||
raise NotImplementedError("请根据具体云服务 API 实现此方法")
|
||
|
||
|
||
class OCRFactory:
|
||
"""
|
||
OCR 引擎工厂
|
||
|
||
根据配置创建对应的 OCR 引擎实例
|
||
"""
|
||
|
||
@staticmethod
|
||
def create_engine(
|
||
mode: str = "local",
|
||
config: Optional[Dict[str, Any]] = None
|
||
) -> BaseOCREngine:
|
||
"""
|
||
创建 OCR 引擎
|
||
|
||
Args:
|
||
mode: OCR 模式 ("local" 或 "cloud")
|
||
config: 配置字典
|
||
|
||
Returns:
|
||
BaseOCREngine: OCR 引擎实例
|
||
|
||
Raises:
|
||
ValueError: 不支持的 OCR 模式
|
||
"""
|
||
if mode == "local":
|
||
return PaddleOCREngine(config)
|
||
elif mode == "cloud":
|
||
return CloudOCREngine(config)
|
||
else:
|
||
raise ValueError(f"不支持的 OCR 模式: {mode}")
|
||
|
||
|
||
# 便捷函数
|
||
def recognize_text(
|
||
image,
|
||
mode: str = "local",
|
||
lang: str = "ch",
|
||
use_gpu: bool = False,
|
||
preprocess: bool = False,
|
||
**kwargs
|
||
) -> OCRBatchResult:
|
||
"""
|
||
快捷识别文本
|
||
|
||
Args:
|
||
image: 图像(路径、PIL Image)
|
||
mode: OCR 模式 ("local" 或 "cloud")
|
||
lang: 语言 (ch/en/chinese_chinese)
|
||
use_gpu: 是否使用 GPU(仅本地模式)
|
||
preprocess: 是否预处理图像
|
||
**kwargs: 其他配置
|
||
|
||
Returns:
|
||
OCRBatchResult: 识别结果
|
||
"""
|
||
config = {
|
||
'lang': lang,
|
||
'use_gpu': use_gpu,
|
||
**kwargs
|
||
}
|
||
|
||
engine = OCRFactory.create_engine(mode, config)
|
||
return engine.recognize(image, preprocess=preprocess)
|
||
|
||
|
||
def preprocess_image(
|
||
image_path: str,
|
||
output_path: Optional[str] = None,
|
||
**kwargs
|
||
) -> Image.Image:
|
||
"""
|
||
快捷预处理图像
|
||
|
||
Args:
|
||
image_path: 输入图像路径
|
||
output_path: 输出图像路径(如果指定,则保存)
|
||
**kwargs: 预处理参数
|
||
|
||
Returns:
|
||
PIL Image: 处理后的图像
|
||
"""
|
||
processed = ImagePreprocessor.preprocess_from_path(image_path, **kwargs)
|
||
|
||
if output_path:
|
||
processed.save(output_path)
|
||
logger.info(f"预处理图像已保存到: {output_path}")
|
||
|
||
return processed
|