本文主要介绍Pytorch中Tensor的储存机制,在搞懂了Tensor在计算机中是如何存储之后我们会进一步来探究tensor.view()、tensor.reshape()、tensor.reszie_(),她们都是改变了一个tensor的“形状”,但是他们之间又有着些许的不同,这些不同常常会导致我们程序之中出现很多的BUG。

一、Tensor的储存机制

 tensor在电脑的储存,分为两个部分(也就是说一个tensor占用了两个内存位置),一个内存储存了这个tensor的形状size、步长stride、数据的索引等信息,我们把这一部分称之为头信息区(Tensor);另一个内存储的就是真正的数据,我们称为存储区 (Storage)。换句话说,一旦定义了一个tensor,那这个tensor将会占据两个内存位置,用于存储。

  要注意,如果我们把一个tensorA进行切片,截取,修改之后通过"="赋值给B,那么这个时候tensorB其实是和tensorA是共享存储区 (Storage),唯一不同的是头信息区(Tensor)不同。下面我们直接看代码来理解。其中tensor.storage().data_ptr()是用于获取tensor储存区的首元素内存地址的。

A = torch.arange(5)  # tensor([0, 1, 2, 3, 4])
B = A[2:]            # 对A进行截取获得:tensor([2, 3, 4])
print(A)
print(B)
tensor([0, 1, 2, 3, 4])
tensor([2, 3, 4])
print(A.storage().data_ptr())
print(B.storage().data_ptr())
2076006947200
2076006947200

  我们可以很直观的看到,A和B的储存区的内存地址是一样的,因此她们是共享数据的,下面这个例子更加直观。

import torch
A = torch.arange(5)  # tensor([0, 1, 2, 3, 4])
B = A[2:]            # 对A进行截取获得:tensor([2, 3, 4])

B[1] = 100     # 修改B的第2位置元素为100
print(A)
print(B)
tensor([  0,   1,   2, 100,   4])
tensor([  2, 100,   4])

  因此我们可以得出结论,通过=直接赋值的操作其实就是“浅拷贝”(这里注意和list的切片区分,list使用A[2:],是可以得到新的一个list的)

二、tensor的stride()属性、storage_offset()属性

为了更好的解释tensor的reshape(),以及view()的操作,我们还需要了解下tensor的stride属性。刚才上面我们提到了,tensor为了节约内存,很多操作其实都是在更改tensor的头信息区(Tensor),因为头信息区里面包含了如何组织数据,以及从哪里开始组织。其中stride()和storage_offset()属性分别代表的就是步长以及初始偏移量。

storage_offset()属性

表示tensor的第一个元素与真实存储区(storage)的第一个元素的偏移量。例如下面的例子:

import torch
A = torch.arange(5) 
B = A[2:]
C = A[1:]
print(A)
print(B)
print(C)
tensor([0, 1, 2, 3, 4])
tensor([2, 3, 4])
tensor([1, 2, 3, 4])
print(B.storage_offset())
print(C.storage_offset())
2
1

  我们可以看到tensorB和tensorC都是从A切片而来的,她们俩的存储区 (Storage)是和A共享的,只不过B的第一个元素,与存储区 (Storage)的首元素相差了2个位置(也就是储存区的index=2开始),C的第一个元素与存储区 (Storage)的首元素相差了1个位置。

stride()属性

这个属性比较难理解,直接翻译官方文档就是:stride是在指定维度dim中从一个元素跳到下一个元素所必需的步长。直接上例子:

import torch
A = torch.rand(2, 3)  # 生成2*3的随机数
print(A)
print(A.storage())    # 打印A的储存区真实的数据
打印A: tensor([[0.8438, 0.2782, 0.9584],
        [0.2089, 0.0259, 0.3666]])
 0.8437800407409668
 0.2781521677970886
 0.9583932757377625
 0.2088671326637268
 0.025857746601104736
 0.366576611995697
