Tensor的常见操作

针对Pytorch中的tensor,总结一下常用的操作

1、torch.max和torch.min

两个函数的实现类似,形参也相同,只是一个取最大一个取最小而已,下面以max为例,min同理。
(1) torch.max(a): 返回输入a中所有元素的最大值。
(2) torch.max(a, 0): 返回每一列的最大值,且返回索引(返回最大元素在各列的行索引)。
(3) torch.max()[0]: 只返回最大值。

a = torch.tensor(
    [
        [
            [
                [2, 3, 4],
                [5, 2, 9],
            ],
            [
                [9, 5, 3],
                [8, 0, 6],
            ],
            [
                [8, 6, 7],
                [4, 1, 3],
            ]
        ],

        [
            [
                [1, 5, 4],
                [6, 2, 3],
            ],
            [
                [1, 2, 3],
                [4, 5, 7],
            ],
            [
                [3, 8, 1],
                [3, 6, 2],
            ]
        ],

    ]
)
print(a.shape, "\n")
print(torch.max(a, 2), "\n")
print(torch.max(a, 2)[0].shape)

比如上边这个例子,a的shape为torch.Size([2, 3, 2, 3]),对应最常见的网络层输入形状(b,c,h,w),

运行torch.max(a, k),k就是在第几维上取最大值(0~3)。
比如,k=0,则就是在第0维上取最大值,得到的value的shape为(3, 2, 3),原始的第0维就没有了。
同理,k=1,得到的value.shape为(2, 2, 3),原始的第1维没有了,以此类推。

torch.max函数返回值有两个,第一个是最大值(values),第二个是最大值在原始张量中的索引(indices),取[0]就是得到最大值,取[1]就是得到最大值索引。以下为输出结果。

// 共三个打印结果,横线是方便观看而添加的
-----------------------------------------------------
torch.Size([2, 3, 2, 3]) 
-----------------------------------------------------
torch.return_types.max(
values=tensor([[[5, 3, 9],
         [9, 5, 6],
         [8, 6, 7]],

        [[6, 5, 4],
         [4, 5, 7],
         [3, 8, 2]]]),
indices=tensor([[[1, 0, 1],
         [0, 0, 1],
         [0, 0, 0]],

        [[1, 0, 0],
         [1, 1, 1],
         [1, 0, 1]]])) 
-----------------------------------------------------
torch.Size([2, 3, 3])
-----------------------------------------------------

2、torch.squeeze和torch.unsqueeze(常用的缩维扩维函数)

torch.squeeze(x,dim,out):对数据的维度进行压缩,去掉维数为1的维度,默认将a中所有为1的维度删掉,也可以通过dim指定位置,删掉指定位置的维数为1的维度。
torch.unsqueeze(x,dim,out):对数据的维度进行扩充,需要通过dim指定位置,给指定位置加上维数为1的维度。

注:压缩和扩充的都是维数为1的维度,张量中元素的总个数不变,其中unsqueeze常用于添加batch一维。

示例代码如下(注释即为打印结果):

a = torch.ones((1, 1, 4, 6))
a1 = a.squeeze()  # [4, 6]
a2 = a.squeeze(dim=0)  # [1, 4, 6]

b = torch.ones((3, 5))
b1 = b.unsqueeze(dim=0)  # [1, 3, 5]

3、torch.view、torch.permute、torch.transpose(常用的维度更换函数)

1)torch.view

用来改变tensor的展示形状,其原理是,先把原始tensor展成一维向量,再按设定拼成具体的形状。
所以此函数不会改变张量元素的个数。
还要注意view()返回的tensor和传入的tensor共享内存,修改其中一个,另一个数据也会变。
实例如下:

a = torch.arange(0, 16).reshape(-1, 2, 4)
av = a.view(-1, 8)

print(a.shape)
print(av.shape)
----------------------------------------
打印结果:
torch.size([2, 2, 4])
torch.size([2, 8])

2)torch.permute

更换维度(任意数量),permute就是置换序列的意思,只是把维度顺序调换一下而已,也不会改变张量元素个数。
实例如下:

a = torch.arange(0, 24).reshape(-1, 2, 4)
b = a.permute(1, 0, 2)

print(a.shape)
print(b.shape)
----------------------------------------
打印结果:
torch.size([3, 2, 4])
torch.size([2, 3, 4])

3)torch.transpose

更换维度(两维),与permute不同,transpose一次只能在两个维度间进行转置,所以要指定dim1和dim2(必要参数,二维张量也要指定),就是把dim1与dim2进行更换。注意与.T区分,.T是将所有维度颠倒。
实例如下:

a = torch.arange(0, 24).reshape(-1, 2, 3, 4)
b = a.transpose(0, 1).transpose(1, 2)
c = b.T
----------------------------------------
结果:
a的形状是 [1, 2, 3, 4]
b经过的变化是 [1, 2, 3, 4] -> [2, 1, 3, 4] -> [2, 3, 1, 4]
c的形状是 [4, 1, 3, 2]

4、torch.contiguous和torch.is_contiguous

torch.contiguous()方法就是将张量“相邻化”,is_contiguous就是判断张量是否是“相邻化”的张量,何谓“相邻化”?直观的解释是Tensor底层一维张量元素的存储顺序与Tensor按行优先一维展开的元素顺序是否一致。

torch.contiguous()经常与torch.permute()、torch.transpose()、torch.view()方法一起使用,尤其是经常放在.view()函数前,因为.view()要求张量元素连续。

详细操作是torch.contiguous()方法首先拷贝了一份张量在内存中的地址,然后将地址按照形状改变后的张量的语义进行排列。
实例如下:

a = torch.arange(0, 24).reshape(-1, 3, 4)  # [2, 3, 4]
a = a.permute(2, 0, 1)  # [4, 2, 3]

a = a.contiguous()
a = a.view(-1, 4)  # [6, 4]

如果把第三句a = a.contiguous()注释掉,则会报错
invalid argument 2: view size is not compatible with input tensor’s size and stride (at least one dimension spans across two contiguous subspaces). Call .contiguous() before .view().

5、torch.reshape(pytorch0.4以上版本)

前面讲了.transpose,.permute,.view,而且说了如果.transpose之后直接.view会报错,需要在中间加一个.contiguous,那么为何pytorch的开发者不在.view的实现方法开头直接加一个.contiguous呢?

答案:因为历史上view方法已经约定了共享底层数据内存,返回的Tensor底层数据不会使用新的内存,如果在view中调用了contiguous方法,则可能在返回Tensor底层数据中使用了新的内存,会破坏兼容性。

PyTorch在0.4版本以后提供了reshape方法,实现了类似于 tensor.contigous().view(*args)的功能,如果不关心底层数据是否使用了新的内存,则使用reshape方法更方便。

实例如下:

x = torch.arange(0, 24).reshape(2, 3, 4)
x1 = x.reshape(3, 2, 4)
x2 = x.permute(1, 0, 2)
x3 = x.transpose(0, 1)

print(x1.is_contiguous())  # True
print(x2.is_contiguous())  # False
print(x3.is_contiguous())  # False

x1,x2,x3均变为了(3,2,4)的张量,但只有x1然是contiguous的。