神经网络view(),torch.flatten(),torch.nn.Flatten()

  • 1、view()
  • 2、torch.nn.Flatten()
  • 3、torch.flatten()



在神经网络中经常看到view(),torch.flatten(),torch.nn.Flatten()这几个方法。这几个方法一般用于改变tensor的形状。为日后方便使用下面就一一透彻的理解一下。

1、view()

view()的原理很简单,其实就是把原先tensor中的数据进行排列,排成一行,然后根据所给的view()中的参数从一行中按顺序选择组成最终的tensor
view()可以有多个参数,这取决于你想要得到的是几维的tensor,一般设置两个参数,也是神经网络中常用的(一般在全连接之前),代表二维。
view(h,w),h代表行(想要变为几行),w代表的是列(想要变为几列)#这里所说并不严谨,只是为了更好理解,

view()的参数

作用

h

取值代表行数,当不知道要变为几行,但知道要变为几列时可取-1

w

取值代表列数,当不知道要变为几列,但知道要变为几行时可取-1

注意:元素个数要能整除行和列|
下面看几个例子就理解了。
1、把原先tensor中的数据进行排列,排成一行,然后根据所给的view()中的参数从一行中按顺序选择组成最终的tensor

import torch
a=torch.Tensor([[[1,2,3],[4,5,6],[7,8,9]]])
b=torch.Tensor([1,2,3,4,5,6,7,8,9])

#结果:
torch.Size([1, 3, 3])
tensor([[[1., 2., 3.],
         [4., 5., 6.],
         [7., 8., 9.]]])
torch.Size([9])
tensor([1., 2., 3., 4., 5., 6., 7., 8., 9.])

#a是[1,3,3]的tensor向量:
b是[9]的tensor向量

a1 = a.view(3,-1)
b1 = b.view(3,-1)

#a1和b1的结果:
torch.Size([3, 3])
tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])
torch.Size([3, 3])
tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])
结果一样

2、当知道要变成的tensor的行时:

import torch
a=torch.Tensor([[[1,2,3],[4,5,6]]])
b=a.view(3,-1)
#a的结果:2行3列
tensor([[[1., 2., 3.],
         [4., 5., 6.]]])

#b的结果:变为了3行2列
tensor([[1., 2.],
        [3., 4.],
        [5., 6.]])

2、当知道要变成的tensor的列时:

import torch
a=torch.Tensor([[[1,2,3],[4,5,6]]])
b=a.view(-1,1)

#a的结果:2行三列
tensor([[[1., 2., 3.],
         [4., 5., 6.]]])
#b的结果:变为了6行1列

tensor([[1.],
        [2.],
        [3.],
        [4.],
        [5.],
        [6.]])

2、torch.nn.Flatten()

torch.nn.Flatten(start_dim=1,end_dim=-1)
start_dim与end_dim代表合并的维度,开始的默认值为1,结束的默认值为-1,因此常被使用在神经网络当中,将每个batch的数据拉伸成一维。
下面举几个例子:
1、默认参数时:

import torch
a = torch.randn(8,3,64,64)
F = torch.nn.Flatten()
a1 = F(a)

a的大小:
torch.Size([8, 3, 64, 64])

a1的大小:
torch.Size([8, 12288])
默认将第0维保留下来,其余拍成一维

2、有一个参数时(一个参数代表开始合并的维度):

import torch
a = torch.randn(8,3,64,64)
F = torch.nn.Flatten(2)
a1 = F(a)


a的大小:
torch.Size([8, 3, 64, 64])

a1的大小:
torch.Size([8, 3, 4096])
从第二维开始,拍成一维

3、有两个参数时(前一个参数代表开始合并的维度,后一个参数代表结束合并的维度)

import torch
a = torch.randn(8,3,64,64)
F = torch.nn.Flatten(1,2)
a1 = F(a)

a的大小:
torch.Size([8, 3, 64, 64])

a1的大小:
torch.Size([8, 192, 64])
将第一维到第二维拍成一维,其余不变

3、torch.flatten()

与 torch.nn.flatten 类似,都是用于展平 tensor 的,但是torch.flatten默认是从0开始的。
torch.flatten(t, start_dim=0, end_dim=-1)
t表示的时要展平的tensor,start_dim是开始展平的维度,end_dim是结束展平的维度
这里只举一个例子,其余与torch.nn.Flatten()是一样的。

import torch
a = torch.randn(8,3,64,64)
F = torch.flatten(a)

a的大小:
torch.Size([8, 3, 64, 64])

F的大小(默认从第0维展平):
torch.Size([98304])