Tensor的常见操作
针对Pytorch中的tensor,总结一下常用的操作
1、torch.max和torch.min
两个函数的实现类似,形参也相同,只是一个取最大一个取最小而已,下面以max为例,min同理。
(1) torch.max(a): 返回输入a中所有元素的最大值。
(2) torch.max(a, 0): 返回每一列的最大值,且返回索引(返回最大元素在各列的行索引)。
(3) torch.max()[0]: 只返回最大值。
a = torch.tensor(
[
[
[
[2, 3, 4],
[5, 2, 9],
],
[
[9, 5, 3],
[8, 0, 6],
],
[
[8, 6, 7],
[4, 1, 3],
]
],
[
[
[1, 5, 4],
[6, 2, 3],
],
[
[1, 2, 3],
[4, 5, 7],
],
[
[3, 8, 1],
[3, 6, 2],
]
],
]
)
print(a.shape, "\n")
print(torch.max(a, 2), "\n")
print(torch.max(a, 2)[0].shape)
比如上边这个例子,a的shape为torch.Size([2, 3, 2, 3]),对应最常见的网络层输入形状(b,c,h,w),
运行torch.max(a, k),k就是在第几维上取最大值(0~3)。
比如,k=0,则就是在第0维上取最大值,得到的value的shape为(3, 2, 3),原始的第0维就没有了。
同理,k=1,得到的value.shape为(2, 2, 3),原始的第1维没有了,以此类推。
torch.max函数返回值有两个,第一个是最大值(values),第二个是最大值在原始张量中的索引(indices),取[0]就是得到最大值,取[1]就是得到最大值索引。以下为输出结果。
// 共三个打印结果,横线是方便观看而添加的
-----------------------------------------------------
torch.Size([2, 3, 2, 3])
-----------------------------------------------------
torch.return_types.max(
values=tensor([[[5, 3, 9],
[9, 5, 6],
[8, 6, 7]],
[[6, 5, 4],
[4, 5, 7],
[3, 8, 2]]]),
indices=tensor([[[1, 0, 1],
[0, 0, 1],
[0, 0, 0]],
[[1, 0, 0],
[1, 1, 1],
[1, 0, 1]]]))
-----------------------------------------------------
torch.Size([2, 3, 3])
-----------------------------------------------------
2、torch.squeeze和torch.unsqueeze(常用的缩维扩维函数)
torch.squeeze(x,dim,out):对数据的维度进行压缩,去掉维数为1的维度,默认将a中所有为1的维度删掉,也可以通过dim指定位置,删掉指定位置的维数为1的维度。
torch.unsqueeze(x,dim,out):对数据的维度进行扩充,需要通过dim指定位置,给指定位置加上维数为1的维度。
注:压缩和扩充的都是维数为1的维度,张量中元素的总个数不变,其中unsqueeze常用于添加batch一维。
示例代码如下(注释即为打印结果):
a = torch.ones((1, 1, 4, 6))
a1 = a.squeeze() # [4, 6]
a2 = a.squeeze(dim=0) # [1, 4, 6]
b = torch.ones((3, 5))
b1 = b.unsqueeze(dim=0) # [1, 3, 5]
3、torch.view、torch.permute、torch.transpose(常用的维度更换函数)
1)torch.view
用来改变tensor的展示形状,其原理是,先把原始tensor展成一维向量,再按设定拼成具体的形状。
所以此函数不会改变张量元素的个数。
还要注意view()返回的tensor和传入的tensor共享内存,修改其中一个,另一个数据也会变。
实例如下:
a = torch.arange(0, 16).reshape(-1, 2, 4)
av = a.view(-1, 8)
print(a.shape)
print(av.shape)
----------------------------------------
打印结果:
torch.size([2, 2, 4])
torch.size([2, 8])
2)torch.permute
更换维度(任意数量),permute就是置换序列的意思,只是把维度顺序调换一下而已,也不会改变张量元素个数。
实例如下:
a = torch.arange(0, 24).reshape(-1, 2, 4)
b = a.permute(1, 0, 2)
print(a.shape)
print(b.shape)
----------------------------------------
打印结果:
torch.size([3, 2, 4])
torch.size([2, 3, 4])
3)torch.transpose
更换维度(两维),与permute不同,transpose一次只能在两个维度间进行转置,所以要指定dim1和dim2(必要参数,二维张量也要指定),就是把dim1与dim2进行更换。注意与.T区分,.T是将所有维度颠倒。
实例如下:
a = torch.arange(0, 24).reshape(-1, 2, 3, 4)
b = a.transpose(0, 1).transpose(1, 2)
c = b.T
----------------------------------------
结果:
a的形状是 [1, 2, 3, 4]
b经过的变化是 [1, 2, 3, 4] -> [2, 1, 3, 4] -> [2, 3, 1, 4]
c的形状是 [4, 1, 3, 2]
4、torch.contiguous和torch.is_contiguous
torch.contiguous()方法就是将张量“相邻化”,is_contiguous就是判断张量是否是“相邻化”的张量,何谓“相邻化”?直观的解释是Tensor底层一维张量元素的存储顺序与Tensor按行优先一维展开的元素顺序是否一致。
torch.contiguous()经常与torch.permute()、torch.transpose()、torch.view()方法一起使用,尤其是经常放在.view()函数前,因为.view()要求张量元素连续。
详细操作是torch.contiguous()方法首先拷贝了一份张量在内存中的地址,然后将地址按照形状改变后的张量的语义进行排列。
实例如下:
a = torch.arange(0, 24).reshape(-1, 3, 4) # [2, 3, 4]
a = a.permute(2, 0, 1) # [4, 2, 3]
a = a.contiguous()
a = a.view(-1, 4) # [6, 4]
如果把第三句a = a.contiguous()注释掉,则会报错
invalid argument 2: view size is not compatible with input tensor’s size and stride (at least one dimension spans across two contiguous subspaces). Call .contiguous() before .view().
5、torch.reshape(pytorch0.4以上版本)
前面讲了.transpose,.permute,.view,而且说了如果.transpose之后直接.view会报错,需要在中间加一个.contiguous,那么为何pytorch的开发者不在.view的实现方法开头直接加一个.contiguous呢?
答案:因为历史上view方法已经约定了共享底层数据内存,返回的Tensor底层数据不会使用新的内存,如果在view中调用了contiguous方法,则可能在返回Tensor底层数据中使用了新的内存,会破坏兼容性。
PyTorch在0.4版本以后提供了reshape方法,实现了类似于 tensor.contigous().view(*args)的功能,如果不关心底层数据是否使用了新的内存,则使用reshape方法更方便。
实例如下:
x = torch.arange(0, 24).reshape(2, 3, 4)
x1 = x.reshape(3, 2, 4)
x2 = x.permute(1, 0, 2)
x3 = x.transpose(0, 1)
print(x1.is_contiguous()) # True
print(x2.is_contiguous()) # False
print(x3.is_contiguous()) # False
x1,x2,x3均变为了(3,2,4)的张量,但只有x1然是contiguous的。