一维tensor作为索引的情况
# a是待索引矩阵
a = torch.tensor([[[2,3,4,5],
[5,6,7,8],
[4,3,5,2]],
[[1,2,6,8],
[8,7,6,9],
[9,8,2,3]]]) # 2x2x4 假设第一维度为batch维度
# index是一维索引矩阵
index = torch.tensor([0,2])
1.如果想要索引每个batch的行向量,即index表示我们想要索引第一个batch的第1行,第二个batch的第3行。
可以使用如下方法:
>>> a[[0,1],index,:]
tensor([[2, 3, 4, 5],
[9, 8, 2, 3]])
#或者对batch维度使用更方便的操作
>>> a[torch.range(0,a.shape[0]-1).long(),index,:]
tensor([[2, 3, 4, 5],
[9, 8, 2, 3]])
2.如果想要索引每个batch的列向量,即index表示我们想要索引第一个batch的第1列,第二个batch的第3列。
可以使用如下方法:
>>> a[torch.range(0,a.shape[0]-1).long(),:,index]
tensor([[2, 5, 4],
[6, 6, 2]])
这种使用一维tensor对三维tensor进行索引的方式,需要注意的是在第一维度需要按照index中索引的顺序指定,如果我们不指定batch维度,那么就会出现这种情况,以第1种行索引为例:
>>> a[:,index,:]
tensor([[[2, 3, 4, 5],
[4, 3, 5, 2]],
[[1, 2, 6, 8],
[9, 8, 2, 3]]])
索引的结果是:对每一个batch维度,都索引第1行和第3行。
二维tensor作为索引的情况
当涉及到二维tensor索引三维tensor的情况,上述这种简单的方式就不再适用了(笔者没有成功过,如果有更好的建议欢迎指出!)。这时需要用到torch.gather
作为索引的方式,但是由于官方文档写的比较模糊,建议看一下这篇博客,我觉得写的比官方要好:
pytorch.gather()函数深入理解(dim=1,2,3三种维度分析)
这里同样是简单举一个比较常用的例子:
# a是待索引矩阵
a = torch.tensor([[[2,3,4,5],
[5,6,7,8],
[4,3,5,2]],
[[1,2,6,8],
[8,7,6,9],
[9,8,2,3]]]) # 2x2x4 假设第一维度为batch维度
# index是二维索引矩阵
index=torch.tensor([[2,3,1],[0,2,1]])
index表示我们想要索引第一个batch的第1行的第3个元素,第2行的第4个元素,第3行的第2个元素。第二个batch的第1行的第1个元素,第2行的第3个元素,第3行的第2个元素。
torch.gather(a, dim=2, index=index.long().unsqueeze(-1))
tensor([[[4],
[8],
[3]],
[[1],
[6],
[8]]])#unqueeze(-1)的目的是使用index的形状和a保持一致。
这里指定dim=2
的意思是index中的第1个维度对应于a的第1个维度,index的第2个维度对应于a的第2个维度,index的每一个元素对a的第3个维度进行索引。注意这里要保证index的维度为2x3x1,才能对dim=2
进行索引。
再举一个dim=1
,在列维度进行索引的例子:
index=torch.tensor([[0,2,1,1],[2,0,1,2]])
index表示我们想要索引第一个batch的第1列的第0个元素,第2列的第3个元素,第3列的第1个元素,第4列的第1个元素。第二个batch的第1列的第3个元素,第2列的第1个元素,第3列的第2个元素,第4列的第3个元素。
>>> torch.gather(a, dim=1, index=index.long().unsqueeze(1))
tensor([[[2, 3, 7, 8]],
[[9, 2, 6, 3]]])
注意,这里dim=1
表示在a的第2个维度上进行索引,因此要保证index的维度为2x1x4,才能对dim=1
索引。
torch.gather的简单理解
torch.gather两种形式:
tensor0.gather(dim, index) 或者 torch.gather(tensor0, dim , index)
index和tensor0具有相同数量的维度,假设都是4维,如果不是相同数量,需要增加维度到相同。
tensor0是[b, c, h, w]维度,index是[b k 1 1]
那么加上我们想根据index索引tensor0的第1维度,那么就需要设定dim=1。
理解为在tensor0的dim=1维度,按照index的值进行排序,其他的维度则是0~(该维度的长度-1)。图像化可以表示为
dim 0 1 2 3
index. b index[0]. h w
index. b index[1]. h w
index. b index[2]. h w
index. b index[3]. h w
```