2026-02-11 18:21:31 +08:00
|
|
|
|
"""
|
2026-02-12 13:42:46 +08:00
|
|
|
|
OCR 模块 - 纯云端版本
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
2026-02-12 13:42:46 +08:00
|
|
|
|
提供云端 API 文字识别功能:
|
|
|
|
|
|
- 云端 OCR API 调用(百度/腾讯/阿里云等)
|
2026-02-11 18:21:31 +08:00
|
|
|
|
- 图片预处理增强
|
|
|
|
|
|
- 多语言支持(中/英/混合)
|
2026-02-12 13:42:46 +08:00
|
|
|
|
|
|
|
|
|
|
注意:本版本不包含本地 OCR 引擎,所有 OCR 处理通过云端 API 完成。
|
2026-02-11 18:21:31 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
from pathlib import Path
|
2026-02-12 13:42:46 +08:00
|
|
|
|
from typing import List, Optional, Dict, Any
|
2026-02-11 18:21:31 +08:00
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
from enum import Enum
|
|
|
|
|
|
import logging
|
2026-02-12 13:42:46 +08:00
|
|
|
|
import base64
|
|
|
|
|
|
import io
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
from PIL import Image, ImageEnhance, ImageFilter
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
raise ImportError(
|
2026-02-12 13:42:46 +08:00
|
|
|
|
"请安装图像处理库: pip install pillow"
|
2026-02-11 18:21:31 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
2026-02-12 13:42:46 +08:00
|
|
|
|
import requests
|
2026-02-11 18:21:31 +08:00
|
|
|
|
except ImportError:
|
2026-02-12 13:42:46 +08:00
|
|
|
|
raise ImportError(
|
|
|
|
|
|
"请安装 requests 库: pip install requests"
|
|
|
|
|
|
)
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
|
|
|
|
|
# 配置日志
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OCRLanguage(str, Enum):
|
|
|
|
|
|
"""OCR 支持的语言"""
|
|
|
|
|
|
CHINESE = "ch" # 中文
|
|
|
|
|
|
ENGLISH = "en" # 英文
|
2026-02-12 13:42:46 +08:00
|
|
|
|
MIXED = "ch_en" # 中英文混合
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OCRProvider(str, Enum):
|
|
|
|
|
|
"""OCR 云端服务提供商"""
|
|
|
|
|
|
BAIDU = "baidu" # 百度 OCR
|
|
|
|
|
|
TENCENT = "tencent" # 腾讯云 OCR
|
|
|
|
|
|
ALIYUN = "aliyun" # 阿里云 OCR
|
|
|
|
|
|
CUSTOM = "custom" # 自定义 API
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class OCRResult:
|
|
|
|
|
|
"""
|
|
|
|
|
|
OCR 识别结果
|
|
|
|
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
|
|
text: 识别的文本内容
|
|
|
|
|
|
confidence: 置信度 (0-1)
|
|
|
|
|
|
bbox: 文本框坐标 [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
|
|
|
|
|
|
line_index: 行索引(从0开始)
|
|
|
|
|
|
"""
|
|
|
|
|
|
text: str
|
|
|
|
|
|
confidence: float
|
|
|
|
|
|
bbox: Optional[List[List[float]]] = None
|
|
|
|
|
|
line_index: int = 0
|
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
|
|
return f"OCRResult(text='{self.text[:30]}...', confidence={self.confidence:.2f})"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class OCRBatchResult:
|
|
|
|
|
|
"""
|
|
|
|
|
|
OCR 批量识别结果
|
|
|
|
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
|
|
results: 所有的识别结果列表
|
|
|
|
|
|
full_text: 完整文本(所有行拼接)
|
|
|
|
|
|
total_confidence: 平均置信度
|
|
|
|
|
|
success: 是否识别成功
|
|
|
|
|
|
error_message: 错误信息(如果失败)
|
|
|
|
|
|
"""
|
|
|
|
|
|
results: List[OCRResult]
|
|
|
|
|
|
full_text: str
|
|
|
|
|
|
total_confidence: float
|
|
|
|
|
|
success: bool = True
|
|
|
|
|
|
error_message: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
|
|
return f"OCRBatchResult(lines={len(self.results)}, confidence={self.total_confidence:.2f})"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ImagePreprocessor:
|
|
|
|
|
|
"""
|
|
|
|
|
|
图像预处理器
|
|
|
|
|
|
|
|
|
|
|
|
提供常见的图像增强和预处理功能,提高 OCR 识别准确率
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def load_image(image_path: str) -> Image.Image:
|
|
|
|
|
|
"""
|
|
|
|
|
|
加载图像
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
image_path: 图像文件路径
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
PIL Image 对象
|
|
|
|
|
|
"""
|
|
|
|
|
|
image = Image.open(image_path)
|
|
|
|
|
|
# 转换为 RGB 模式
|
|
|
|
|
|
if image.mode != 'RGB':
|
|
|
|
|
|
image = image.convert('RGB')
|
|
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def resize_image(image: Image.Image, max_width: int = 2000) -> Image.Image:
|
|
|
|
|
|
"""
|
|
|
|
|
|
调整图像大小(保持宽高比)
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
image: PIL Image 对象
|
|
|
|
|
|
max_width: 最大宽度
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
调整后的图像
|
|
|
|
|
|
"""
|
|
|
|
|
|
if image.width > max_width:
|
|
|
|
|
|
ratio = max_width / image.width
|
|
|
|
|
|
new_height = int(image.height * ratio)
|
|
|
|
|
|
image = image.resize((max_width, new_height), Image.Resampling.LANCZOS)
|
|
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def enhance_contrast(image: Image.Image, factor: float = 1.5) -> Image.Image:
|
|
|
|
|
|
"""
|
|
|
|
|
|
增强对比度
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
image: PIL Image 对象
|
|
|
|
|
|
factor: 增强因子,1.0 表示原始,>1.0 增强,<1.0 减弱
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
处理后的图像
|
|
|
|
|
|
"""
|
|
|
|
|
|
enhancer = ImageEnhance.Contrast(image)
|
|
|
|
|
|
return enhancer.enhance(factor)
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def enhance_sharpness(image: Image.Image, factor: float = 1.5) -> Image.Image:
|
|
|
|
|
|
"""
|
|
|
|
|
|
增强锐度
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
image: PIL Image 对象
|
|
|
|
|
|
factor: 锐化因子
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
处理后的图像
|
|
|
|
|
|
"""
|
|
|
|
|
|
enhancer = ImageEnhance.Sharpness(image)
|
|
|
|
|
|
return enhancer.enhance(factor)
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def denoise(image: Image.Image) -> Image.Image:
|
|
|
|
|
|
"""
|
|
|
|
|
|
去噪(使用中值滤波)
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
image: PIL Image 对象
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
处理后的图像
|
|
|
|
|
|
"""
|
|
|
|
|
|
return image.filter(ImageFilter.MedianFilter(size=3))
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def preprocess(
|
|
|
|
|
|
image: Image.Image,
|
|
|
|
|
|
resize: bool = True,
|
|
|
|
|
|
enhance_contrast: bool = True,
|
|
|
|
|
|
enhance_sharpness: bool = True,
|
2026-02-12 13:42:46 +08:00
|
|
|
|
denoise: bool = False
|
2026-02-11 18:21:31 +08:00
|
|
|
|
) -> Image.Image:
|
|
|
|
|
|
"""
|
|
|
|
|
|
综合预处理(根据指定选项)
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
image: PIL Image 对象
|
|
|
|
|
|
resize: 是否调整大小
|
|
|
|
|
|
enhance_contrast: 是否增强对比度
|
|
|
|
|
|
enhance_sharpness: 是否增强锐度
|
|
|
|
|
|
denoise: 是否去噪
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
处理后的图像
|
|
|
|
|
|
"""
|
|
|
|
|
|
result = image.copy()
|
|
|
|
|
|
|
|
|
|
|
|
if resize:
|
|
|
|
|
|
result = ImagePreprocessor.resize_image(result)
|
|
|
|
|
|
|
|
|
|
|
|
if enhance_contrast:
|
|
|
|
|
|
result = ImagePreprocessor.enhance_contrast(result)
|
|
|
|
|
|
|
|
|
|
|
|
if enhance_sharpness:
|
|
|
|
|
|
result = ImagePreprocessor.enhance_sharpness(result)
|
|
|
|
|
|
|
|
|
|
|
|
if denoise:
|
|
|
|
|
|
result = ImagePreprocessor.denoise(result)
|
|
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
2026-02-12 13:42:46 +08:00
|
|
|
|
def image_to_base64(image: Image.Image, format: str = "JPEG") -> str:
|
2026-02-11 18:21:31 +08:00
|
|
|
|
"""
|
2026-02-12 13:42:46 +08:00
|
|
|
|
将 PIL Image 转换为 base64 编码
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
|
|
|
|
|
Args:
|
2026-02-12 13:42:46 +08:00
|
|
|
|
image: PIL Image 对象
|
|
|
|
|
|
format: 图像格式 (JPEG/PNG)
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
|
|
|
|
|
Returns:
|
2026-02-12 13:42:46 +08:00
|
|
|
|
base64 编码的字符串
|
2026-02-11 18:21:31 +08:00
|
|
|
|
"""
|
2026-02-12 13:42:46 +08:00
|
|
|
|
buffer = io.BytesIO()
|
|
|
|
|
|
image.save(buffer, format=format)
|
|
|
|
|
|
img_bytes = buffer.getvalue()
|
|
|
|
|
|
return base64.b64encode(img_bytes).decode('utf-8')
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseOCREngine(ABC):
|
|
|
|
|
|
"""
|
|
|
|
|
|
OCR 引擎基类
|
|
|
|
|
|
|
|
|
|
|
|
所有 OCR 实现必须继承此类并实现 recognize 方法
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
|
|
|
|
|
"""
|
|
|
|
|
|
初始化 OCR 引擎
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
config: OCR 配置字典
|
|
|
|
|
|
"""
|
|
|
|
|
|
self.config = config or {}
|
|
|
|
|
|
self.preprocessor = ImagePreprocessor()
|
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
|
def recognize(
|
|
|
|
|
|
self,
|
|
|
|
|
|
image,
|
|
|
|
|
|
preprocess: bool = True,
|
|
|
|
|
|
**kwargs
|
|
|
|
|
|
) -> OCRBatchResult:
|
|
|
|
|
|
"""
|
|
|
|
|
|
识别图像中的文本
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
2026-02-12 13:42:46 +08:00
|
|
|
|
image: 图像(可以是路径或 PIL Image)
|
2026-02-11 18:21:31 +08:00
|
|
|
|
preprocess: 是否预处理图像
|
|
|
|
|
|
**kwargs: 其他参数
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
OCRBatchResult: 识别结果
|
|
|
|
|
|
"""
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def _load_image(self, image) -> Image.Image:
|
|
|
|
|
|
"""
|
|
|
|
|
|
加载图像(支持多种输入格式)
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
2026-02-12 13:42:46 +08:00
|
|
|
|
image: 图像(路径或 PIL Image)
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
PIL Image 对象
|
|
|
|
|
|
"""
|
|
|
|
|
|
if isinstance(image, str) or isinstance(image, Path):
|
|
|
|
|
|
return self.preprocessor.load_image(str(image))
|
|
|
|
|
|
elif isinstance(image, Image.Image):
|
|
|
|
|
|
return image
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError(f"不支持的图像类型: {type(image)}")
|
|
|
|
|
|
|
|
|
|
|
|
def _calculate_total_confidence(self, results: List[OCRResult]) -> float:
|
|
|
|
|
|
"""
|
|
|
|
|
|
计算平均置信度
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
results: OCR 结果列表
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
平均置信度 (0-1)
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not results:
|
|
|
|
|
|
return 0.0
|
|
|
|
|
|
return sum(r.confidence for r in results) / len(results)
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-02-12 13:42:46 +08:00
|
|
|
|
class CloudOCREngine(BaseOCREngine):
|
2026-02-11 18:21:31 +08:00
|
|
|
|
"""
|
2026-02-12 13:42:46 +08:00
|
|
|
|
云端 OCR 引擎
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
2026-02-12 13:42:46 +08:00
|
|
|
|
支持多种云端 OCR 服务:
|
|
|
|
|
|
- 百度 OCR
|
|
|
|
|
|
- 腾讯云 OCR
|
|
|
|
|
|
- 阿里云 OCR
|
|
|
|
|
|
- 自定义 API
|
2026-02-11 18:21:31 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
|
|
|
|
|
"""
|
2026-02-12 13:42:46 +08:00
|
|
|
|
初始化云端 OCR 引擎
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
config: 配置字典,支持:
|
2026-02-12 13:42:46 +08:00
|
|
|
|
- api_endpoint: API 端点
|
|
|
|
|
|
- api_key: API 密钥
|
|
|
|
|
|
- api_secret: API 密钥(部分服务商需要)
|
|
|
|
|
|
- provider: 提供商 (baidu/tencent/aliyun/custom)
|
|
|
|
|
|
- timeout: 超时时间(秒)
|
2026-02-11 18:21:31 +08:00
|
|
|
|
"""
|
|
|
|
|
|
super().__init__(config)
|
|
|
|
|
|
|
2026-02-12 13:42:46 +08:00
|
|
|
|
self.api_endpoint = self.config.get('api_endpoint', '')
|
|
|
|
|
|
self.api_key = self.config.get('api_key', '')
|
|
|
|
|
|
self.api_secret = self.config.get('api_secret', '')
|
|
|
|
|
|
self.provider = self.config.get('provider', 'custom')
|
|
|
|
|
|
self.timeout = self.config.get('timeout', 30)
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
2026-02-12 13:42:46 +08:00
|
|
|
|
if not self.api_endpoint:
|
|
|
|
|
|
logger.warning("云端 OCR: api_endpoint 未配置,OCR 功能将不可用")
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
|
|
|
|
|
def recognize(
|
|
|
|
|
|
self,
|
|
|
|
|
|
image,
|
2026-02-12 13:42:46 +08:00
|
|
|
|
preprocess: bool = True,
|
2026-02-11 18:21:31 +08:00
|
|
|
|
**kwargs
|
|
|
|
|
|
) -> OCRBatchResult:
|
|
|
|
|
|
"""
|
2026-02-12 13:42:46 +08:00
|
|
|
|
使用云端 API 识别图像中的文本
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
|
|
|
|
|
Args:
|
2026-02-12 13:42:46 +08:00
|
|
|
|
image: 图像(路径或 PIL Image)
|
2026-02-11 18:21:31 +08:00
|
|
|
|
preprocess: 是否预处理图像
|
2026-02-12 13:42:46 +08:00
|
|
|
|
**kwargs: 其他参数
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
OCRBatchResult: 识别结果
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 加载图像
|
|
|
|
|
|
pil_image = self._load_image(image)
|
|
|
|
|
|
|
|
|
|
|
|
# 预处理(如果启用)
|
|
|
|
|
|
if preprocess:
|
|
|
|
|
|
pil_image = self.preprocessor.preprocess(pil_image)
|
|
|
|
|
|
|
2026-02-12 13:42:46 +08:00
|
|
|
|
# 转换为 base64
|
|
|
|
|
|
img_base64 = self.preprocessor.image_to_base64(pil_image)
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
2026-02-12 13:42:46 +08:00
|
|
|
|
# 根据提供商调用不同的 API
|
|
|
|
|
|
if self.provider == OCRProvider.BAIDU:
|
|
|
|
|
|
return self._baidu_ocr(img_base64)
|
|
|
|
|
|
elif self.provider == OCRProvider.TENCENT:
|
|
|
|
|
|
return self._tencent_ocr(img_base64)
|
|
|
|
|
|
elif self.provider == OCRProvider.ALIYUN:
|
|
|
|
|
|
return self._aliyun_ocr(img_base64)
|
|
|
|
|
|
else:
|
|
|
|
|
|
return self._custom_api_ocr(img_base64)
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
2026-02-12 13:42:46 +08:00
|
|
|
|
logger.error(f"云端 OCR 识别失败: {e}", exc_info=True)
|
2026-02-11 18:21:31 +08:00
|
|
|
|
return OCRBatchResult(
|
|
|
|
|
|
results=[],
|
|
|
|
|
|
full_text="",
|
|
|
|
|
|
total_confidence=0.0,
|
|
|
|
|
|
success=False,
|
|
|
|
|
|
error_message=str(e)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-02-12 13:42:46 +08:00
|
|
|
|
def _baidu_ocr(self, img_base64: str) -> OCRBatchResult:
|
|
|
|
|
|
"""百度 OCR API"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 百度 OCR API 实现
|
|
|
|
|
|
url = "https://aip.baidubce.com/rest/2.0/ocr/v1/general_basic"
|
|
|
|
|
|
|
|
|
|
|
|
# 获取 access_token(简化版本,实际应该缓存)
|
|
|
|
|
|
if self.api_key and self.api_secret:
|
|
|
|
|
|
token_url = f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={self.api_key}&client_secret={self.api_secret}"
|
|
|
|
|
|
token_resp = requests.get(token_url, timeout=self.timeout)
|
|
|
|
|
|
if token_resp.status_code == 200:
|
|
|
|
|
|
access_token = token_resp.json().get('access_token', '')
|
|
|
|
|
|
url = f"{url}?access_token={access_token}"
|
|
|
|
|
|
|
|
|
|
|
|
data = {
|
|
|
|
|
|
'image': img_base64
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
response = requests.post(url, data=data, timeout=self.timeout)
|
|
|
|
|
|
result = response.json()
|
|
|
|
|
|
|
|
|
|
|
|
if 'words_result' in result:
|
|
|
|
|
|
ocr_results = []
|
|
|
|
|
|
full_lines = []
|
|
|
|
|
|
|
|
|
|
|
|
for idx, item in enumerate(result['words_result']):
|
|
|
|
|
|
text = item.get('words', '')
|
|
|
|
|
|
ocr_result = OCRResult(
|
|
|
|
|
|
text=text,
|
|
|
|
|
|
confidence=0.95, # 百度 API 不返回置信度
|
|
|
|
|
|
line_index=idx
|
|
|
|
|
|
)
|
|
|
|
|
|
ocr_results.append(ocr_result)
|
|
|
|
|
|
full_lines.append(text)
|
|
|
|
|
|
|
|
|
|
|
|
full_text = '\n'.join(full_lines)
|
|
|
|
|
|
total_confidence = self._calculate_total_confidence(ocr_results)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"百度 OCR 识别完成: {len(ocr_results)} 行")
|
|
|
|
|
|
|
|
|
|
|
|
return OCRBatchResult(
|
|
|
|
|
|
results=ocr_results,
|
|
|
|
|
|
full_text=full_text,
|
|
|
|
|
|
total_confidence=total_confidence,
|
|
|
|
|
|
success=True
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
error_msg = result.get('error_msg', '未知错误')
|
|
|
|
|
|
return OCRBatchResult(
|
|
|
|
|
|
results=[],
|
|
|
|
|
|
full_text="",
|
|
|
|
|
|
total_confidence=0.0,
|
|
|
|
|
|
success=False,
|
|
|
|
|
|
error_message=error_msg
|
|
|
|
|
|
)
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
2026-02-12 13:42:46 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"百度 OCR 调用失败: {e}")
|
|
|
|
|
|
return OCRBatchResult(
|
|
|
|
|
|
results=[],
|
|
|
|
|
|
full_text="",
|
|
|
|
|
|
total_confidence=0.0,
|
|
|
|
|
|
success=False,
|
|
|
|
|
|
error_message=f"百度 OCR 调用失败: {str(e)}"
|
|
|
|
|
|
)
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
2026-02-12 13:42:46 +08:00
|
|
|
|
def _tencent_ocr(self, img_base64: str) -> OCRBatchResult:
|
|
|
|
|
|
"""腾讯云 OCR API"""
|
|
|
|
|
|
# 腾讯云 OCR 实现占位
|
|
|
|
|
|
logger.warning("腾讯云 OCR 尚未实现")
|
|
|
|
|
|
return OCRBatchResult(
|
|
|
|
|
|
results=[],
|
|
|
|
|
|
full_text="",
|
|
|
|
|
|
total_confidence=0.0,
|
|
|
|
|
|
success=False,
|
|
|
|
|
|
error_message="腾讯云 OCR 尚未实现"
|
|
|
|
|
|
)
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
2026-02-12 13:42:46 +08:00
|
|
|
|
def _aliyun_ocr(self, img_base64: str) -> OCRBatchResult:
|
|
|
|
|
|
"""阿里云 OCR API"""
|
|
|
|
|
|
# 阿里云 OCR 实现占位
|
|
|
|
|
|
logger.warning("阿里云 OCR 尚未实现")
|
2026-02-11 18:21:31 +08:00
|
|
|
|
return OCRBatchResult(
|
|
|
|
|
|
results=[],
|
|
|
|
|
|
full_text="",
|
|
|
|
|
|
total_confidence=0.0,
|
|
|
|
|
|
success=False,
|
2026-02-12 13:42:46 +08:00
|
|
|
|
error_message="阿里云 OCR 尚未实现"
|
2026-02-11 18:21:31 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2026-02-12 13:42:46 +08:00
|
|
|
|
def _custom_api_ocr(self, img_base64: str) -> OCRBatchResult:
|
|
|
|
|
|
"""自定义 API OCR"""
|
|
|
|
|
|
if not self.api_endpoint:
|
|
|
|
|
|
return OCRBatchResult(
|
|
|
|
|
|
results=[],
|
|
|
|
|
|
full_text="",
|
|
|
|
|
|
total_confidence=0.0,
|
|
|
|
|
|
success=False,
|
|
|
|
|
|
error_message="未配置云端 OCR API endpoint"
|
|
|
|
|
|
)
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
2026-02-12 13:42:46 +08:00
|
|
|
|
try:
|
|
|
|
|
|
headers = {
|
|
|
|
|
|
'Content-Type': 'application/json',
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 添加 API Key(如果有)
|
|
|
|
|
|
if self.api_key:
|
|
|
|
|
|
headers['Authorization'] = f'Bearer {self.api_key}'
|
|
|
|
|
|
|
|
|
|
|
|
data = {
|
|
|
|
|
|
'image': img_base64,
|
|
|
|
|
|
'format': 'base64'
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
response = requests.post(
|
|
|
|
|
|
self.api_endpoint,
|
|
|
|
|
|
json=data,
|
|
|
|
|
|
headers=headers,
|
|
|
|
|
|
timeout=self.timeout
|
|
|
|
|
|
)
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
2026-02-12 13:42:46 +08:00
|
|
|
|
if response.status_code == 200:
|
|
|
|
|
|
result = response.json()
|
|
|
|
|
|
|
|
|
|
|
|
# 尝试解析常见格式
|
|
|
|
|
|
if 'text' in result:
|
|
|
|
|
|
# 简单文本格式
|
|
|
|
|
|
full_text = result['text']
|
|
|
|
|
|
ocr_results = [OCRResult(text=full_text, confidence=0.9)]
|
|
|
|
|
|
return OCRBatchResult(
|
|
|
|
|
|
results=ocr_results,
|
|
|
|
|
|
full_text=full_text,
|
|
|
|
|
|
total_confidence=0.9,
|
|
|
|
|
|
success=True
|
|
|
|
|
|
)
|
|
|
|
|
|
elif 'lines' in result:
|
|
|
|
|
|
# 多行格式
|
|
|
|
|
|
ocr_results = []
|
|
|
|
|
|
full_lines = []
|
|
|
|
|
|
for idx, line in enumerate(result['lines']):
|
|
|
|
|
|
text = line.get('text', '')
|
|
|
|
|
|
conf = line.get('confidence', 0.9)
|
|
|
|
|
|
ocr_results.append(OCRResult(text=text, confidence=conf, line_index=idx))
|
|
|
|
|
|
full_lines.append(text)
|
|
|
|
|
|
|
|
|
|
|
|
full_text = '\n'.join(full_lines)
|
|
|
|
|
|
total_confidence = self._calculate_total_confidence(ocr_results)
|
|
|
|
|
|
|
|
|
|
|
|
return OCRBatchResult(
|
|
|
|
|
|
results=ocr_results,
|
|
|
|
|
|
full_text=full_text,
|
|
|
|
|
|
total_confidence=total_confidence,
|
|
|
|
|
|
success=True
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
return OCRBatchResult(
|
|
|
|
|
|
results=[],
|
|
|
|
|
|
full_text="",
|
|
|
|
|
|
total_confidence=0.0,
|
|
|
|
|
|
success=False,
|
|
|
|
|
|
error_message=f"未知的响应格式: {list(result.keys())}"
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
return OCRBatchResult(
|
|
|
|
|
|
results=[],
|
|
|
|
|
|
full_text="",
|
|
|
|
|
|
total_confidence=0.0,
|
|
|
|
|
|
success=False,
|
|
|
|
|
|
error_message=f"API 请求失败: HTTP {response.status_code}"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"自定义 API OCR 调用失败: {e}")
|
|
|
|
|
|
return OCRBatchResult(
|
|
|
|
|
|
results=[],
|
|
|
|
|
|
full_text="",
|
|
|
|
|
|
total_confidence=0.0,
|
|
|
|
|
|
success=False,
|
|
|
|
|
|
error_message=f"API 调用失败: {str(e)}"
|
|
|
|
|
|
)
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OCRFactory:
|
|
|
|
|
|
"""
|
|
|
|
|
|
OCR 引擎工厂
|
|
|
|
|
|
|
|
|
|
|
|
根据配置创建对应的 OCR 引擎实例
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def create_engine(
|
2026-02-12 13:42:46 +08:00
|
|
|
|
mode: str = "cloud",
|
2026-02-11 18:21:31 +08:00
|
|
|
|
config: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
) -> BaseOCREngine:
|
|
|
|
|
|
"""
|
|
|
|
|
|
创建 OCR 引擎
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
2026-02-12 13:42:46 +08:00
|
|
|
|
mode: OCR 模式(当前仅支持 "cloud")
|
2026-02-11 18:21:31 +08:00
|
|
|
|
config: 配置字典
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
BaseOCREngine: OCR 引擎实例
|
|
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
|
ValueError: 不支持的 OCR 模式
|
|
|
|
|
|
"""
|
2026-02-12 13:42:46 +08:00
|
|
|
|
if mode == "cloud":
|
2026-02-11 18:21:31 +08:00
|
|
|
|
return CloudOCREngine(config)
|
|
|
|
|
|
else:
|
2026-02-12 13:42:46 +08:00
|
|
|
|
# 为了向后兼容,非 cloud 模式也返回云端引擎
|
|
|
|
|
|
logger.warning(f"OCR 模式 '{mode}' 已弃用,使用云端 OCR")
|
|
|
|
|
|
return CloudOCREngine(config)
|
2026-02-11 18:21:31 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 便捷函数
|
|
|
|
|
|
def recognize_text(
|
|
|
|
|
|
image,
|
2026-02-12 13:42:46 +08:00
|
|
|
|
mode: str = "cloud",
|
|
|
|
|
|
preprocess: bool = True,
|
2026-02-11 18:21:31 +08:00
|
|
|
|
**kwargs
|
|
|
|
|
|
) -> OCRBatchResult:
|
|
|
|
|
|
"""
|
|
|
|
|
|
快捷识别文本
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
2026-02-12 13:42:46 +08:00
|
|
|
|
image: 图像(路径或 PIL Image)
|
|
|
|
|
|
mode: OCR 模式(仅支持 "cloud")
|
2026-02-11 18:21:31 +08:00
|
|
|
|
preprocess: 是否预处理图像
|
|
|
|
|
|
**kwargs: 其他配置
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
OCRBatchResult: 识别结果
|
|
|
|
|
|
"""
|
2026-02-12 13:42:46 +08:00
|
|
|
|
config = kwargs.copy()
|
2026-02-11 18:21:31 +08:00
|
|
|
|
engine = OCRFactory.create_engine(mode, config)
|
|
|
|
|
|
return engine.recognize(image, preprocess=preprocess)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_image(
|
|
|
|
|
|
image_path: str,
|
|
|
|
|
|
output_path: Optional[str] = None,
|
|
|
|
|
|
**kwargs
|
|
|
|
|
|
) -> Image.Image:
|
|
|
|
|
|
"""
|
|
|
|
|
|
快捷预处理图像
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
image_path: 输入图像路径
|
|
|
|
|
|
output_path: 输出图像路径(如果指定,则保存)
|
|
|
|
|
|
**kwargs: 预处理参数
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
PIL Image: 处理后的图像
|
|
|
|
|
|
"""
|
|
|
|
|
|
processed = ImagePreprocessor.preprocess_from_path(image_path, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
if output_path:
|
|
|
|
|
|
processed.save(output_path)
|
|
|
|
|
|
logger.info(f"预处理图像已保存到: {output_path}")
|
|
|
|
|
|
|
|
|
|
|
|
return processed
|