feat: 实现CutThenThink P0阶段核心功能
项目初始化 - 创建完整项目结构(src/, data/, docs/, examples/, tests/) - 配置requirements.txt依赖 - 创建.gitignore P0基础框架 - 数据库模型:Record模型,6种分类类型 - 配置管理:YAML配置,支持AI/OCR/云存储/UI配置 - OCR模块:PaddleOCR本地识别,支持云端扩展 - AI模块:支持OpenAI/Claude/通义/Ollama,6种分类 - 存储模块:完整CRUD,搜索,统计,导入导出 - 主窗口框架:侧边栏导航,米白配色方案 - 图片处理:截图/剪贴板/文件选择/图片预览 - 处理流程整合:OCR→AI→存储串联,Markdown展示,剪贴板复制 - 分类浏览:卡片网格展示,分类筛选,搜索,详情查看 技术栈 - PyQt6 + SQLAlchemy + PaddleOCR + OpenAI/Claude SDK - 共47个Python文件,4000+行代码 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
613
src/core/ocr.py
Normal file
613
src/core/ocr.py
Normal file
@@ -0,0 +1,613 @@
|
||||
"""
|
||||
OCR 模块
|
||||
|
||||
提供文字识别功能,支持:
|
||||
- 本地 PaddleOCR 识别
|
||||
- 云端 OCR API 扩展
|
||||
- 图片预处理增强
|
||||
- 多语言支持(中/英/混合)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
try:
|
||||
from PIL import Image, ImageEnhance, ImageFilter
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"请安装图像处理库: pip install pillow numpy"
|
||||
)
|
||||
|
||||
try:
|
||||
from paddleocr import PaddleOCR
|
||||
except ImportError:
|
||||
PaddleOCR = None
|
||||
logging.warning("PaddleOCR 未安装,本地 OCR 功能不可用")
|
||||
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OCRLanguage(str, Enum):
|
||||
"""OCR 支持的语言"""
|
||||
CHINESE = "ch" # 中文
|
||||
ENGLISH = "en" # 英文
|
||||
MIXED = "chinese_chinese" # 中英文混合
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
"""
|
||||
OCR 识别结果
|
||||
|
||||
Attributes:
|
||||
text: 识别的文本内容
|
||||
confidence: 置信度 (0-1)
|
||||
bbox: 文本框坐标 [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
|
||||
line_index: 行索引(从0开始)
|
||||
"""
|
||||
text: str
|
||||
confidence: float
|
||||
bbox: Optional[List[List[float]]] = None
|
||||
line_index: int = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"OCRResult(text='{self.text[:30]}...', confidence={self.confidence:.2f})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRBatchResult:
|
||||
"""
|
||||
OCR 批量识别结果
|
||||
|
||||
Attributes:
|
||||
results: 所有的识别结果列表
|
||||
full_text: 完整文本(所有行拼接)
|
||||
total_confidence: 平均置信度
|
||||
success: 是否识别成功
|
||||
error_message: 错误信息(如果失败)
|
||||
"""
|
||||
results: List[OCRResult]
|
||||
full_text: str
|
||||
total_confidence: float
|
||||
success: bool = True
|
||||
error_message: Optional[str] = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"OCRBatchResult(lines={len(self.results)}, confidence={self.total_confidence:.2f})"
|
||||
|
||||
|
||||
class ImagePreprocessor:
|
||||
"""
|
||||
图像预处理器
|
||||
|
||||
提供常见的图像增强和预处理功能,提高 OCR 识别准确率
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load_image(image_path: str) -> Image.Image:
|
||||
"""
|
||||
加载图像
|
||||
|
||||
Args:
|
||||
image_path: 图像文件路径
|
||||
|
||||
Returns:
|
||||
PIL Image 对象
|
||||
"""
|
||||
image = Image.open(image_path)
|
||||
# 转换为 RGB 模式
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def resize_image(image: Image.Image, max_width: int = 2000) -> Image.Image:
|
||||
"""
|
||||
调整图像大小(保持宽高比)
|
||||
|
||||
Args:
|
||||
image: PIL Image 对象
|
||||
max_width: 最大宽度
|
||||
|
||||
Returns:
|
||||
调整后的图像
|
||||
"""
|
||||
if image.width > max_width:
|
||||
ratio = max_width / image.width
|
||||
new_height = int(image.height * ratio)
|
||||
image = image.resize((max_width, new_height), Image.Resampling.LANCZOS)
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def enhance_contrast(image: Image.Image, factor: float = 1.5) -> Image.Image:
|
||||
"""
|
||||
增强对比度
|
||||
|
||||
Args:
|
||||
image: PIL Image 对象
|
||||
factor: 增强因子,1.0 表示原始,>1.0 增强,<1.0 减弱
|
||||
|
||||
Returns:
|
||||
处理后的图像
|
||||
"""
|
||||
enhancer = ImageEnhance.Contrast(image)
|
||||
return enhancer.enhance(factor)
|
||||
|
||||
@staticmethod
|
||||
def enhance_sharpness(image: Image.Image, factor: float = 1.5) -> Image.Image:
|
||||
"""
|
||||
增强锐度
|
||||
|
||||
Args:
|
||||
image: PIL Image 对象
|
||||
factor: 锐化因子
|
||||
|
||||
Returns:
|
||||
处理后的图像
|
||||
"""
|
||||
enhancer = ImageEnhance.Sharpness(image)
|
||||
return enhancer.enhance(factor)
|
||||
|
||||
@staticmethod
|
||||
def enhance_brightness(image: Image.Image, factor: float = 1.1) -> Image.Image:
|
||||
"""
|
||||
调整亮度
|
||||
|
||||
Args:
|
||||
image: PIL Image 对象
|
||||
factor: 亮度因子
|
||||
|
||||
Returns:
|
||||
处理后的图像
|
||||
"""
|
||||
enhancer = ImageEnhance.Brightness(image)
|
||||
return enhancer.enhance(factor)
|
||||
|
||||
@staticmethod
|
||||
def denoise(image: Image.Image) -> Image.Image:
|
||||
"""
|
||||
去噪(使用中值滤波)
|
||||
|
||||
Args:
|
||||
image: PIL Image 对象
|
||||
|
||||
Returns:
|
||||
处理后的图像
|
||||
"""
|
||||
return image.filter(ImageFilter.MedianFilter(size=3))
|
||||
|
||||
@staticmethod
|
||||
def binarize(image: Image.Image, threshold: int = 127) -> Image.Image:
|
||||
"""
|
||||
二值化(转换为黑白图像)
|
||||
|
||||
Args:
|
||||
image: PIL Image 对象
|
||||
threshold: 二值化阈值
|
||||
|
||||
Returns:
|
||||
处理后的图像
|
||||
"""
|
||||
# 先转为灰度图
|
||||
gray = image.convert('L')
|
||||
# 二值化
|
||||
binary = gray.point(lambda x: 0 if x < threshold else 255, '1')
|
||||
# 转回 RGB
|
||||
return binary.convert('RGB')
|
||||
|
||||
@staticmethod
|
||||
def preprocess(
|
||||
image: Image.Image,
|
||||
resize: bool = True,
|
||||
enhance_contrast: bool = True,
|
||||
enhance_sharpness: bool = True,
|
||||
denoise: bool = False,
|
||||
binarize: bool = False
|
||||
) -> Image.Image:
|
||||
"""
|
||||
综合预处理(根据指定选项)
|
||||
|
||||
Args:
|
||||
image: PIL Image 对象
|
||||
resize: 是否调整大小
|
||||
enhance_contrast: 是否增强对比度
|
||||
enhance_sharpness: 是否增强锐度
|
||||
denoise: 是否去噪
|
||||
binarize: 是否二值化
|
||||
|
||||
Returns:
|
||||
处理后的图像
|
||||
"""
|
||||
result = image.copy()
|
||||
|
||||
if resize:
|
||||
result = ImagePreprocessor.resize_image(result)
|
||||
|
||||
if enhance_contrast:
|
||||
result = ImagePreprocessor.enhance_contrast(result)
|
||||
|
||||
if enhance_sharpness:
|
||||
result = ImagePreprocessor.enhance_sharpness(result)
|
||||
|
||||
if denoise:
|
||||
result = ImagePreprocessor.denoise(result)
|
||||
|
||||
if binarize:
|
||||
result = ImagePreprocessor.binarize(result)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def preprocess_from_path(
|
||||
image_path: str,
|
||||
**kwargs
|
||||
) -> Image.Image:
|
||||
"""
|
||||
从文件路径加载并预处理图像
|
||||
|
||||
Args:
|
||||
image_path: 图像文件路径
|
||||
**kwargs: preprocess 方法的参数
|
||||
|
||||
Returns:
|
||||
处理后的图像
|
||||
"""
|
||||
image = ImagePreprocessor.load_image(image_path)
|
||||
return ImagePreprocessor.preprocess(image, **kwargs)
|
||||
|
||||
|
||||
class BaseOCREngine(ABC):
|
||||
"""
|
||||
OCR 引擎基类
|
||||
|
||||
所有 OCR 实现必须继承此类并实现 recognize 方法
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
初始化 OCR 引擎
|
||||
|
||||
Args:
|
||||
config: OCR 配置字典
|
||||
"""
|
||||
self.config = config or {}
|
||||
self.preprocessor = ImagePreprocessor()
|
||||
|
||||
@abstractmethod
|
||||
def recognize(
|
||||
self,
|
||||
image,
|
||||
preprocess: bool = True,
|
||||
**kwargs
|
||||
) -> OCRBatchResult:
|
||||
"""
|
||||
识别图像中的文本
|
||||
|
||||
Args:
|
||||
image: 图像(可以是路径、PIL Image 或 numpy 数组)
|
||||
preprocess: 是否预处理图像
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
OCRBatchResult: 识别结果
|
||||
"""
|
||||
pass
|
||||
|
||||
def _load_image(self, image) -> Image.Image:
|
||||
"""
|
||||
加载图像(支持多种输入格式)
|
||||
|
||||
Args:
|
||||
image: 图像(路径、PIL Image 或 numpy 数组)
|
||||
|
||||
Returns:
|
||||
PIL Image 对象
|
||||
"""
|
||||
if isinstance(image, str) or isinstance(image, Path):
|
||||
return self.preprocessor.load_image(str(image))
|
||||
elif isinstance(image, Image.Image):
|
||||
return image
|
||||
elif isinstance(image, np.ndarray):
|
||||
return Image.fromarray(image)
|
||||
else:
|
||||
raise ValueError(f"不支持的图像类型: {type(image)}")
|
||||
|
||||
def _calculate_total_confidence(self, results: List[OCRResult]) -> float:
|
||||
"""
|
||||
计算平均置信度
|
||||
|
||||
Args:
|
||||
results: OCR 结果列表
|
||||
|
||||
Returns:
|
||||
平均置信度 (0-1)
|
||||
"""
|
||||
if not results:
|
||||
return 0.0
|
||||
return sum(r.confidence for r in results) / len(results)
|
||||
|
||||
|
||||
class PaddleOCREngine(BaseOCREngine):
|
||||
"""
|
||||
PaddleOCR 本地识别引擎
|
||||
|
||||
使用 PaddleOCR 进行本地文字识别
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
初始化 PaddleOCR 引擎
|
||||
|
||||
Args:
|
||||
config: 配置字典,支持:
|
||||
- use_gpu: 是否使用 GPU (默认 False)
|
||||
- lang: 语言 (默认 "ch",支持 ch/en/chinese_chinese)
|
||||
- show_log: 是否显示日志 (默认 False)
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
if PaddleOCR is None:
|
||||
raise ImportError(
|
||||
"PaddleOCR 未安装。请运行: pip install paddleocr paddlepaddle"
|
||||
)
|
||||
|
||||
# 解析配置
|
||||
self.use_gpu = self.config.get('use_gpu', False)
|
||||
self.lang = self.config.get('lang', 'ch')
|
||||
self.show_log = self.config.get('show_log', False)
|
||||
|
||||
# 初始化 PaddleOCR
|
||||
logger.info(f"初始化 PaddleOCR (lang={self.lang}, gpu={self.use_gpu})")
|
||||
self.ocr = PaddleOCR(
|
||||
use_angle_cls=True, # 使用方向分类器
|
||||
lang=self.lang,
|
||||
use_gpu=self.use_gpu,
|
||||
show_log=self.show_log
|
||||
)
|
||||
|
||||
def recognize(
|
||||
self,
|
||||
image,
|
||||
preprocess: bool = False,
|
||||
**kwargs
|
||||
) -> OCRBatchResult:
|
||||
"""
|
||||
使用 PaddleOCR 识别图像中的文本
|
||||
|
||||
Args:
|
||||
image: 图像(路径、PIL Image 或 numpy 数组)
|
||||
preprocess: 是否预处理图像
|
||||
**kwargs: 其他参数(未使用)
|
||||
|
||||
Returns:
|
||||
OCRBatchResult: 识别结果
|
||||
"""
|
||||
try:
|
||||
# 加载图像
|
||||
pil_image = self._load_image(image)
|
||||
|
||||
# 预处理(如果启用)
|
||||
if preprocess:
|
||||
pil_image = self.preprocessor.preprocess(pil_image)
|
||||
|
||||
# 转换为 numpy 数组(PaddleOCR 需要)
|
||||
img_array = np.array(pil_image)
|
||||
|
||||
# 执行 OCR
|
||||
result = self.ocr.ocr(img_array, cls=True)
|
||||
|
||||
# 解析结果
|
||||
ocr_results = []
|
||||
full_lines = []
|
||||
|
||||
if result and result[0]:
|
||||
for line_index, line in enumerate(result[0]):
|
||||
if line:
|
||||
# PaddleOCR 返回格式: [[bbox], (text, confidence)]
|
||||
bbox = line[0]
|
||||
text_info = line[1]
|
||||
text = text_info[0]
|
||||
confidence = float(text_info[1])
|
||||
|
||||
ocr_result = OCRResult(
|
||||
text=text,
|
||||
confidence=confidence,
|
||||
bbox=bbox,
|
||||
line_index=line_index
|
||||
)
|
||||
ocr_results.append(ocr_result)
|
||||
full_lines.append(text)
|
||||
|
||||
# 计算平均置信度
|
||||
total_confidence = self._calculate_total_confidence(ocr_results)
|
||||
|
||||
# 拼接完整文本
|
||||
full_text = '\n'.join(full_lines)
|
||||
|
||||
logger.info(f"OCR 识别完成: {len(ocr_results)} 行, 平均置信度 {total_confidence:.2f}")
|
||||
|
||||
return OCRBatchResult(
|
||||
results=ocr_results,
|
||||
full_text=full_text,
|
||||
total_confidence=total_confidence,
|
||||
success=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OCR 识别失败: {e}", exc_info=True)
|
||||
return OCRBatchResult(
|
||||
results=[],
|
||||
full_text="",
|
||||
total_confidence=0.0,
|
||||
success=False,
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
|
||||
class CloudOCREngine(BaseOCREngine):
|
||||
"""
|
||||
云端 OCR 引擎(适配器)
|
||||
|
||||
预留接口,用于扩展云端 OCR 服务
|
||||
支持:百度 OCR、腾讯 OCR、阿里云 OCR 等
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
初始化云端 OCR 引擎
|
||||
|
||||
Args:
|
||||
config: 配置字典,支持:
|
||||
- api_endpoint: API 端点
|
||||
- api_key: API 密钥
|
||||
- provider: 提供商 (baidu/tencent/aliyun/custom)
|
||||
- timeout: 超时时间(秒)
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
self.api_endpoint = self.config.get('api_endpoint', '')
|
||||
self.api_key = self.config.get('api_key', '')
|
||||
self.provider = self.config.get('provider', 'custom')
|
||||
self.timeout = self.config.get('timeout', 30)
|
||||
|
||||
if not self.api_endpoint:
|
||||
logger.warning("云端 OCR: api_endpoint 未配置")
|
||||
|
||||
def recognize(
|
||||
self,
|
||||
image,
|
||||
preprocess: bool = False,
|
||||
**kwargs
|
||||
) -> OCRBatchResult:
|
||||
"""
|
||||
使用云端 API 识别图像中的文本
|
||||
|
||||
Args:
|
||||
image: 图像(路径、PIL Image)
|
||||
preprocess: 是否预处理图像
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
OCRBatchResult: 识别结果
|
||||
"""
|
||||
# 这是一个占位实现
|
||||
# 实际使用时需要根据具体的云端 OCR API 实现
|
||||
logger.warning("云端 OCR 尚未实现,请使用本地 PaddleOCR 或自行实现")
|
||||
|
||||
return OCRBatchResult(
|
||||
results=[],
|
||||
full_text="",
|
||||
total_confidence=0.0,
|
||||
success=False,
|
||||
error_message="云端 OCR 尚未实现"
|
||||
)
|
||||
|
||||
def _send_request(self, image_data: bytes) -> Dict[str, Any]:
|
||||
"""
|
||||
发送 API 请求(待实现)
|
||||
|
||||
Args:
|
||||
image_data: 图像二进制数据
|
||||
|
||||
Returns:
|
||||
API 响应
|
||||
"""
|
||||
raise NotImplementedError("请根据具体云服务 API 实现此方法")
|
||||
|
||||
|
||||
class OCRFactory:
|
||||
"""
|
||||
OCR 引擎工厂
|
||||
|
||||
根据配置创建对应的 OCR 引擎实例
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_engine(
|
||||
mode: str = "local",
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
) -> BaseOCREngine:
|
||||
"""
|
||||
创建 OCR 引擎
|
||||
|
||||
Args:
|
||||
mode: OCR 模式 ("local" 或 "cloud")
|
||||
config: 配置字典
|
||||
|
||||
Returns:
|
||||
BaseOCREngine: OCR 引擎实例
|
||||
|
||||
Raises:
|
||||
ValueError: 不支持的 OCR 模式
|
||||
"""
|
||||
if mode == "local":
|
||||
return PaddleOCREngine(config)
|
||||
elif mode == "cloud":
|
||||
return CloudOCREngine(config)
|
||||
else:
|
||||
raise ValueError(f"不支持的 OCR 模式: {mode}")
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def recognize_text(
|
||||
image,
|
||||
mode: str = "local",
|
||||
lang: str = "ch",
|
||||
use_gpu: bool = False,
|
||||
preprocess: bool = False,
|
||||
**kwargs
|
||||
) -> OCRBatchResult:
|
||||
"""
|
||||
快捷识别文本
|
||||
|
||||
Args:
|
||||
image: 图像(路径、PIL Image)
|
||||
mode: OCR 模式 ("local" 或 "cloud")
|
||||
lang: 语言 (ch/en/chinese_chinese)
|
||||
use_gpu: 是否使用 GPU(仅本地模式)
|
||||
preprocess: 是否预处理图像
|
||||
**kwargs: 其他配置
|
||||
|
||||
Returns:
|
||||
OCRBatchResult: 识别结果
|
||||
"""
|
||||
config = {
|
||||
'lang': lang,
|
||||
'use_gpu': use_gpu,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
engine = OCRFactory.create_engine(mode, config)
|
||||
return engine.recognize(image, preprocess=preprocess)
|
||||
|
||||
|
||||
def preprocess_image(
|
||||
image_path: str,
|
||||
output_path: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> Image.Image:
|
||||
"""
|
||||
快捷预处理图像
|
||||
|
||||
Args:
|
||||
image_path: 输入图像路径
|
||||
output_path: 输出图像路径(如果指定,则保存)
|
||||
**kwargs: 预处理参数
|
||||
|
||||
Returns:
|
||||
PIL Image: 处理后的图像
|
||||
"""
|
||||
processed = ImagePreprocessor.preprocess_from_path(image_path, **kwargs)
|
||||
|
||||
if output_path:
|
||||
processed.save(output_path)
|
||||
logger.info(f"预处理图像已保存到: {output_path}")
|
||||
|
||||
return processed
|
||||
Reference in New Issue
Block a user