学习总结
一、应用场景
栗子:torch.nn
只支持小批量处理 (mini-batches)。整个 torch.nn
包只支持小批量样本的输入,不支持单个样本的输入。比如,nn.Conv2d
接受一个4维的张量,即nSamples x nChannels x Height x Width
,如果是一个单独的样本,只需要使用input.unsqueeze(0)
来添加一个“假的”批大小维度。
PS:pytorch中,处理图片必须一个batch一个batch的操作,所以我们要准备的数据的格式是 [batch_size, n_channels, hight, width]
。
二、升维和降维
降维:squeeze(input, dim = None, out = None)
函数
(1)在不指定dim时,张量中形状为1的所有维都会除去。如input为(A, 1, B, 1, C, 1, D),output为(A, B, C, D)。
(2)如果要指定dim,降维操作只能在给定的维度上,如input为(A, 1, B)时:
错误用法:squeeze(input, dim = 0)
会发现shape没变化,如下:
结果为:
torch.unsqueeze
有两种写法:
结果为如下,即对第一维度升高后,b从a =【4】变为b=【【1, 2, 3,4】】:
Reference
https://zhuanlan.zhihu.com/p/86763381