[torch.FloatStorage of size 6]
print(A.stride())
(3, 1)

  主要是理解这个(3,1)指的是什么意思。这里的3指的是A[i][j]到A[i+1][j]这两个数字在存储区真实数据排列中是相差3的(例如A[0][0]=0.8438与A[1][0]=0.2089这两个数字在储存区中位次相差了3);这里的1是指A[i][j]与A[i][j+1]这两个数字在储存区的真实数据排列中相差1(例如A[0][0]=0.8438与A[0][1]=0.2781这两个数字在储存区中位次相差1)。如果还没有理解,加下来我们试一下对于3维数据看看他们的stride()属性。

import torch
A = torch.rand(2, 3, 4)  # 生成2*3*4的随机数
print("打印A:",A)
print(A.storage())    # 打印A的储存区真实的数据
打印A: tensor([[[0.4303, 0.7474, 0.8649, 0.5006],
         [0.2716, 0.9966, 0.7765, 0.6737],
         [0.5515, 0.2274, 0.9791, 0.1940]],

        [[0.6401, 0.7746, 0.5124, 0.0258],
         [0.8576, 0.9118, 0.9504, 0.4675],
         [0.9359, 0.0687, 0.2457, 0.3604]]])
 0.4302864074707031
 0.747403085231781
 0.8648527264595032
 0.500649631023407
 0.2716004252433777
 0.9965775609016418
 0.7765441536903381
 0.6737198233604431
 0.5515168905258179
 0.2273930311203003
 0.9791405200958252
 0.19399094581604004
 0.6401097774505615
 0.7746065855026245
 0.512383759021759
 0.02578103542327881
 0.8575518727302551
 0.911821186542511
 0.9503545165061951
 0.4674733877182007
 0.9358749389648438
 0.06866037845611572
 0.24573636054992676
 0.3603515625
print(A.stride())
(12, 4, 1)

  输出有点长,大家对照着看,由于我们A的size是3维度的,因此我们A.stride()也是个三元组,那如果A是4维呢?(A.stride()一定就是4元组了)。这里的12表示就是A[i][j][k]与A[i+1][j][k]这两个数字在真实储存区的数据排布中相差12,大家可以对照的找几个数字试试。同样的道理这里的4表示A[i][j][k]与A[i][j+1][k]这两个数字在真实储存区的数据排布中相差4。最后1表示什么我就不说啦。

  好了终于说完这个很难的知识点了,接下来就进入正题,view()、reshape()、reszie_()三者的关系和区别。

三、view()、reshape()、reszie_()三者的关系和区别

其中view()和reshape()是官方比较推荐使用的方式,而resize_()官方在文档中说到不太推荐使用,具体原因一会说到。这三个方法都是可以完成对以一个tensor重新排列,没错是重新排列,其实她们本质上都没有改变tensor的存储区 (Storage)的真实数据的排列(除了一些特殊情况下会使得存储区发生改变,这就是她们间的区别)。

view()

从字面上来说就是"视图"的意思,就是把存储区 (Storage)的真实数据,根据某种排列方式”展示“给你看罢了,也就是仅仅改变了头信息区(Tensor),真实数据的储存地址是没有改变的。直接上例子。

import torch
A = torch.arange(6) 
B = A.view(2,3)
print(A)
print(B)
tensor([0, 1, 2, 3, 4, 5])
tensor([[0, 1, 2],
        [3, 4, 5]])
print(A.storage().data_ptr())
print(B.storage().data_ptr())
1881582170752
1881582170752

  可以看到,A和B的真实数据的内存地址都是一样的,下面我们进一步打印一下A,B两个tensor真实数据的排列。

print(A.storage())
print(B.storage())
0
 1
 2
 3
 4
 5
[torch.LongStorage of size 6]
 0
 1
 2
 3
 4
 5
[torch.LongStorage of size 6]

  可以看到,是完全一样的。更进一步打印一下A,B的stride()属性

print(A.stride())
print(B.stride())
(1,)
(3, 1)

  没问题和前面说的是一样的。

  总结一下,view()函数主要就是更改了tensor中的stride()属性,这样从而影响了tensor的显示,但是从本质上来说A,B还是共用真实数据的存储区 (Storag)的。

reshape()

