Pytorch 剪枝操作实现

首先需要版本为 1.4 以上,

目前很多模型都取得了十分好的结果, 但是还是参数太多, 占得权重太大, 所以我们的目标是得到一个稀疏的子系数矩阵.

这个例子是基于 LeNet 的 Pytorch 实现的例子, 我们从 CNN 的角度来剪枝, 其实在全连接层与 RNN 的剪枝应该是类似, 首先导入一些必要的模块

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

然后是 LeNet 的网络结构, 不知道为什么这里的网络结构是这样的, 算出来输入的图像是 26x26 的,

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        # 第一个卷积层, 输出的向量维度是 6
        self.conv2 = nn.Conv2d(6, 16, 3)
        # 第二个卷积层, 输出的向量维度是 16
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        # 最后将二维向量变成一维
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # 2*2 的池化层
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        # relu 激活函数层
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        # 除以 batch_size 的大小, 将维度变成一
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

这时查看模型的参数:

module = model.conv1
print(list(module.named_parameters()))

此时参数包含矩阵的权值与偏置.

为了剪枝一个模型, 首先要在 torch.nn.utils.prune 中选择一种剪枝方法, 或者使用子类 BasePruningMethod 实现自己的剪枝方法, 然后确定模型以及需要减去的参数, 最后,使用所选修剪技术所需的适当关键字参数,指定修剪参数. 在下面的例子中, 我们将要随机减去 conv1 层中的 30% 的权重参数, module 是函数的第一个参数, name 使用的是参数的字符串标识, amount 表示剪枝的百分比.

prune.random_unstructured(module, name="weight", amount=0.3)

剪枝行为将 weight 参数名称删除, 并将其替代为新的参数名称, weight_orig , weight_orig存储未修剪的张量版本. 也就是说 weight_orig 是原来的权重,

上述的剪枝方法会产生一个 mask 矩阵, 叫做 weight_mask , 存储为一个 module buffer , 相当于一个 mask矩阵, 他的维度与 weight 的维度相同, 不同的是 mask 矩阵是一个 0/1 矩阵. 可以通过下面的函数查看 mask 矩阵:

print(list(module.named_buffers()))

剪枝之后的权重属性 weight 不再是权重的集合, 而是 mask 矩阵与原始矩阵的结合, 所以不再是模型的一个 parameter, 而是一个 attribute.

最后,使用 PyTorch 的forward_pre_hooks在每次正向传递之前应用修剪。具体来说,如我们在此处所做的那样,在剪枝模块部分,它将为与之相关的每个要修剪的参数获取一个forward_pre_hook。目前为止我们只修剪了名为weight的原始参数,因此将只存在一个 forward_pre_hook, 相当于没有一个剪枝参数就有一个 forward_pre_hook.

除了对 weight 剪枝, 还可以对 bias 剪枝, 下面是通过 L1 范式剪去三个单元

prune.l1_unstructured(module, name="bias", amount=3)
# Prunes tensor corresponding to parameter called name in module by removing the specified amount of (currently unpruned) units with the lowest L1-norm.

Iterative Pruning

相同的参数在一个模型中可以被多次剪枝, 相当于把多个剪枝核序列化成一个剪枝核, 新的 mask 矩阵与旧的 mask 矩阵的结合使用 PruningContainer 中的 compute_mask 方法. 比如在上面的 module 的 weight 中, 我们除了随机剪枝外还可以通过范式剪枝, 下面是个例子:

prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
# As we can verify, this will zero out all the connections corresponding to 
# 50% (3 out of 6) of the channels, while preserving the action of the 
# previous mask.
# 这里的 n 表示剪枝的范式, dim = 0, 表示参数矩阵的维度, 这里卷积层的 dim= 0, 就是核的个数
print(module.weight)

剪完之后, 核的个数变成原来的一半. mask 矩阵也会自动叠加.

还可以通过下面的方法查看我们使用了哪些方法剪枝, hook 记录了某个 attribute 的剪枝方法:

for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))  # pruning history in the container

Serializing a pruned model

