Skip to content

Commit e7bc1d1

Browse files
titaiwangmspytorchmergebot
authored andcommitted
[ONNX] Update saved exported program in debugging report if the exporting passes run_decomposition() (#148617)
Previous to this PR, if the exporting passes run_decomposition(), the report still shows the exported_program before decomposition, which adds the difficulties to our users when they want to check the exported program that are used to translate to ONNX graph. The following example is what we see before this PR: ``` # PyTorch ONNX Conversion Report ``` ✅ Obtain model graph with `torch.export.export(..., strict=False)` ⚪ Obtain model graph with `torch.export.export(..., strict=True)` ⚪ Obtain model graph with `torch.jit.trace` ✅ Decompose operators for ONNX compatibility ❌ Translate the graph into ONNX ⚪ Run `onnx.checker` on the ONNX model ⚪ Execute the model with ONNX Runtime ⚪ Validate model output accuracy ``` ## Error messages ```pytb Traceback (most recent call last): File "/home/titaiwang/pytorch/torch/onnx/_internal/exporter/_core.py", line 707, in _translate_fx_graph _handle_call_function_node_with_lowering( File "/home/titaiwang/pytorch/torch/onnx/_internal/exporter/_core.py", line 486, in _handle_call_function_node_with_lowering raise _errors.DispatchError( torch.onnx._internal.exporter._errors.DispatchError: No ONNX function found for <OpOverload(op='aten.slice', overload='Tensor')>. Failure message: No decompositions registered for the complex-valued input The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/titaiwang/pytorch/torch/onnx/_internal/exporter/_core.py", line 1371, in export onnx_program = _exported_program_to_onnx_program( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/titaiwang/pytorch/torch/onnx/_internal/exporter/_core.py", line 1007, in _exported_program_to_onnx_program values = _translate_fx_graph( ^^^^^^^^^^^^^^^^^^^^ File "/home/titaiwang/pytorch/torch/onnx/_internal/exporter/_core.py", line 733, in _translate_fx_graph raise _errors.ConversionError( torch.onnx._internal.exporter._errors.ConversionError: Error when translating node %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%_to_copy, 0, 0, 9223372036854775807), kwargs = {}). See the stack trace for more information. ``` ## Exported program ```python ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 4]"): # File: /home/titaiwang/pytorch/test_slice_complex.py:6 in forward, code: x_complex = x.to(torch.complex64) to: "c64[3, 4]" = torch.ops.aten.to.dtype(x, torch.complex64); x = None # File: /home/titaiwang/pytorch/test_slice_complex.py:8 in forward, code: return x_complex[:, :2] slice_1: "c64[3, 4]" = torch.ops.aten.slice.Tensor(to, 0, 0, 9223372036854775807); to = None slice_2: "c64[3, 2]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 2); slice_1 = None return (slice_2,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='slice_2'), target=None)]) Range constraints: {} ``` ## Analysis PyTorch ONNX Conversion Analysis ## Model Information The model has 0 parameters and 0 buffers (non-trainable parameters). Number of parameters per dtype: ```python defaultdict(<class 'int'>, {}) ``` Number of buffers per dtype: ```python defaultdict(<class 'int'>, {}) ``` Inputs: - `x`: `TensorMetadata(shape=torch.Size([3, 4]), dtype=torch.float32, requires_grad=False, stride=(4, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})` Outputs: - `slice_2`: `TensorMetadata(shape=torch.Size([3, 2]), dtype=torch.complex64, requires_grad=False, stride=(4, 1), memory_format=None, is_quantized=False, qparams={})` The FX graph has 5 nodes in total. Number of FX nodes per op: - `placeholder`: 1 - `call_function`: 3 - `output`: 1 Of the call_function nodes, the counts of operators used are: - `aten.slice.Tensor`: 2 - `aten.to.dtype`: 1 ## ONNX Conversion Information The model contains operators the dispatcher could not find registered ONNX decompositions for. This may be due to missing implementations, decompositions not registered correctly, or a bug in the dispatcher. Errors grouped by operator: - `aten.to.dtype`: No decompositions registered for the real-valued input. Example node: `%to : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%x, torch.complex64), kwargs = {})`. All nodes: `[to]` - `aten.slice.Tensor`: No decompositions registered for the complex-valued input. Example node: `%slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%to, 0, 0, 9223372036854775807), kwargs = {})`. All nodes: `[slice_1, slice_2]` ## Decomposition comparison Ops exist only in the ExportedProgram before decomposition: `['aten.to.dtype']` Ops exist only in the ExportedProgram after decomposition: `['aten._to_copy.default']` ``` Pull Request resolved: #148617 Approved by: https://github.com/justinchuby
1 parent ae6bb58 commit e7bc1d1

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torch/onnx/_internal/exporter/_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1402,7 +1402,7 @@ def export(
14021402
_reporting.create_onnx_export_report(
14031403
report_path,
14041404
f"{_format_exceptions_for_all_strategies(failed_results)}\n\n{_format_exception(e)}",
1405-
program,
1405+
decomposed_program,
14061406
decomp_comparison=_reporting.format_decomp_comparison(
14071407
pre_decomp_unique_ops, post_decomp_unique_ops
14081408
),

0 commit comments

Comments
 (0)