为了解释view()和reshape()的区别,我们还需要知道一个知识:tensor的连续性。tensor又不是函数哪里来什么连续性?其实tensor的连续性说的就是stride()属性和size()属性(tensor维度)之间的关系。

  前一小结已经说了对于一个高维的tensor,stride()指的是:指定维度dim中从一个元素跳到下一个元素所必需的步长。一般来说我们最后一个维度步长应该是1(其实我们前面的例子我们应该也能发现,例子中所有tensor.stride()返回的元组最后一个元素都是1),对吧,因为是按顺序排列的嘛。但是当一个tensor涉及到转置(tensor.t(),tensor.transpose(),tensor.permute())这些操作都会使得tensor失去连续性这个性质。我们直接来看看例子吧。

a = torch.arange(6).view(2, 3)
b = a.t()
c = a.transpose(1,0)
d = a.permute(1,0)

print('b是:',b)
print('c是:',c)
print('d是:',d)
b是: tensor([[0, 3],
        [1, 4],
        [2, 5]])
c是: tensor([[0, 3],
        [1, 4],
        [2, 5]])
d是: tensor([[0, 3],
        [1, 4],
        [2, 5]])
print(a.stride())
print(b.stride())
print(c.stride())
print(d.stride())
(3, 1)
(1, 3)
(1, 3)
(1, 3)

这里我就不验证她们是不是同一个存储区 (Storage)了,大家下来可以验证下(其实就是同一个)。我们可以看到b,c,d三个tensor的stride()属性和a是不一样的,根据stride()的定义大家应该是很容易知道b,c,d返回的stride()是什么意思吧。那为什么说b,c,d的tensor就不连续了呢?是因为她们不满足张量的连续性条件了。连续性条件如下:

Pytorch Tensorflow 数据 pytorch tensor.view_数据

 

 这是什么意思呢?拿b举例就是,b的stride=(3,1),b的size=(3,2),那么stride[0] != stride[1] * size[1]的,因此b是不满足连续性条件的。如果从直观上来感觉来"连续"的意思就是,“我”旁边的数字就应该是“我”真实储存区旁边的数据,例如b[0][0]=0,但是b[0][1]=3,0和3这两个数字在真实的存储区 (Storage)不是挨着的啊,所以叫做不连续。

  那不满足连续性有什么后果呢?后果就是不满足连续性的tensor是无法使用view()方法的。换句话说,上面例子中的b,c,d都无法再使用view()方法了。

e = b.view(1,6)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

  c,d大家自己下来试一试。所以对于一个tensor是不是连续就意味着他能不能使用view()方法。

  那有什么办法让b使用view()方法呢?那就是把b连续化(使用tensor.contiguous()方法)。上例子。

a = torch.arange(6).view(2, 3)
b = a.t()
c = b.contiguous()
print(a.storage())
print(c.storage())
0
 1
 2
 3
 4
 5
[torch.LongStorage of size 6]
 0
 3
 1
 4
 2
 5
[torch.LongStorage of size 6]
print(a.storage().data_ptr())
print(c.storage().data_ptr())
1881582182144
1881582172928

  其实tensor.contiguous()方法是创造了一个新的tensor(全新的,连存储区都不共用的tensor),这里的c就是从b得到的连续的tensor了,大家可以打印下c.stride(),会得到(2,1),这样再根据c的size就能发现,c是满足上面提到的连续性公式的。

  了解以上知识之后,reshape()和view()的差别就来了,view()是没法对非连续性的tensor使用的(会报错),但是reshape()是可以对非连续性tensor使用的。换句话说

  • 当tensor满足连续性要求时,reshape() = view(),和原来tensor共用内存
  • 当tensor不满足连续性要求时,reshape() = contiguous() + view(),会产生新的存储区的tensor,与原来tensor不共用内存

  这就是view()和reshape()的差别了。

reszie_()

 那这一个又和前面那俩有啥关系的呢?从官方文档上来说,它是不希望我们使用这个resize_()的,如图。

 