所有相关的张量,包括掩码缓冲区和用于计算修剪的张量的原始参数,都存储在模型的 state_dict 中,因此可以根据需要轻松地序列化和保存.

我们可以通过下面的方法查看模型中的权重参数:

>> print(model.state_dict().keys())
>> odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

Remove pruning re-parametrization

注意, 这里的删除剪枝的意思并不是真正的删除, 还原到未剪枝的状态. 举个例子, 剪枝之后, 我们的参数 parameters 中的 weight 会变成, 'weight_orig', 而 weight 变成一个属性, 他是 'weight_orig' 与 mask 矩阵结合后的结果, 那么

prune.remove(module, 'weight')

之后会发生什么呢?

print(list(module.named_parameters()))
('weight', Parameter containing:
tensor([[[[-0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000]]],
.......

也就是说, weight 又变成了 parameters, 剪枝变成永久化.

Pruning multiple parameters in a model

多个参数, 多个网络结构的剪枝,

new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers 
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
        # 将所有卷积层的权重减去 20%
    # prune 40% of connections in all linear layers 
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)
        # 将所有全连接层的权重减去 40%

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

Global pruning

之前的剪枝我们都是针对每一层每一层的剪枝, 减去某一层权重的百分比, 对于全局剪枝就是将模型的参数看成一个整体, 减去一部分参数, 对于每一层减去的比例可能不同.

剪枝的方法可以通过下面的方法:

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

使用自定义的方法剪枝

要实现自己的修剪功能,可以通过将 BasePruningMethod 基类作为子类来扩展 nn.utils.prune 模块,就像其他所有修剪方法一样. 基类以及完成了下面的方法:

__call__, apply_mask, apply, prune, and remove

除了一些特殊的情况, 你不需要重写这些方法以实现新的剪枝方法. 你需要实现的是:

  1. __init__ 构造器
  2. compute_mask 如何根据剪枝策略的逻辑为给定张量计算 mask
  3. 需要说明是全局剪枝, 还是结构剪枝, 或者是非结构剪枝, 这决定了在迭代剪枝是如何结合 mask 矩阵, 换句话说,当剪枝需要剪枝的参数时,当前的剪枝策略应作用于参数的未剪枝部分。指定 PRUNING_TYPE 将启用 PruningContainer 正确识别要修剪的参数的范围.

比如说, 当我们希望剪枝一个张量中除了某一参数外的所有其他参数的时候, 或者说这个张量已经被部分剪枝的时候, 我们就需要设置: PRUNING_TYPE='unstructured' 因为他只是单独作用与一层, 而不是一个单元或者通道(对应于'structured'), 也不是作用于整个参数(对应于'global')

class FooBarPruningMethod(prune.BasePruningMethod):
    # 继承自基类 BasePruningMethod
    """Prune every other entry in a tensor
    """
    PRUNING_TYPE = 'unstructured'
    # 类型为 unstructured 类型

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0 
        # 定义了 mask 矩阵的构成方法, 每两个数字一个 0
        return mask

然后给出一个调用的例子:

def foobar_unstructured(module, name):
    """Prunes tensor corresponding to parameter called `name` in `module`
    by removing every other entry in the tensors.
    Modifies module in place (and also return the modified module) 
    by:
    1) adding a named buffer called `name+'_mask'` corresponding to the 
    binary mask applied to the parameter `name` by the pruning method.
    The parameter `name` is replaced by its pruned version, while the 
    original (unpruned) parameter is stored in a new parameter named 
    `name+'_orig'`.

    Args:
        module (nn.Module): module containing the tensor to prune
        name (string): parameter name within `module` on which pruning
                will act.

    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input
            module
    
    Examples:
        >>> m = nn.Linear(3, 4)
        >>> foobar_unstructured(m, name='bias')
    """
    FooBarPruningMethod.apply(module, name)
    return module

model = LeNet()
foobar_unstructured(model.fc3, name='bias')

print(model.fc3.bias_mask)

以上就是Pytorch 剪枝的主要方法, 其实对于复杂的剪枝方法, 只要在 compute_mask 设置特殊的 mask 构成方法就可以了.