PaddleOCR-DB文本行检测转onnx推理

该文详细介绍了如何将PaddleOCR的DB模型转换为ONNX格式,并展示了在Python中使用ONNXRuntime进行文本行检测的步骤,包括图像预处理、ONNX模型推理和后处理,以提取检测框。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

DB文本行检测

(一)PaddleOCR-DB模型转onnx

可参考以下几个链接

  1. https://blog.csdn.net/favorxin/article/details/113838901
  2. https://github.com/PaddlePaddle/Paddle2ONNX

(二)导入依赖包

import os
import sys
import cv2
import glob
import onnxruntime
import numpy as np
import pyclipper
from shapely.geometry import Polygon

(三)定义PaddleOCR文本行检测依赖类函数

class NormalizeImage(object):
    """ normalize image such as substract mean, divide std
    """

    def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
        if isinstance(scale, str):
            scale = eval(scale)
        self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
        mean = mean if mean is not None else [0.485, 0.456, 0.406]
        std = std if std is not None else [0.229, 0.224, 0.225]

        shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
        self.mean = np.array(mean).reshape(shape).astype('float32')
        self.std = np.array(std).reshape(shape).astype('float32')

    def __call__(self, data):
        img = data['image']
        from PIL import Image
        if isinstance(img, Image.Image):
            img = np.array(img)

        assert isinstance(img,
                          np.ndarray), "invalid input 'img' in NormalizeImage"
        data['image'] = (
            img.astype('float32') * self.scale - self.mean) / self.std
        return data


class ToCHWImage(object):
    """ convert hwc image to chw image
    """

    def __init__(self, **kwargs):
        pass

    def __call__(self, data):
        img = data['image']
        from PIL import Image
        if isinstance(img, Image.Image):
            img = np.array(img)
        data['image'] = img.transpose((2, 0, 1))
        return data


class KeepKeys(object):
    def __init__(self, keep_keys, **kwargs):
        self.keep_keys = keep_keys

    def __call__(self, data):
        data_list = []
        for key in self.keep_keys:
            data_list.append(data[key])
        return data_list

class DetResizeForTest(object):
    def __init__(self, **kwargs):
        super(DetResizeForTest, self).__init__()
        self.resize_type = 0
        self.limit_side_len = kwargs['limit_side_len']
        self.limit_type = kwargs.get('limit_type', 'min')

    def __call__(self, data):
        img = data['image']
        src_h, src_w, _ = img.shape
        img, [ratio_h, ratio_w] = self.resize_image_type0(img)
        data['image'] = img
        data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
        return data

    def resize_image_type0(self, img):
        """
        resize image to a size multiple of 32 which is required by the network
        args:
            img(array): array with shape [h, w, c]
        return(tuple):
            img, (ratio_h, ratio_w)
        """
        limit_side_len = self.limit_side_len
        h, w, _ = img.shape

        # limit the max side
        if max(h, w) > limit_side_len:
            if h > w:
                ratio = float(limit_side_len) / h
            else:
                ratio = float(limit_side_len) / w
        else:
            ratio = 1.
        resize_h = int(h * ratio)
        resize_w = int(w * ratio)

        resize_h = int(round(resize_h / 32) * 32)
        resize_w = int(round(resize_w / 32) * 32)

        try:
            if int(resize_w) <= 0 or int(resize_h) <= 0:
                return None, (None, None)
            img = cv2.resize(img, (int(resize_w), int(resize_h)))
        except:
            print(img.shape, resize_w, resize_h)
            sys.exit(0)
        ratio_h = resize_h / float(h)
        ratio_w = resize_w / float(w)
        return img, [ratio_h, ratio_w]

