torch.onnx.export 导出自定义网络
时间: 2025-07-23 09:18:34 AIGC 浏览: 22
使用 `torch.onnx.export` 将自定义的 PyTorch 模型导出为 ONNX 格式涉及多个关键步骤,包括模型准备、输入数据构造、调用导出函数以及可选的模型优化。以下是详细说明:
### 1. 准备模型
确保模型处于评估模式,并且没有使用任何仅在训练阶段需要的功能(如 Dropout 或 BatchNorm 的更新)。通过 `model.eval()` 设置模型为推理模式。
```python
import torch
# 假设 CustomModel 是用户自定义的神经网络类
model = CustomModel()
model.load_state_dict(torch.load("model.pth")) # 加载训练好的权重
model.eval() # 设置为评估模式
```
### 2. 构造输入数据
提供一个与模型期望输入匹配的张量。该输入将用于执行模型前向传播并生成计算图。
```python
# 假设模型接受形状为 (batch_size, channels, height, width) 的输入
dummy_input = torch.randn(1, 3, 224, 224) # batch_size=1, 3通道, 224x224图像
```
### 3. 调用 `torch.onnx.export`
指定必要的参数,例如输入名称 (`input_names`) 和输出名称 (`output_names`),这些名称将在 ONNX 模型中作为输入/输出节点的标识符。如果模型支持动态维度,可以使用 `dynamic_axes` 参数指定。
```python
import torch.onnx
# 导出模型
torch.onnx.export(
model,
dummy_input,
"model.onnx",
export_params=True, # 存储训练参数
opset_version=13, # ONNX 算子集版本
do_constant_folding=True, # 优化常量
input_names=["input"], # 输入节点名称
output_names=["output"], # 输出节点名称
dynamic_axes={
"input": {0: "batch_size"}, # 动态维度
"output": {0: "batch_size"}
}
)
```
### 4. 验证 ONNX 模型
使用 ONNX Runtime 运行推理以验证导出的模型是否正确。
```python
import onnx
import onnxruntime as ort
# 加载 ONNX 模型
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)
# 使用 ONNX Runtime 推理
ort_session = ort.InferenceSession("model.onnx")
outputs = ort_session.run(
None,
{"input": dummy_input.numpy()}
)
```
### 5. 模型优化(可选)
使用 `onnx-simplifier` 工具简化模型结构,提升推理效率[^1]。
```bash
pip install onnx-simplifier
python -m onnxsim model.onnx model_simplified.onnx
```
### 总结
- `input_names` 和 `output_names` 可由用户自定义,也可以根据网络定义确定,通常通过查看模型前向传播中的输入和输出变量来决定。
- `dynamic_axes` 支持动态输入尺寸,适用于不同批量大小或序列长度的应用场景。
- 导出后建议进行验证,确保 ONNX 模型行为与原始 PyTorch 模型一致。
阅读全文
相关推荐




















