简介
今天在使用torch中的topk的时候, 对于dim产生了一些疑问. 后面也是找到了规律, 但是还是很困惑他为什么是这么设计的, 即dim与tensor本身的行列是不一致的. 然后就查了一下, 真的找到了一篇很好的文章, 解决了我的困惑, 就想在这里记录一下.
我这一篇文章里的所有的动图, 都是来自与下面这篇文章, 写的非常直观.
原文链接(十分棒的文章), Understanding dimensions in PyTorch
关于这里设计的代码, 有一个完整的notebook的文档, 具体链接见Github, Pytorch维度介绍.ipynb
理解PyTorch维度概念
首先我们从最基础的开始, 当我们在Pytorch中定义一个二维的tensor的时候, 他包含行和列. 例如下面我们创建一个2*3的tensor.
1. x = torch.tensor([
2. [1,2,3],
3. [4,5,6]
4. ])
5. # 我们可以看到"行"是dim=0, "列"是dim=1
6. print(x.shape)
7. >> torch.Size([2, 3])
我们可以看到打印的结果显示:
- first dimension (dim=0) stays for rows, 第一个维度代表行, 因为是2, 实际x就是2行
- the second one (dim=1) for columns, 第二个维度代表列, 因为是3
于是, 我们会认为, torch.sum(x, dim=0)就是(1+2+3, 4+5+6)=tensor([6, 15]), 但是实际情况却不是这个样子的.
- torch.sum(x, dim=0)
- >> tensor([5, 7, 9])
我们可以看到按照dim=0求和, 其实是在按列相加, 也就是(1+4, 2+5, 3+6) =tensor([5, 7, 9]), 和我们想象的完全不一样. 我们再看一下按照dim=1进行求和.
- torch.sum(x, dim=1)
- >> tensor([ 6, 15])
可以看到, 在按照dim=1的时候求和的时候, 其实在按照按行进行求和, (1+2+3, 4+5+6)=tensor([6, 15]), 这就让人很困惑, 明明上面说的是dim=0代表是行.
于是, 原文作者在一篇介绍numpy维度的文章中, 找到了问题的关键所在. 也就是下面的这段话(numpy中的axis也就是这里的dim).
The way to understand the "axis" of numpy sum is that it collapses the specified axis. So when it collapses the axis 0 (the row), it becomes just one row (it sums column-wise).
上面的话简单翻译就是, 当按照axis=0进行求和的时候, 其实可以想象为对axis=0这个维度进行挤压, 最后只剩下一行, 那一行就是结果, 也就是按列在相加.
是不是还是会有一些困惑, 我们还是对于上面的例子(tensor([[1,2,3], [4,5,6]])), 看一下在dim=0的时候, 为什么是列相加, 以及上面的collapse the specific axis(dim)的含义.
如上面的动图所示, 当dim=0的时候, 按每一行的元素进行相加, 最后的结果就是和按列求和.
对于三维向量
下面我们更进一步, 来看一下对于三维的tensor, 在各个维度进行sum操作的结果. 首先我们看一下每一个dim代表的含义.
1. # 看一下三维的
2. x = torch.tensor([
3. [
4. [1,2,3],
5. [4,5,6]
6. ],
7. [
8. [1,2,3],
9. [4,5,6]
10. ],
11. [
12. [1,2,3],
13. [4,5,6]
14. ]
15. ])
16. # 我们可以看到第三维是dim=0, "行"是dim=1, 列是dim=2
17. print(x.shape)
18. >> torch.Size([3, 2, 3])
可以看到此时dim=0是第三个维度, dim=1是行, dim=2是列.
1. torch.sum(x, dim=0)
2. >>
3. tensor([[ 3, 6, 9],
4. [12, 15, 18]])
我们可以将其看成是各个二维平面对应元素求和, 还是有点绕, 还是直接看下面的动图.
接着是对dim=1进行求和.
1. torch.sum(x, dim=1)
2. >>
3. tensor([[5, 7, 9],
4. [5, 7, 9],
5. [5, 7, 9]])
还是直接看下面的动图, 来进行理解.
最后按照dim=2来进行求和.
1. torch.sum(x, dim=2)
2. >>
3. tensor([[ 6, 15],
4. [ 6, 15],
5. [ 6, 15]])
还是使用动图来进行解释.