Files
cutThenThink/src/core/ocr.py
congsh 154d53dbfd feat: 添加轻量打包和OCR自动安装功能
- 添加Windows打包脚本 build.bat
- 更新打包文档 BUILD.md(轻量版方案)
- OCR模块:添加首次运行时自动安装PaddleOCR的功能
- 主窗口:添加OCR安装检测和提示逻辑
- 创建应用入口 src/main.py

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 10:14:10 +08:00

638 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
OCR 模块
提供文字识别功能,支持:
- 本地 PaddleOCR 识别
- 云端 OCR API 扩展
- 图片预处理增强
- 多语言支持(中/英/混合)
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Dict, Any, Tuple
from dataclasses import dataclass
from enum import Enum
import logging
try:
from PIL import Image, ImageEnhance, ImageFilter
import numpy as np
except ImportError:
raise ImportError(
"请安装图像处理库: pip install pillow numpy"
)
try:
from paddleocr import PaddleOCR
except ImportError:
PaddleOCR = None
logging.warning("PaddleOCR 未安装,本地 OCR 功能不可用。首次运行时将自动安装。")
def ensure_paddleocr():
"""
确保 PaddleOCR 已安装,如果没有则安装
首次运行时自动下载安装 OCR 库
"""
global PaddleOCR
if PaddleOCR is None:
import subprocess
import sys
logging.info("正在安装 PaddleOCR...")
try:
subprocess.check_call([
sys.executable, "-m", "pip", "install",
"--break-system-packages",
"paddleocr"
])
# 重新导入
from paddleocr import PaddleOCR
globals()["PaddleOCR"] = PaddleOCR
logging.info("PaddleOCR 安装成功!")
except subprocess.CalledProcessError as e:
logging.error(f"PaddleOCR 安装失败: {e}")
raise
# 配置日志
logger = logging.getLogger(__name__)
class OCRLanguage(str, Enum):
"""OCR 支持的语言"""
CHINESE = "ch" # 中文
ENGLISH = "en" # 英文
MIXED = "chinese_chinese" # 中英文混合
@dataclass
class OCRResult:
"""
OCR 识别结果
Attributes:
text: 识别的文本内容
confidence: 置信度 (0-1)
bbox: 文本框坐标 [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
line_index: 行索引从0开始
"""
text: str
confidence: float
bbox: Optional[List[List[float]]] = None
line_index: int = 0
def __repr__(self) -> str:
return f"OCRResult(text='{self.text[:30]}...', confidence={self.confidence:.2f})"
@dataclass
class OCRBatchResult:
"""
OCR 批量识别结果
Attributes:
results: 所有的识别结果列表
full_text: 完整文本(所有行拼接)
total_confidence: 平均置信度
success: 是否识别成功
error_message: 错误信息(如果失败)
"""
results: List[OCRResult]
full_text: str
total_confidence: float
success: bool = True
error_message: Optional[str] = None
def __repr__(self) -> str:
return f"OCRBatchResult(lines={len(self.results)}, confidence={self.total_confidence:.2f})"
class ImagePreprocessor:
"""
图像预处理器
提供常见的图像增强和预处理功能,提高 OCR 识别准确率
"""
@staticmethod
def load_image(image_path: str) -> Image.Image:
"""
加载图像
Args:
image_path: 图像文件路径
Returns:
PIL Image 对象
"""
image = Image.open(image_path)
# 转换为 RGB 模式
if image.mode != 'RGB':
image = image.convert('RGB')
return image
@staticmethod
def resize_image(image: Image.Image, max_width: int = 2000) -> Image.Image:
"""
调整图像大小(保持宽高比)
Args:
image: PIL Image 对象
max_width: 最大宽度
Returns:
调整后的图像
"""
if image.width > max_width:
ratio = max_width / image.width
new_height = int(image.height * ratio)
image = image.resize((max_width, new_height), Image.Resampling.LANCZOS)
return image
@staticmethod
def enhance_contrast(image: Image.Image, factor: float = 1.5) -> Image.Image:
"""
增强对比度
Args:
image: PIL Image 对象
factor: 增强因子1.0 表示原始,>1.0 增强,<1.0 减弱
Returns:
处理后的图像
"""
enhancer = ImageEnhance.Contrast(image)
return enhancer.enhance(factor)
@staticmethod
def enhance_sharpness(image: Image.Image, factor: float = 1.5) -> Image.Image:
"""
增强锐度
Args:
image: PIL Image 对象
factor: 锐化因子
Returns:
处理后的图像
"""
enhancer = ImageEnhance.Sharpness(image)
return enhancer.enhance(factor)
@staticmethod
def enhance_brightness(image: Image.Image, factor: float = 1.1) -> Image.Image:
"""
调整亮度
Args:
image: PIL Image 对象
factor: 亮度因子
Returns:
处理后的图像
"""
enhancer = ImageEnhance.Brightness(image)
return enhancer.enhance(factor)
@staticmethod
def denoise(image: Image.Image) -> Image.Image:
"""
去噪(使用中值滤波)
Args:
image: PIL Image 对象
Returns:
处理后的图像
"""
return image.filter(ImageFilter.MedianFilter(size=3))
@staticmethod
def binarize(image: Image.Image, threshold: int = 127) -> Image.Image:
"""
二值化(转换为黑白图像)
Args:
image: PIL Image 对象
threshold: 二值化阈值
Returns:
处理后的图像
"""
# 先转为灰度图
gray = image.convert('L')
# 二值化
binary = gray.point(lambda x: 0 if x < threshold else 255, '1')
# 转回 RGB
return binary.convert('RGB')
@staticmethod
def preprocess(
image: Image.Image,
resize: bool = True,
enhance_contrast: bool = True,
enhance_sharpness: bool = True,
denoise: bool = False,
binarize: bool = False
) -> Image.Image:
"""
综合预处理(根据指定选项)
Args:
image: PIL Image 对象
resize: 是否调整大小
enhance_contrast: 是否增强对比度
enhance_sharpness: 是否增强锐度
denoise: 是否去噪
binarize: 是否二值化
Returns:
处理后的图像
"""
result = image.copy()
if resize:
result = ImagePreprocessor.resize_image(result)
if enhance_contrast:
result = ImagePreprocessor.enhance_contrast(result)
if enhance_sharpness:
result = ImagePreprocessor.enhance_sharpness(result)
if denoise:
result = ImagePreprocessor.denoise(result)
if binarize:
result = ImagePreprocessor.binarize(result)
return result
@staticmethod
def preprocess_from_path(
image_path: str,
**kwargs
) -> Image.Image:
"""
从文件路径加载并预处理图像
Args:
image_path: 图像文件路径
**kwargs: preprocess 方法的参数
Returns:
处理后的图像
"""
image = ImagePreprocessor.load_image(image_path)
return ImagePreprocessor.preprocess(image, **kwargs)
class BaseOCREngine(ABC):
"""
OCR 引擎基类
所有 OCR 实现必须继承此类并实现 recognize 方法
"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""
初始化 OCR 引擎
Args:
config: OCR 配置字典
"""
self.config = config or {}
self.preprocessor = ImagePreprocessor()
@abstractmethod
def recognize(
self,
image,
preprocess: bool = True,
**kwargs
) -> OCRBatchResult:
"""
识别图像中的文本
Args:
image: 图像可以是路径、PIL Image 或 numpy 数组)
preprocess: 是否预处理图像
**kwargs: 其他参数
Returns:
OCRBatchResult: 识别结果
"""
pass
def _load_image(self, image) -> Image.Image:
"""
加载图像(支持多种输入格式)
Args:
image: 图像路径、PIL Image 或 numpy 数组)
Returns:
PIL Image 对象
"""
if isinstance(image, str) or isinstance(image, Path):
return self.preprocessor.load_image(str(image))
elif isinstance(image, Image.Image):
return image
elif isinstance(image, np.ndarray):
return Image.fromarray(image)
else:
raise ValueError(f"不支持的图像类型: {type(image)}")
def _calculate_total_confidence(self, results: List[OCRResult]) -> float:
"""
计算平均置信度
Args:
results: OCR 结果列表
Returns:
平均置信度 (0-1)
"""
if not results:
return 0.0
return sum(r.confidence for r in results) / len(results)
class PaddleOCREngine(BaseOCREngine):
"""
PaddleOCR 本地识别引擎
使用 PaddleOCR 进行本地文字识别
"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""
初始化 PaddleOCR 引擎
Args:
config: 配置字典,支持:
- use_gpu: 是否使用 GPU (默认 False)
- lang: 语言 (默认 "ch",支持 ch/en/chinese_chinese)
- show_log: 是否显示日志 (默认 False)
"""
super().__init__(config)
if PaddleOCR is None:
raise ImportError(
"PaddleOCR 未安装。请运行: pip install paddleocr paddlepaddle"
)
# 解析配置
self.use_gpu = self.config.get('use_gpu', False)
self.lang = self.config.get('lang', 'ch')
self.show_log = self.config.get('show_log', False)
# 初始化 PaddleOCR
logger.info(f"初始化 PaddleOCR (lang={self.lang}, gpu={self.use_gpu})")
self.ocr = PaddleOCR(
use_angle_cls=True, # 使用方向分类器
lang=self.lang,
use_gpu=self.use_gpu,
show_log=self.show_log
)
def recognize(
self,
image,
preprocess: bool = False,
**kwargs
) -> OCRBatchResult:
"""
使用 PaddleOCR 识别图像中的文本
Args:
image: 图像路径、PIL Image 或 numpy 数组)
preprocess: 是否预处理图像
**kwargs: 其他参数(未使用)
Returns:
OCRBatchResult: 识别结果
"""
try:
# 加载图像
pil_image = self._load_image(image)
# 预处理(如果启用)
if preprocess:
pil_image = self.preprocessor.preprocess(pil_image)
# 转换为 numpy 数组PaddleOCR 需要)
img_array = np.array(pil_image)
# 执行 OCR
result = self.ocr.ocr(img_array, cls=True)
# 解析结果
ocr_results = []
full_lines = []
if result and result[0]:
for line_index, line in enumerate(result[0]):
if line:
# PaddleOCR 返回格式: [[bbox], (text, confidence)]
bbox = line[0]
text_info = line[1]
text = text_info[0]
confidence = float(text_info[1])
ocr_result = OCRResult(
text=text,
confidence=confidence,
bbox=bbox,
line_index=line_index
)
ocr_results.append(ocr_result)
full_lines.append(text)
# 计算平均置信度
total_confidence = self._calculate_total_confidence(ocr_results)
# 拼接完整文本
full_text = '\n'.join(full_lines)
logger.info(f"OCR 识别完成: {len(ocr_results)} 行, 平均置信度 {total_confidence:.2f}")
return OCRBatchResult(
results=ocr_results,
full_text=full_text,
total_confidence=total_confidence,
success=True
)
except Exception as e:
logger.error(f"OCR 识别失败: {e}", exc_info=True)
return OCRBatchResult(
results=[],
full_text="",
total_confidence=0.0,
success=False,
error_message=str(e)
)
class CloudOCREngine(BaseOCREngine):
"""
云端 OCR 引擎(适配器)
预留接口,用于扩展云端 OCR 服务
支持:百度 OCR、腾讯 OCR、阿里云 OCR 等
"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""
初始化云端 OCR 引擎
Args:
config: 配置字典,支持:
- api_endpoint: API 端点
- api_key: API 密钥
- provider: 提供商 (baidu/tencent/aliyun/custom)
- timeout: 超时时间(秒)
"""
super().__init__(config)
self.api_endpoint = self.config.get('api_endpoint', '')
self.api_key = self.config.get('api_key', '')
self.provider = self.config.get('provider', 'custom')
self.timeout = self.config.get('timeout', 30)
if not self.api_endpoint:
logger.warning("云端 OCR: api_endpoint 未配置")
def recognize(
self,
image,
preprocess: bool = False,
**kwargs
) -> OCRBatchResult:
"""
使用云端 API 识别图像中的文本
Args:
image: 图像路径、PIL Image
preprocess: 是否预处理图像
**kwargs: 其他参数
Returns:
OCRBatchResult: 识别结果
"""
# 这是一个占位实现
# 实际使用时需要根据具体的云端 OCR API 实现
logger.warning("云端 OCR 尚未实现,请使用本地 PaddleOCR 或自行实现")
return OCRBatchResult(
results=[],
full_text="",
total_confidence=0.0,
success=False,
error_message="云端 OCR 尚未实现"
)
def _send_request(self, image_data: bytes) -> Dict[str, Any]:
"""
发送 API 请求(待实现)
Args:
image_data: 图像二进制数据
Returns:
API 响应
"""
raise NotImplementedError("请根据具体云服务 API 实现此方法")
class OCRFactory:
"""
OCR 引擎工厂
根据配置创建对应的 OCR 引擎实例
"""
@staticmethod
def create_engine(
mode: str = "local",
config: Optional[Dict[str, Any]] = None
) -> BaseOCREngine:
"""
创建 OCR 引擎
Args:
mode: OCR 模式 ("local""cloud")
config: 配置字典
Returns:
BaseOCREngine: OCR 引擎实例
Raises:
ValueError: 不支持的 OCR 模式
"""
if mode == "local":
return PaddleOCREngine(config)
elif mode == "cloud":
return CloudOCREngine(config)
else:
raise ValueError(f"不支持的 OCR 模式: {mode}")
# 便捷函数
def recognize_text(
image,
mode: str = "local",
lang: str = "ch",
use_gpu: bool = False,
preprocess: bool = False,
**kwargs
) -> OCRBatchResult:
"""
快捷识别文本
Args:
image: 图像路径、PIL Image
mode: OCR 模式 ("local""cloud")
lang: 语言 (ch/en/chinese_chinese)
use_gpu: 是否使用 GPU仅本地模式
preprocess: 是否预处理图像
**kwargs: 其他配置
Returns:
OCRBatchResult: 识别结果
"""
config = {
'lang': lang,
'use_gpu': use_gpu,
**kwargs
}
engine = OCRFactory.create_engine(mode, config)
return engine.recognize(image, preprocess=preprocess)
def preprocess_image(
image_path: str,
output_path: Optional[str] = None,
**kwargs
) -> Image.Image:
"""
快捷预处理图像
Args:
image_path: 输入图像路径
output_path: 输出图像路径(如果指定,则保存)
**kwargs: 预处理参数
Returns:
PIL Image: 处理后的图像
"""
processed = ImagePreprocessor.preprocess_from_path(image_path, **kwargs)
if output_path:
processed.save(output_path)
logger.info(f"预处理图像已保存到: {output_path}")
return processed