### 检测结果后处理过程(得到检测框)
class DBPostProcess(object):
    """
    The post process for Differentiable Binarization (DB).
    """

    def __init__(self,
                 thresh=0.3,
                 box_thresh=0.7,
                 max_candidates=1000,
                 unclip_ratio=2.0,
                 use_dilation=False,
                 **kwargs):
        self.thresh = thresh
        self.box_thresh = box_thresh
        self.max_candidates = max_candidates
        self.unclip_ratio = unclip_ratio
        self.min_size = 3
        self.dilation_kernel = None if not use_dilation else np.array(
            [[1, 1], [1, 1]])

    def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
        '''
        _bitmap: single map with shape (1, H, W),
                whose values are binarized as {0, 1}
        '''

        bitmap = _bitmap
        height, width = bitmap.shape

        outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
                                cv2.CHAIN_APPROX_SIMPLE)
        if len(outs) == 3:
            img, contours, _ = outs[0], outs[1], outs[2]
        elif len(outs) == 2:
            contours, _ = outs[0], outs[1]

        num_contours = min(len(contours), self.max_candidates)

        boxes = []
        scores = []
        for index in range(num_contours):
            contour = contours[index]
            points, sside = self.get_mini_boxes(contour)
            if sside < self.min_size:
                continue
            points = np.array(points)
            score = self.box_score_fast(pred, points.reshape(-1, 2))
            if self.box_thresh > score:
                continue

            box = self.unclip(points).reshape(-1, 1, 2)
            box, sside = self.get_mini_boxes(box)
            if sside < self.min_size + 2:
                continue
            box = np.array(box)

            box[:, 0] = np.clip(
                np.round(box[:, 0] / width * dest_width), 0, dest_width)
            box[:, 1] = np.clip(
                np.round(box[:, 1] / height * dest_height), 0, dest_height)
            boxes.append(box.astype(np.int16))
            scores.append(score)
        return np.array(boxes, dtype=np.int16), scores

    def unclip(self, box):
        unclip_ratio = self.unclip_ratio
        poly = Polygon(box)
        distance = poly.area * unclip_ratio / poly.length
        offset = pyclipper.PyclipperOffset()
        offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
        expanded = np.array(offset.Execute(distance))
        return expanded

    def get_mini_boxes(self, contour):
        bounding_box = cv2.minAreaRect(contour)
        points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])

        index_1, index_2, index_3, index_4 = 0, 1, 2, 3
        if points[1][1] > points[0][1]:
            index_1 = 0
            index_4 = 1
        else:
            index_1 = 1
            index_4 = 0
        if points[3][1] > points[2][1]:
            index_2 = 2
            index_3 = 3
        else:
            index_2 = 3
            index_3 = 2

        box = [
            points[index_1], points[index_2], points[index_3], points[index_4]
        ]
        return box, min(bounding_box[1])

    def box_score_fast(self, bitmap, _box):
        h, w = bitmap.shape[:2]
        box = _box.copy()
        xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int_), 0, w - 1)
        xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int_), 0, w - 1)
        ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int_), 0, h - 1)
        ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int_), 0, h - 1)

        mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
        box[:, 0] = box[:, 0] - xmin
        box[:, 1] = box[:, 1] - ymin
        cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
        return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]

    def __call__(self, outs_dict, shape_list):
        pred = outs_dict
        pred = pred[:, 0, :, :]
        segmentation = pred > self.thresh

        boxes_batch = []
        for batch_index in range(pred.shape[0]):
            src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
            if self.dilation_kernel is not None:
                mask = cv2.dilate(
                    np.array(segmentation[batch_index]).astype(np.uint8),
                    self.dilation_kernel)
            else:
                mask = segmentation[batch_index]
            boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
                                                   src_w, src_h)

            boxes_batch.append({'points': boxes})
        return boxes_batch


## 图片预处理过程
def transform(data, ops=None):
    """ transform """
    if ops is None:
        ops = []
    for op in ops:
        data = op(data)
        if data is None:
            return None
    return data

def create_operators(op_param_list, global_config=None):
    """
    create operators based on the config

    Args:
        params(list): a dict list, used to create some operators
    """
    assert isinstance(op_param_list, list), ('operator config should be a list')
    ops = []
    for operator in op_param_list:
        assert isinstance(operator,
                          dict) and len(operator) == 1, "yaml format error"
        op_name = list(operator)[0]
        param = {} if operator[op_name] is None else operator[op_name]
        if global_config is not None:
            param.update(global_config)
        op = eval(op_name)(**param)
        ops.append(op)
    return ops

### 检测框的后处理
def order_points_clockwise(pts):
    """
    reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
    # sort the points based on their x-coordinates
    """
    xSorted = pts[np.argsort(pts[:, 0]), :]

    # grab the left-most and right-most points from the sorted
    # x-roodinate points
    leftMost = xSorted[:2, :]
    rightMost = xSorted[2:, :]

    # now, sort the left-most coordinates according to their
    # y-coordinates so we can grab the top-left and bottom-left
    # points, respectively
    leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
    (tl, bl) = leftMost

    rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
    (tr, br) = rightMost

    rect = np.array([tl, tr, br, bl], dtype="float32")
    return rect