Pytorch Tensorflow 数据 pytorch tensor.view_Storage_02

   前面说到的reshape和view都必须要用到全部的原始数据,比如你的原始数据只有12个,无论你怎么变形都必须要用到12个数字们,不能多不能少。因此你就不能把只有12个数字的tensor强行reshape成2*5的维度的tensor。但是resize_()可以做到,无论你存储区原始有多少个数字,我都能变成你想要的维度,数字不够怎么办?随机产生凑!数字多了怎么办?就取我需要的部分!上例子。

  多说一句a.resize_()是会改变a的哟,换句话说,a.resize_(2,3)之后,a就不再是1*7的维度了,而是2*3的维度了。但是a的储存区还是原来的储存区

a = torch.arange(7)
print("变换前a的储存区地址:",a.storage().data_ptr())
b = a.resize_(2,3)
print('这是新的a:',a)
变换前a的储存区地址: 1881579251648
这是新的a: tensor([[0, 1, 2],
        [3, 4, 5]])
print(a.storage())
print(b.storage())
0
 1
 2
 3
 4
 5
 6
[torch.LongStorage of size 7]
 0
 1
 2
 3
 4
 5
 6
[torch.LongStorage of size 7]
print('变换后a的储存区地址',a.storage().data_ptr())
print(b.storage().data_ptr())
变换后a的储存区地址 1881579251648
1881579251648

  你会发现尽管a的”长相“(数字个数也从7个变成了6个)被改变了,但是存储区依旧是没变的(要注意到真实存储区的个数也没变哟还是7个),因此我们可以说resize_()再进行变换的时候如果数字多余了,会截取我们需要的数据量,多余的数据量并没有被舍弃。

  再来看看,当我reszie_多于原来的数据的时候发生什么。

a = torch.arange(7)
print("变换前a的储存区地址:",a.storage().data_ptr())
b = a.resize_(3,4)
print(a.storage())
print(b.storage())
变换前a的储存区地址: 1881579250944

 0
 1
 2
 3
 4
 5
 6
 7667809
 6815836
[torch.LongStorage of size 9]

 0
 1
 2
 3
 4
 5
 6
 7667809
 6815836
[torch.LongStorage of size 9]
print('变换后a的储存区地址',a.storage().data_ptr())
print(b.storage().data_ptr())
变换后a的储存区地址 1881582026048
1881582026048

  这个时候resize_()前后a的储存区地址是发生了变化的哟。

下一个问题:resize_()可不可以对不连续的tensor使用呢?

  答案是可以,并且并不会改变原来tensor的内存。当tensor是不连续的时候,采用reshape()会生成个新的存储区的,采用resize_()则不会改变存储区。那这两者有啥区别呢?其实很好解释,reshape是尊重tensor,把存储区改了来将就tensor的reshape的长相,并使得连续。而resize_是:不改存储区,但是“用户”又想要看到想看到的长相,行,那我就把存储区的数按照你想看到的长相排列吧。直接上例子。

import torch
a = torch.arange(6).view(2, 3)
b = a.t()  
#b是这个样子的:tensor([[0, 3],
#                     [1, 4],
#                     [2, 5]])
c = b.reshape(1,6)
e = b.resize_(1,6)
print("c的存储区:",c.storage().data_ptr())
print('e的存储区:',e.storage().data_ptr())
c的存储区: 2237602017664
e的存储区: 2237602025472
print("c的存储区真实数据排布:",c.storage())
print("e的存储区真实数据排布:",e.storage())
c的存储区真实数据排布:  
 0
 3
 1
 4
 2
 5
[torch.LongStorage of size 6]
e的存储区真实数据排布:  
 0
 1
 2
 3
 4
 5
[torch.LongStorage of size 6]
print('我是c:',c)
print('我是e:',e)
我是c: tensor([[0, 3, 1, 4, 2, 5]])
我是e: tensor([[0, 1, 2, 3, 4, 5]])

可以很直观的看出来,如果tensor是不连续的时候,reshape和resize_的差别了吧。

四、总结

最后总结一下view()、reshape()、reszie_()三者的关系和区别。

  • view()只能对满足连续性要求的tensor使用。
  • 当tensor满足连续性要求时,reshape() = view(),和原来tensor共用内存。
  • 当tensor不满足连续性要求时,reshape() = contiguous() + view(),会产生新的存储区的tensor,与原来tensor不共用内存。
  • resize_()可以随意的获取任意维度的tensor,不用在意真实数据的个数限制,但是不推荐使用。