简介

今天在使用torch中的topk的时候, 对于dim产生了一些疑问. 后面也是找到了规律, 但是还是很困惑他为什么是这么设计的, 即dim与tensor本身的行列是不一致的. 然后就查了一下, 真的找到了一篇很好的文章, 解决了我的困惑, 就想在这里记录一下.

我这一篇文章里的所有的动图, 都是来自与下面这篇文章, 写的非常直观.

原文链接(十分棒的文章), Understanding dimensions in PyTorch

关于这里设计的代码, 有一个完整的notebook的文档, 具体链接见GithubPytorch维度介绍.ipynb

理解PyTorch维度概念

首先我们从最基础的开始, 当我们在Pytorch中定义一个二维的tensor的时候, 他包含行和列. 例如下面我们创建一个2*3的tensor.

1. x = torch.tensor([
2.         [1,2,3],
3.         [4,5,6]
4.     ])
5. # 我们可以看到"行"是dim=0, "列"是dim=1
6. print(x.shape)
7. >> torch.Size([2, 3])

我们可以看到打印的结果显示:

  • first dimension (dim=0) stays for rows, 第一个维度代表行, 因为是2, 实际x就是2行
  • the second one (dim=1) for columns, 第二个维度代表列, 因为是3

于是, 我们会认为, torch.sum(x, dim=0)就是(1+2+3, 4+5+6)=tensor([6, 15]), 但是实际情况却不是这个样子的.

  1. torch.sum(x, dim=0)
  2. >> tensor([5, 7, 9])

我们可以看到按照dim=0求和, 其实是在按列相加, 也就是(1+4, 2+5, 3+6) =tensor([5, 7, 9]), 和我们想象的完全不一样. 我们再看一下按照dim=1进行求和.

  1. torch.sum(x, dim=1)
  2. >> tensor([ 6, 15])

可以看到, 在按照dim=1的时候求和的时候, 其实在按照按行进行求和,  (1+2+3, 4+5+6)=tensor([6, 15]), 这就让人很困惑, 明明上面说的是dim=0代表是行.

于是, 原文作者在一篇介绍numpy维度的文章中, 找到了问题的关键所在. 也就是下面的这段话(numpy中的axis也就是这里的dim).

The way to understand the "axis" of numpy sum is that it collapses the specified axis. So when it collapses the axis 0 (the row), it becomes just one row (it sums column-wise).

上面的话简单翻译就是, 当按照axis=0进行求和的时候, 其实可以想象为对axis=0这个维度进行挤压, 最后只剩下一行, 那一行就是结果, 也就是按列在相加.

是不是还是会有一些困惑, 我们还是对于上面的例子(tensor([[1,2,3], [4,5,6]])), 看一下在dim=0的时候, 为什么是列相加, 以及上面的collapse the specific axis(dim)的含义.

pytorch 维度 pytorch维度理解_pytorch 维度

如上面的动图所示, 当dim=0的时候, 按每一行的元素进行相加, 最后的结果就是和按列求和.

对于三维向量

下面我们更进一步, 来看一下对于三维的tensor, 在各个维度进行sum操作的结果. 首先我们看一下每一个dim代表的含义.

1. # 看一下三维的
2. x = torch.tensor([
3.         [
4.          [1,2,3],
5.          [4,5,6]
6.         ],
7.         [
8.          [1,2,3],
9.          [4,5,6]
10.         ],
11.         [
12.          [1,2,3],
13.          [4,5,6]
14.         ]
15.     ])
16. # 我们可以看到第三维是dim=0, "行"是dim=1, 列是dim=2
17. print(x.shape)
18. >> torch.Size([3, 2, 3])

可以看到此时dim=0是第三个维度, dim=1是行, dim=2是列.

1. torch.sum(x, dim=0)
2. >>
3. tensor([[ 3,  6,  9],
4.         [12, 15, 18]])

我们可以将其看成是各个二维平面对应元素求和, 还是有点绕, 还是直接看下面的动图.

pytorch 维度 pytorch维度理解_二维_02

接着是对dim=1进行求和.

1. torch.sum(x, dim=1)
2. >>
3. tensor([[5, 7, 9],
4.         [5, 7, 9],
5.         [5, 7, 9]])

还是直接看下面的动图, 来进行理解.

pytorch 维度 pytorch维度理解_二维_03

最后按照dim=2来进行求和.

1. torch.sum(x, dim=2)
2. >>
3. tensor([[ 6, 15],
4.         [ 6, 15],
5.         [ 6, 15]])

还是使用动图来进行解释.

pytorch 维度 pytorch维度理解_pytorch 维度_04