torch-xla动态shape——通过torch.nonzero分析mhlo实现

pytorch api:

torch.where(condition) is identical to torch.nonzero(condition, as_tuple=True).

特别注意torch.nonzeroas_tuple参数:

在这里插入图片描述

动态shape开启方式

XLA_EXPERIMENTAL=“nonzero”

xla算子定义torch_xla/csrc/ops/nonzero.cpp
torch-xla nonzero及其开启方式,不开启时走xla_cpu_fallback
在这里插入图片描述

mhlo算法理解

在这里插入图片描述

python脚本

脚本来源xla/issues/4432

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
os.environ["PJRT_DEVICE"] = "XPU"
# os.environ["PJRT_DEVICE"] = "GPU"
os.environ["XLA_EXPERIMENTAL"]="nonzero"
# for nonzero cpu fallback log, when XLA_EXPERIMENTAL='nonzero' is not enabled
os.environ["TF_CPP_VMODULE"] = "aten_cpu_fallback=5"

import torch
import torch_xla.core.xla_model as xm

a1 = torch.tensor([[1, 0, 0, 5, 0, 6]], device=xm.xla_device())
a2 = torch.nonzero(a1)
"""
IR {
  %0 = s64[1,6]{1,0} xla::device_data(), xla_shape=s64[1,6]{1,0}, device=Poplar:0
  %1 = (s32[<=6,2]{1,0}, s32[]) aten::nonzero(%0), num_outputs=2, xla_shape=(s32[<=6,2]{1,0}, s32[]), ROOT=0
}

"""
print(torch_xla._XLAC._get_xla_tensors_text([a2]))
print(f'{
     
     a2.shape=}') # a2.shape=torch.Size([<=6, 2])
print('a2=', a2)

运行结果:
在这里插入图片描述
as_tuple=True的运行结果:
在这里插入图片描述

mhlo代码

module @SyncTensorsGraph.40 {
   
   
  // %arg0: [[1, 0, 0, 5, 0, 6]]
  func.func @main(%arg0: tensor<1x6xi64>) -> tuple<tensor<?x2xi32, #mhlo.type_extensions<bounds = [6, -1]>>> {
   
   
    %0 = mhlo.constant dense<0> : tensor<i64>
    %1 = "mhlo.broadcast_in_dim"(%0) {
   
   broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i64>) -> tensor<1x6xi64>
    %2 = mhlo.compare  NE, %arg0, %1 : (tensor<1x6xi64>, tensor<1x6xi64>) -> tensor<1x6xi1>
    %3 = mhlo.convert(%2) : (tensor<1x6xi1>) -> tensor<1x6xi32>
    // %4: [1, 0, 0, 1, 0, 1]
	%4 = mhlo.reshape %3 : (tensor<1x6xi32>) -> tensor<6xi32>
	%5 = "mhlo.iota"() {
   
   iota_dimension = 0 : i64} : () -> tensor<1x6xi32>
    // %6: [0, 0, 0, 0, 0, 0], row indices
    %6 = mhlo.reshape %5 : (tensor<1x6xi32>) -> tensor<6xi32>
    %7 = "mhlo.iota"() {
   
   iota_dimension = 1 : i64} : () -> tensor<1x6xi32>
    // %8: [0, 1, 2, 3, 4, 5], col indices
	%8 = mhlo.reshape %7 : (tensor<1x6xi32>) -> tensor<6xi32>
	"""
    对所有的operand(%4, %6, %8)分别排序得到3个排序后的tensor,且排序条件一致如下
	这里3个operand,对应block中有6个element,对应关系为:一个operand对应2个arg
	%arg1和%arg2为%4中的元素,所以这里操作的结果为:将%6和%8按照%4中元素的大小关系排序
	即:将第0,3,5位置的元素排到最前。排序后的结果为:
	%4'=%9#0 = [1, 1, 1, 0, 0, 0]
	%6'=%9#1 = [0, 0, 0, 0, 0, 0] 即%10
	%8'=%9#2 = [0, 3, 5, 1, 2, 4] 即%11
	"""
	%9:3 = "mhlo.sort"(%4, %6, %8) ({
   
   
    ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<i32>, %arg6: tensor<i32>):
      // %20总是true
	  %20 = mhlo.constant dense<true> : tensor<i1>
      // 排序条件: %arg1 > %arg2
	  %21 = mhlo.compare  GT, %arg1, %arg2 : (tensor<i3
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值