在mindee/doctr中优化和导出模型用于生产环境

在mindee/doctr中优化和导出模型用于生产环境

前言

在完成模型训练后,如何将模型优化并部署到生产环境是每个开发者都需要面对的问题。本文将详细介绍如何在mindee/doctr项目中优化模型性能,并将其导出为适合生产环境使用的格式。

模型优化技术

半精度推理(FP16)

半精度(FP16)是一种16位的浮点数格式,相比标准的32位浮点数(FP32),它可以带来显著的性能提升:

  • 优势
    • 推理速度更快
    • 显存占用更少
    • 特别适合GPU加速

注意:目前仅支持在GPU设备上使用PyTorch和TensorFlow后端进行半精度推理。

PyTorch实现
import torch
predictor = ocr_predictor(
    reco_arch="crnn_mobilenet_v3_small",
    det_arch="linknet_resnet34",
    pretrained=True
).cuda().half()  # 关键是将模型转换为半精度
res = predictor(doc)
TensorFlow实现
import tensorflow as tf
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')  # 设置全局混合精度策略
predictor = ocr_predictor(
    reco_arch="crnn_mobilenet_v3_small",
    det_arch="linknet_resnet34",
    pretrained=True
)

PyTorch模型编译优化

PyTorch 2.0引入了torch.compile功能,可以将模型转换为图表示并进行优化:

  • 优势
    • 显著提升推理速度
    • 减少内存开销
    • 支持多种后端优化器

注意事项

  1. 仅支持PyTorch后端
  2. 目前不支持master识别架构
  3. 官方默认支持inductor后端
import torch
from doctr.models import (
    ocr_predictor,
    vitstr_small,
    fast_base,
    mobilenet_v3_small_crop_orientation,
    mobilenet_v3_small_page_orientation,
    crop_orientation_predictor,
    page_orientation_predictor
)

# 编译各个组件模型
detection_model = torch.compile(fast_base(pretrained=True).eval())
recognition_model = torch.compile(vitstr_small(pretrained=True).eval())
crop_orientation_model = torch.compile(
    mobilenet_v3_small_crop_orientation(pretrained=True).eval()
)
page_orientation_model = torch.compile(
    mobilenet_v3_small_page_orientation(pretrained=True).eval()
)

# 构建预测器
predictor = models.ocr_predictor(
    detection_model, recognition_model, assume_straight_pages=False
)

# 设置方向预测器(非直线页面需要)
predictor.crop_orientation_predictor = crop_orientation_predictor(crop_orientation_model)
predictor.page_orientation_predictor = page_orientation_predictor(page_orientation_model)

# 执行推理
compiled_out = predictor(doc)

模型导出为ONNX格式

ONNX是一种开放的神经网络交换格式,可以实现跨框架的模型部署。

PyTorch模型导出

import torch
from doctr.models import vitstr_small
from doctr.models.utils import export_model_to_onnx

batch_size = 1
input_shape = (3, 32, 128)  # 输入张量形状
model = vitstr_small(pretrained=True, exportable=True)
dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32)
model_path = export_model_to_onnx(
    model,
    model_name="vitstr.onnx",  # 输出文件名
    dummy_input=dummy_input  # 示例输入
)

TensorFlow模型导出

import tensorflow as tf
from doctr.models import vitstr_small
from doctr.models.utils import export_model_to_onnx

batch_size = 1
input_shape = (32, 128, 3)  # TensorFlow的通道顺序不同
model = vitstr_small(pretrained=True, exportable=True)
dummy_input = [tf.TensorSpec([batch_size, *input_shape], tf.float32, name="input")]
model_path, output = export_model_to_onnx(
    model,
    model_name="vitstr.onnx",
    dummy_input=dummy_input
)

使用导出的ONNX模型

导出的ONNX模型可以通过专门的轻量级包进行推理,无需安装PyTorch或TensorFlow。

安装依赖

pip install onnxtr[cpu]  # 或[gpu]如果需要GPU支持

加载和使用模型

from onnxtr.io import DocumentFile
from onnxtr.models import ocr_predictor, parseq, linknet_resnet18

# 加载文档
single_img_doc = DocumentFile.from_images("path/to/your/img.jpg")

# 加载自定义导出模型
reco_model = parseq("path_to_custom_model.onnx", vocab="ABC")
det_model = linknet_resnet18("path_to_custom_model.onnx")
predictor = ocr_predictor(det_arch=det_model, reco_arch=reco_model)

# 或者使用预训练模型
predictor = ocr_predictor(det_arch="linknet_resnet18", reco_arch="parseq")

# 执行推理
res = predictor(single_img_doc)

总结

本文详细介绍了在mindee/doctr项目中优化和导出模型的完整流程,包括:

  1. 半精度推理优化
  2. PyTorch模型编译
  3. ONNX格式导出
  4. 轻量级推理部署

通过这些技术,开发者可以显著提升模型在生产环境中的性能表现,同时实现更灵活的部署方案。

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

高慈鹃Faye

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值