def clip_det_res(points, img_height, img_width):
    for pno in range(points.shape[0]):
        points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
        points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
    return points

def filter_tag_det_res(dt_boxes, image_shape):
    img_height, img_width = image_shape[0:2]
    dt_boxes_new = []
    for box in dt_boxes:
        box = order_points_clockwise(box)
        box = clip_det_res(box, img_height, img_width)
        rect_width = int(np.linalg.norm(box[0] - box[1]))
        rect_height = int(np.linalg.norm(box[0] - box[3]))
        if rect_width <= 3 or rect_height <= 3:
            continue
        dt_boxes_new.append(box)
    dt_boxes = np.array(dt_boxes_new)
    return dt_boxes

### 定义图片前处理过程,和检测结果后处理过程
def get_process():
    det_db_thresh = 0.3
    det_db_box_thresh = 0.4
    max_candidates = 4000
    unclip_ratio = 1.6
    use_dilation = True

    pre_process_list = [{
        'DetResizeForTest': {
            'limit_side_len': 2000,
            'limit_type': 'max'
        }
    }, {
        'NormalizeImage': {
            'std': [0.229, 0.224, 0.225],
            'mean': [0.485, 0.456, 0.406],
            'scale': '1./255.',
            'order': 'hwc'
        }
    }, {
        'ToCHWImage': None
    }, {
        'KeepKeys': {
            'keep_keys': ['image', 'shape']
        }
    }]

    infer_before_process_op = create_operators(pre_process_list)
    det_re_process_op = DBPostProcess(det_db_thresh, det_db_box_thresh, max_candidates, unclip_ratio, use_dilation)
    return infer_before_process_op, det_re_process_op

def sorted_boxes(dt_boxes):
    """
    Sort text boxes in order from top to bottom, left to right
    args:
        dt_boxes(array):detected text boxes with shape [4, 2]
    return:
        sorted boxes(array) with shape [4, 2]
    """
    num_boxes = dt_boxes.shape[0]
    sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
    _boxes = list(sorted_boxes)

    for i in range(num_boxes - 1):
        if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
                (_boxes[i + 1][0][0] < _boxes[i][0][0]):
            tmp = _boxes[i]
            _boxes[i] = _boxes[i + 1]
            _boxes[i + 1] = tmp
    return _boxes
   
def draw_bbox(img_path, result, color=(255, 0, 0), thickness=2):
    if isinstance(img_path, str):
        img_path = cv2.imread(img_path)
        # img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB)
    img_path = img_path.copy()
    for point in result:
        point = point.astype(int)

        cv2.polylines(img_path, [point], True, color, thickness)
    return img_path

def get_images(pic_dir):
    """获得文件夹下所有图片路径,并存放在列表中.
    """
    pic_paths = []
    for ext in ['jpg', 'png', 'bmp', 'jpeg', '.jfif']:
        pic_paths.extend(glob.glob(os.path.join(pic_dir, '*.{}'.format(ext))))
    return pic_paths

# resize图片展示
def imshow_resize(image, need_w=1680, need_h=1280):
    img = image.copy()
    h, w = image.shape[:2]
    ratio = w / h
    if need_w / ratio > need_h:
        new_h = need_h
        new_w = int(need_h*ratio)
    else:
        new_h = int(need_w/ratio)
        new_w = need_w
    img = cv2.resize(img, (new_w, new_h))
    return img

def load_image(img_path):
    image = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), -1)
    if image.shape[2] == 4:
        image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
    return image

(四)onnx文本行检测代码

curdir = os.path.abspath(os.path.dirname(__file__))

large_det_file = os.path.join(curdir, 'models', 'ppocrv3_det.onnx')
onet_det_session = onnxruntime.InferenceSession(large_det_file)
infer_before_process_op, det_re_process_op = get_process()

