我们讨论了如何通过将冻结的batch normalization层与前面的卷积层融合来简化网络结构,这是实践中常见的设置,值得研究。
Introduction and motivationBatch normalization (often abbreviated as BN) is a popular method used in modern neural networks as it often reduces training time and potentially improves generalization (however, there are some controversies around it: 1, 2).
Batch Normalization(简称BN)是现代神经网络中常用的一种方法,可以减少训练时间并有可能提高模型的泛化能力(但它也有一些争议:1,2)。
Today’s state-of-the-art image classifiers incorporate batch normalization (ResNets, DenseNets).
当今最先进的图像分类器都结合了Batch Normalization(ResNets, DenseNets)。
During runtime (test time, i.e., after training), the functinality of batch normalization is turned off and the approximated per-channel mean μ and variance _σ_2 are used instead. This restricted functionality can be implemented as a convolutional layer or, even better, merged with the preceding convolutional layer. This saves computational resources and simplifies the network architecture at the same time.
在运行时(训练后的测试),Batch Normalization的功能关闭,每个通道的近似均值 _μ _和方差 _σ_2 被直接使用,其实这个受限的功能可以合并为卷积层,或者直接与前面的卷积层合并。这样可以节省计算资源,同时简化网络体系结构。
Let x be a signal (activation) within the network that we want to normalize. Given a set of such signals _x_1, _x_2,…, x__n coming from processing different samples within a batch, each is normalized as follows: x i ^ = γ x i − μ σ 2 + ϵ + β \hat{x_{i}}=\gamma\frac{x_{i}-\mu}{\sqrt{\sigma^{2}+\epsilon}}+\beta xi^=γσ2+ϵ xi−μ+β。
让 x 成为我们想要 normalize 的网络中的一个信号(激活)。给定一组这样的信号 _x_1, _x_2,…, x__n 为在一个 batch 中处理的不同样本,每个元素的标准化如下: x i ^ = γ x i − μ σ 2 + ϵ + β \hat{x_{i}}=\gamma\frac{x_{i}-\mu}{\sqrt{\sigma^{2}+\epsilon}}+\beta xi^=γσ2+ϵ xi−μ+β。
The values μ and _σ_2 are the mean and variance computed over a batch, ϵ is a small constant included for numerical stability, γ is the scaling factor and β the shift factor.
μ 和 _σ_2 是计算出一个 batch 的均值和方差,_ϵ _是一个小微常量(包括数据值稳定性),_γ _是标度因子,β 是位移因子。
During training, μ and σ are recomputed for each batch: μ = 1 n ∑ x i \mu=\frac{1}{n}\sum{x_{i}} μ=n1∑xi, σ 2 = 1 n ∑ ( x i − μ ) 2 \sigma^{2}=\frac{1}{n}\sum{(x_{i}-\mu)^{2}} σ2=n1∑(xi−μ)2。
在训练过程中,对每个 batch 都要重新计算 μ 和 σ : μ = 1 n ∑ x i \mu=\frac{1}{n}\sum{x_{i}} μ=n1∑xi, σ 2 = 1 n ∑ ( x i − μ ) 2 \sigma^{2}=\frac{1}{n}\sum{(x_{i}-\mu)^{2}} σ2=n1∑(xi−μ)2。
The parameters γ and β are slowly learned with gradient descent together with the other parameters of the network. During test time, we usually do not run the network on a batch of images. Thus, the previously mentioned formulae for μ and σ cannot be used. Instead, we use their estimates computed during training by exponential moving average. Let us denote these approximations as μ ^ \hat{\mu} μ^ and σ ^ \hat{\sigma} σ^.
参数 _γ _和 _β _和网络的其它参数一起用梯度下降法慢慢学习。在测试时,我们通常不会在一个 batch 上运行网络。因此,前面提到的 μ 和 _σ _不能使用。相反,我们使用在训练期间使用指数移动平均法计算出的估计值。我们将这些近似表示为 μ ^ \hat{\mu} μ^和 σ ^ \hat{\sigma} σ^。
目前,卷积神经网络大多采用批量归一化的方法来处理图像。在此设置中,输入特征图的每个通道都有均值和方差估计、移位和缩放参数。对于c通道我们将这些表示为 μ,,γc和βc。
将冻结的 batch norm实现为 1×1 的卷积给定一个C×H×W顺序的特征图 F(channel,height,width),我们可以得到它的归一化版本, F ^ \hat{F} F^,通过计算每个空间位置i、j的以下矩阵向量运算:
( F ^ 1 , i , j F ^ 2 , i , j ⋮ F ^ C − 1 , i , j F ^ C , i , j ) = ( γ 1 σ ^ 1 2 + ε 0 ⋯ 0 0 γ 2 σ ^ 2 2 + ε ⋮ ⋱ ⋮ 0 ⋯ γ C − 1 σ ^ C − 1 2 + ϵ 0 0 ⋯ 0 γ C σ ^ C 2 + ε ) ⋅ ( F 1 , i , j F 2 , i , j ⋮ F C − 1 , i , j F C , i , j ) + ( β 1 − γ 1 μ ^ 1 a ^ 1 2 + ε β 2 − γ 2 μ ^ 2 σ ^ 2 2 + ϵ ⋮ β C − 1 − γ C − 1 μ ^ C − 1 σ ^ C 1 2 + ε β C − γ C μ ^ C σ ^ C 2 + ε ) \left(\begin{array}{c} \hat{F}_{1, i, j} \\ \hat{F}_{2, i, j} \\ \vdots \\ \hat{F}_{C-1, i, j} \\ \hat{F}_{C, i, j} \end{array}\right)=\left(\begin{array}{ccccc} \frac{\gamma_{1}}{\sqrt{\hat{\sigma}_{1}^{2}+\varepsilon}} & 0 & \cdots & 0 \\ 0 & \frac{\gamma_{2}}{\sqrt{\hat{\sigma}_{2}^{2}+\varepsilon}} & & & \\ \vdots & & \ddots & & \vdots \\ 0 & & \cdots & \frac{\gamma_{C-1}}{\sqrt{\hat{\sigma}_{C-1}^{2}+\epsilon}} & 0 \\ 0 & & \cdots & 0 & \frac{\gamma_{C}}{\sqrt{\hat{\sigma}_{C}^{2}+\varepsilon}} \end{array}\right) \cdot\left(\begin{array}{c} F_{1, i, j} \\ F_{2, i, j} \\ \vdots \\ F_{C-1, i, j} \\ F_{C, i, j} \end{array}\right)+\left(\begin{array}{c} \beta_{1}-\gamma_{1} \frac{\hat{\mu}_{1}}{\sqrt{\hat{a}_{1}^{2}+\varepsilon}} \\ \beta_{2}-\gamma_{2} \frac{\hat{\mu}_{2}}{\sqrt{\hat{\sigma}_{2}^{2}+\epsilon}} \\ \vdots \\ \beta_{C-1}-\gamma_{C-1} \frac{\hat{\mu}_{C-1}}{\sqrt{\hat{\sigma}_{C 1}^{2}+\varepsilon}} \\ \beta_{C}-\gamma_{C} \frac{\hat{\mu} C}{\sqrt{\hat{\sigma}_{C}^{2}+\varepsilon}} \end{array}\right) ⎝⎜⎜⎜⎜⎜⎛F^1,i,jF^2,i,j⋮F^C−1,i,jF^C,i,j⎠⎟⎟⎟⎟⎟⎞=⎝⎜⎜⎜⎜⎜⎜⎜⎜⎛σ^12+ε γ10⋮000σ^22+ε γ2⋯⋱⋯⋯0σ^C−12+ϵ γC−10⋮0σ^C2+ε γC⎠⎟⎟⎟⎟⎟⎟⎟⎟⎞⋅⎝⎜⎜⎜⎜⎜⎛F1,i,jF2,i,j⋮FC−1,i,jFC,i,j⎠⎟⎟⎟⎟⎟⎞+⎝⎜⎜⎜⎜⎜⎜⎜⎜⎜⎛β1−γ1a^12+ε μ^1β2−γ2σ^22+ϵ μ^2⋮βC−1−γC−1σ^C12+ε μ^C−1βC−γCσ^C2+ε μ^C⎠⎟⎟⎟⎟⎟⎟⎟⎟⎟⎞
用 W B N ∈ R C × C \mathbf{W}_{B N} \in \mathbb{R}^{C \times C} WBN∈RC×C和 b B N ∈ R C \mathbf{b}_{B N} \in \mathbb{R}^{C} bBN∈RC来表示上述方程的矩阵和偏差, W c o n v ∈ R C × ( C p r e v ⋅ k 2 ) \mathbf{W}_{c o n v} \in \mathbb{R}^{C \times\left(C_{p r e v} \cdot k^{2}\right)} Wconv∈RC×(Cprev⋅k2)和 b c o n v ∈ R C \mathbf{b}_{c o n v} \in \mathbb{R}^{C} bconv∈RC是在 Batch Normalization 之前的卷积层的参数,其中 C prev C_{\text {prev}} Cprev是输入卷积层的特征图 F p r e v F_{prev} Fprev的通道数, k ∗ k k * k k∗k是滤波器的大小。
给定 F p r e v F_{prev} Fprev的 k ∗ k k * k k∗k的邻域展开成 k 2 ⋅ C p r e v k^{2}·C_{prev} k2⋅Cprev的向量 f i , j f_{i,j} fi,j,我们可以将整个计算过程写成: f ^ i , j = W B N ⋅ ( W c o n v ⋅ f i , j + b c o n v ) + b B N \hat{\mathbf{f}}_{i, j}=\mathbf{W}_{B N} \cdot\left(\mathbf{W}_{c o n v} \cdot \mathbf{f}_{i, j}+\mathbf{b}_{c o n v}\right)+\mathbf{b}_{B N} f^i,j=WBN⋅(Wconv⋅fi,j+bconv)+bBN。
因此,我们可以用以下参数来将这两个层替换为单个卷积层: f i l t e r w e i g h t s , W = W B N ⋅ W c o n v filter\ weights,\ W=W_{BN}·W_{conv} filter weights, W=WBN⋅Wconv, b i a s : b = W bias:\ b=W bias: b=W。
Implementation in PyTorch在PyTorch中,每个卷积层 conv 有以下参数:
- filter weights, W: conv.weight;
- bias, b: conv.bias;
- scaling, γ: bn.weight;
- shift, β: bn.bias;
- mean estiamte, μ ^ \hat{\mu} μ^: bn.running_mean;
- variance estimate, _σ_2: bn.running_var;
- ϵ (for numerical stability): bn.eps.
下面的函数将PyTorch的 nn.Conv2d 和 nn.BatchNorm2d 作为参数,并将它们融合成一个 nn.Conv2d 层:
def fuse_conv_and_bn(conv, bn): # init fusedconv = torch.nn.Conv2d( conv.in_channels, conv.out_channels, kernel_size=conv.kernel_size, stride=conv.stride, padding=conv.padding, bias=True ) # prepare filters w_conv = conv.weight.clone().view(conv.out_channels, -1) w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var))) fusedconv.weight.copy_( torch.mm(w_bn, w_conv).view(fusedconv.weight.size()) ) # prepare spatial bias if conv.bias is not None: b_conv = conv.bias else: b_conv = torch.zeros( conv.weight.size(0) ) b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) fusedconv.bias.copy_( b_conv + b_bn ) # we're done return fusedconv
以下 code 在 ResNet-18 的前两层上测试了上述函数:
import torch import torchvision torch.set_grad_enabled(False) x = torch.randn(16, 3, 256, 256) rn18 = torchvision.models.resnet18(pretrained=True) rn18.eval() net = torch.nn.Sequential( rn18.conv1, rn18.bn1 ) y1 = net.forward(x) fusedconv = fuse_conv_and_bn(net[0], net[1]) y2 = fusedconv.forward(x) d = (y1 - y2).norm().div(y1.norm()).item() print("error: %.8f" % d)