pytorch中使用tensor作为另一个tensor的索引 (torch.gather)

本文详细介绍了如何在PyTorch中使用一维和二维tensor对三维tensor进行行向量和列向量索引,并通过torch.gather函数处理二维索引的复杂场景。通过实例演示了索引规则和torch.gather的使用方法,适合理解索引操作在多维度数据处理中的应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一维tensor作为索引的情况

# a是待索引矩阵
a = torch.tensor([[[2,3,4,5],
	  [5,6,7,8],
	  [4,3,5,2]],
	 [[1,2,6,8],
	  [8,7,6,9],
	  [9,8,2,3]]]) # 2x2x4 假设第一维度为batch维度
# index是一维索引矩阵
index = torch.tensor([0,2])

1.如果想要索引每个batch的行向量,即index表示我们想要索引第一个batch的第1行,第二个batch的第3行
可以使用如下方法:

>>> a[[0,1],index,:]
tensor([[2, 3, 4, 5],
        [9, 8, 2, 3]])
#或者对batch维度使用更方便的操作
>>> a[torch.range(0,a.shape[0]-1).long(),index,:]
tensor([[2, 3, 4, 5],
        [9, 8, 2, 3]])           					   

2.如果想要索引每个batch的列向量,即index表示我们想要索引第一个batch的第1列,第二个batch的第3列
可以使用如下方法:

>>> a[torch.range(0,a.shape[0]-1).long(),:,index]
tensor([[2, 5, 4],
        [6, 6, 2]])

这种使用一维tensor对三维tensor进行索引的方式,需要注意的是在第一维度需要按照index中索引的顺序指定,如果我们不指定batch维度,那么就会出现这种情况,以第1种行索引为例:

>>> a[:,index,:]
tensor([[[2, 3, 4, 5],
         [4, 3, 5, 2]],

        [[1, 2, 6, 8],
         [9, 8, 2, 3]]])

索引的结果是:对每一个batch维度,都索引第1行和第3行

二维tensor作为索引的情况

当涉及到二维tensor索引三维tensor的情况,上述这种简单的方式就不再适用了(笔者没有成功过,如果有更好的建议欢迎指出!)。这时需要用到torch.gather作为索引的方式,但是由于官方文档写的比较模糊,建议看一下这篇博客,我觉得写的比官方要好:
pytorch.gather()函数深入理解(dim=1,2,3三种维度分析)
这里同样是简单举一个比较常用的例子:

# a是待索引矩阵
a = torch.tensor([[[2,3,4,5],
	  [5,6,7,8],
	  [4,3,5,2]],
	 [[1,2,6,8],
	  [8,7,6,9],
	  [9,8,2,3]]]) # 2x2x4 假设第一维度为batch维度
# index是二维索引矩阵
index=torch.tensor([[2,3,1],[0,2,1]])

index表示我们想要索引第一个batch的第1行的第3个元素,第2行的第4个元素,第3行的第2个元素。第二个batch的第1行的第1个元素,第2行的第3个元素,第3行的第2个元素。

torch.gather(a, dim=2, index=index.long().unsqueeze(-1))
tensor([[[4],
         [8],
         [3]],

        [[1],
         [6],
         [8]]])#unqueeze(-1)的目的是使用index的形状和a保持一致。

这里指定dim=2的意思是index中的第1个维度对应于a的第1个维度,index的第2个维度对应于a的第2个维度,index的每一个元素对a的第3个维度进行索引。注意这里要保证index的维度为2x3x1,才能对dim=2进行索引。

再举一个dim=1,在列维度进行索引的例子:

index=torch.tensor([[0,2,1,1],[2,0,1,2]])

index表示我们想要索引第一个batch的第1列的第0个元素,第2列的第3个元素,第3列的第1个元素,第4列的第1个元素。第二个batch的第1列的第3个元素,第2列的第1个元素,第3列的第2个元素,第4列的第3个元素。

>>> torch.gather(a, dim=1, index=index.long().unsqueeze(1))
tensor([[[2, 3, 7, 8]],

        [[9, 2, 6, 3]]])

注意,这里dim=1表示在a的第2个维度上进行索引,因此要保证index的维度为2x1x4,才能对dim=1索引。

torch.gather的简单理解

torch.gather两种形式:
tensor0.gather(dim, index) 或者 torch.gather(tensor0, dim , index)
index和tensor0具有相同数量的维度,假设都是4维,如果不是相同数量,需要增加维度到相同。
tensor0是[b, c, h, w]维度,index是[b k 1 1]
那么加上我们想根据index索引tensor0的第1维度,那么就需要设定dim=1。
理解为在tensor0的dim=1维度,按照index的值进行排序,其他的维度则是0~(该维度的长度-1)。图像化可以表示为

dim      0        1          2      3 
index.   b     index[0].     h      w
index.   b     index[1].     h      w
index.   b     index[2].     h      w
index.   b     index[3].     h      w
             ```
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值