## 推理检测图片中的部分
def get_boxes(image):
    img_part = image.copy()
    data_part = {'image': img_part}
    data_part = transform(data_part, infer_before_process_op)
    img_part, shape_part_list = data_part
    img_part = np.expand_dims(img_part, axis=0)
    shape_part_list = np.expand_dims(shape_part_list, axis=0)
    inputs_part = {onet_det_session.get_inputs()[0].name: img_part}
    outs_part = onet_det_session.run(None, inputs_part)
    post_res_part = det_re_process_op(outs_part[0], shape_part_list)
    dt_boxes_part = post_res_part[0]['points']
    dt_boxes_part = filter_tag_det_res(dt_boxes_part, image.shape)
    dt_boxes_part = sorted_boxes(dt_boxes_part)
    return dt_boxes_part

(五)测试样例

if __name__=="__main__":

    imgs_dir = r'E:\datasets\OCR\test_imgs'
    imgs_path = get_images(imgs_dir)
    for img_path in imgs_path:
        image = load_image(img_path)
        # 原始检测结果
        img = image.copy()
        boxes = get_boxes(img)
        img_show = image.copy()
        img_show = draw_bbox(img_show, np.array(boxes))
        img_show = imshow_resize(img_show)
        cv2.imshow("RealSense", img_show)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

PaddleOCR-Server 是基于 PaddlePaddle 深度学习框架开发的 OCR(光学字符识别)服务端模型,专为高性能和高精度的检测与识别设计。其优缺点可以从以下几个方面进行分析: ### 优点 1. **高精度识别能力** PaddleOCR-Server 支持多种先进的检测与识别算法,例如 DB(Differentiable Binarization)检测算法和 CRNN(Convolutional Recurrent Neural Network)本识别算法。这些算法在复杂背景、倾斜本、多语言混合等场景下表现出色,具有较高的识别准确率[^1]。 2. **多语言支持** 该模型支持包括中英在内的多种语言识别,同时还可以扩展到其他语言,如日语、韩语、阿拉伯语等,适用于国际化场景的需求。 3. **灵活部署与高并发处理能力** PaddleOCR-Server 提供了服务端部署方案,可以结合 Paddle Inference 或 Paddle Serving 进行高性能部署,适用于高并发的工业级应用场景。此外,它支持多线程和 GPU 加速推理,能够显著提升处理效率。 4. **丰富的预训练模型库** PaddleOCR 提供了多种预训练模型,例如 MobileNetV3、ResNet18、ResNet50 等骨干网络模型,用户可以根据具体需求选择适合的模型,平衡精度与速度[^2]。 5. **良好的社区支持** PaddleOCR 项目拥有活跃的社区支持和完善的中档,便于开发者快速上手并解决问题。此外,还有微信用户群提供实时答疑,进一步降低了学习和使用门槛。 ### 缺点 1. **计算资源消耗较高** 相较于轻量级模型(如移动端优化模型),PaddleOCR-Server 的服务端模型对计算资源的需求更高,尤其是在使用 ResNet 等深度网络时,需要较高的 GPU 内存和计算能力,这可能增加部署成本。 2. **部署复杂性** 虽然 PaddleOCR-Server 提供了服务端部署方案,但相较于直接使用预训练模型进行推理,其部署过程涉及更多配置步骤,例如服务端接口设置、负载均衡、性能调优等,对开发者的工程能力有一定要求。 3. **对实时性要求高的场景适应性有限** 尽管 PaddleOCR-Server 支持多线程和 GPU 加速,但在某些对响应时间要求极高的实时应用场景中,可能需要进一步优化模型结构或采用更轻量级的模型以满足延迟要求。 4. **定制化训练门槛较高** 如果用户希望针对特定场景进行模型定制训练,需要具备一定的深度学习知识,并熟悉 PaddlePaddle 框架的训练流程,这可能会增加学习曲线。 ### 示例代码:使用 PaddleOCR 进行简单识别 ```python from paddleocr import PaddleOCR, draw_ocr # 初始化 PaddleOCR 模型 ocr = PaddleOCR(use_angle_cls=True, lang='ch') # 启用角度分类器,支持中 # 图片路径 img_path = 'test_image.jpg' # 进行 OCR 识别 result = ocr.ocr(img_path, cls=True) # 可视化结果 for line in result: print(line) # 可视化识别结果 from PIL import Image image = Image.open(img_path).convert('RGB') boxes = [line[0] for line in result] txts = [line[1][0] for line in result] scores = [line[1][1] for line in result] im_show = draw_ocr(image, boxes, txts, scores) im_show = Image.fromarray(im_show) im_show.save('result.jpg') ``` ---
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值