好久没更新博客了,最近一直在忙,既有生活上的也有工作上的。道阻且长啊。
今天来水一文,说一说最近工作上遇到的一个函数:torch.gather() 。
文字理解
我遇到的代码是 NLP 相关的,代码中用 torch.gather()
来将一个 tensor 的 shape 从 (batch_size, seq_length, hidden_size)
转为 (batch_size, labels_length, hidden_size)
,其中 seq_length >= labels_length
。
torch.gather()
的官方解释是
Gathers values along an axis specified by dim.
就是在指定维度上 gather value。那么怎么 gather、gather 哪些 value 呢?这就要看其参数了。
torch.gather()
的必填也是最常用的参数有三个,下面引用官方解释:
-
input
(Tensor) – the source tensor -
dim
(int) – the axis along which to index -
index
(LongTensor) – the indices of elements to gather
所以一句话概括 gather 操作就是:根据 index
,在 input
的 dim
维度上收集 value。
具体来说,input
就是源 tensor,等会我们要在这个 tensor 上执行 gather 操作。如果 input
是一个一维数组,即 flat 列表,那么我们就可以直接根据 index
在 input
上取了,就像正常的列表/数组索引一样。但是由于 input
可能含有多个维度,是 N 维数组,所以我们需要知道在哪个维度上进行 gather,这就是 dim
的作用。
对于 dim
参数,一种更为具体的理解方式是替换法。假设 input
和 index
均为三维数组,那么输出 tensor 每个位置的索引是列表 [i, j, k]
,正常来说我们直接取 input[i, j, k]
作为 输出 tensor 对应位置的值即可,但是由于 dim
的存在以及 input.shape
可能不等于 index.shape
,所以直接取值可能就会报 IndexError
。所以我们是将索引列表的相应位置替换为 dim
,再去 input
取值。如果 dim=0
,我们就替换索引列表第 0 个值,即 [dim, j, k]
,依此类推。Pytorch 的官方文档的写法其实也是这个意思,但是看这么多个方括号可能会有点懵:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
但是可能你还有点迷糊,没关系接着看下面的直观理解部分,然后再回来看这段话,结合着看,相信你很快能明白。
由于我们是按照 index
来取值的,所以最终得到的 tensor 的 shape 也是和 index
一样的,就像我们在列表上按索引取值,得到的输出列表长度和索引相等一样。
直观理解
为便于理解,我们以一个具体例子来说明。我们使用反推法,根据 input
和输出推参数。这应该也是我们平常自己写代码的时候遇到比较多的情况。
假设 input
和我们想要的输出 output
如下:
>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
>>> output_tensor # shape: (2, 2, 4)
tensor([[[ 0, 1, 2, 3],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[20, 21, 22, 23]]])
即,我们想让 shape 为 (2, 3, 4)
的 input_tensor
变成 shape 为 (2, 2, 4)
的 output_tensor
,丢弃维度 1 的第 2 个元素,即 [ 4, 5, 6, 7]
和 [16, 17, 18, 19]
。
我们应用替换法,重点是找出来 dim
和 index
的值。始终记住 index
和 output_tensor
的 shape 是一样的。
从 output_tensor
的第一个位置开始,由于 output_tensor[0, 0, :] = input_tensor[0, 0, :]
,所以此时 [i, j, k]
是一样的,我们看不出来 dim
应该是多少。
下一行 output_tensor[0, 1, 0] = input_tensor[0, 2, 0]
,这里我们看到维度 1 发生了变化,1 变成了 2,所以 dim
应该是 1,而 index
应为 2, index_tensor[0, 1, 0]=2
。
此时 dim
已经明确。同理,output_tensor[0, 1, 1] = input_tensor[0, 2, 1]
,index_tensor[0, 1, 1]=2
,依此类推,得到 index_tensor[0, 1, :] = 2
。同时也可以明确 index_tensor[0, 0, :] = 0
。
所以
>>> dim = 0
>>> index_tensor
tensor([[[0, 0, 0, 0],
[2, 2, 2, 2]],
[[0, 0, 0, 0],
[2, 2, 2, 2]]])
简单可描述如下图:
为描述方便,假如我们把输入看作是 6 行,从上到下依次是 0-5。那么从事后诸葛亮的角度讲,输出相当于是把第 1 和第 4 行“抽掉”。如果输出和输入一样,那么原本的 index_tensor
就是如下:
tensor([[[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2]],
[[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2]]])
“抽掉”后, index_tensor
也相应“抽掉”,那么就得到我们想要的结果了。而且由于这个“抽掉”的操作是在维度 1 上进行的,那么 dim
自然是 1。
numpy.take()
和 tf.gather
貌似也是同样功能,就不细说了。
Reference
- torch.gather — PyTorch 1.9.0 documentation
- numpy.take — NumPy v1.21 Manual
- tf.gather | TensorFlow Core v2.6.0
END