pytorch api:
torch.where(condition)
is identical totorch.nonzero(condition, as_tuple=True)
.
特别注意torch.nonzero
的as_tuple
参数:
动态shape开启方式
XLA_EXPERIMENTAL=“nonzero”
xla算子定义torch_xla/csrc/ops/nonzero.cpp
torch-xla nonzero及其开启方式,不开启时走xla_cpu_fallback
:
mhlo算法理解
python脚本
脚本来源xla/issues/4432
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
os.environ["PJRT_DEVICE"] = "XPU"
# os.environ["PJRT_DEVICE"] = "GPU"
os.environ["XLA_EXPERIMENTAL"]="nonzero"
# for nonzero cpu fallback log, when XLA_EXPERIMENTAL='nonzero' is not enabled
os.environ["TF_CPP_VMODULE"] = "aten_cpu_fallback=5"
import torch
import torch_xla.core.xla_model as xm
a1 = torch.tensor([[1, 0, 0, 5, 0, 6]], device=xm.xla_device())
a2 = torch.nonzero(a1)
"""
IR {
%0 = s64[1,6]{1,0} xla::device_data(), xla_shape=s64[1,6]{1,0}, device=Poplar:0
%1 = (s32[<=6,2]{1,0}, s32[]) aten::nonzero(%0), num_outputs=2, xla_shape=(s32[<=6,2]{1,0}, s32[]), ROOT=0
}
"""
print(torch_xla._XLAC._get_xla_tensors_text([a2]))
print(f'{
a2.shape=}') # a2.shape=torch.Size([<=6, 2])
print('a2=', a2)
运行结果:
as_tuple=True
的运行结果:
mhlo代码
module @SyncTensorsGraph.40 {
// %arg0: [[1, 0, 0, 5, 0, 6]]
func.func @main(%arg0: tensor<1x6xi64>) -> tuple<tensor<?x2xi32, #mhlo.type_extensions<bounds = [6, -1]>>> {
%0 = mhlo.constant dense<0> : tensor<i64>
%1 = "mhlo.broadcast_in_dim"(%0) {
broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i64>) -> tensor<1x6xi64>
%2 = mhlo.compare NE, %arg0, %1 : (tensor<1x6xi64>, tensor<1x6xi64>) -> tensor<1x6xi1>
%3 = mhlo.convert(%2) : (tensor<1x6xi1>) -> tensor<1x6xi32>
// %4: [1, 0, 0, 1, 0, 1]
%4 = mhlo.reshape %3 : (tensor<1x6xi32>) -> tensor<6xi32>
%5 = "mhlo.iota"() {
iota_dimension = 0 : i64} : () -> tensor<1x6xi32>
// %6: [0, 0, 0, 0, 0, 0], row indices
%6 = mhlo.reshape %5 : (tensor<1x6xi32>) -> tensor<6xi32>
%7 = "mhlo.iota"() {
iota_dimension = 1 : i64} : () -> tensor<1x6xi32>
// %8: [0, 1, 2, 3, 4, 5], col indices
%8 = mhlo.reshape %7 : (tensor<1x6xi32>) -> tensor<6xi32>
"""
对所有的operand(%4, %6, %8)分别排序得到3个排序后的tensor,且排序条件一致如下
这里3个operand,对应block中有6个element,对应关系为:一个operand对应2个arg
%arg1和%arg2为%4中的元素,所以这里操作的结果为:将%6和%8按照%4中元素的大小关系排序
即:将第0,3,5位置的元素排到最前。排序后的结果为:
%4'=%9#0 = [1, 1, 1, 0, 0, 0]
%6'=%9#1 = [0, 0, 0, 0, 0, 0] 即%10
%8'=%9#2 = [0, 3, 5, 1, 2, 4] 即%11
"""
%9:3 = "mhlo.sort"(%4, %6, %8) ({
^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<i32>, %arg6: tensor<i32>):
// %20总是true
%20 = mhlo.constant dense<true> : tensor<i1>
// 排序条件: %arg1 > %arg2
%21 = mhlo.compare GT, %arg1, %arg2 : (tensor<i3