Hook函数机制是不改变函数主体,实现额外功能,像一个挂件,挂钩。正是因为PyTorch计算图动态图的机制,所以才会有Hook函数。在动态图机制的运算,当运算结束后,一些中间变量就会被释放掉,例如,特征图,非leaf节点的梯度。但是有时候,我们需要这些中间变量,所以就出现了Hook函数。我们可以使用Hook函数获取这些中间变量。
文章目录
- Hook函数
- 1、torch.Tensor.register_hook
- 2、torch.nn.Module.register_forward_hook
- 3、torch.nn.Module.register_forward_pre_hook
- 4、torch.nn.Module.register_backward_hook
- Hook函数进行特征提取
Hook函数
PyTorch提供四种Hook函数:
1、torch.Tensor.register_hook(hook)
2、torch.nn.Module.register_forward_hook
3、torch.nn.Module.register_forward_pre_hook
4、torch.nn.Module.register_backward_hook
1、torch.Tensor.register_hook
功能:注册一个反向传播hook函数,Hook函数仅一个输入参数,为张量的梯度。Hook不应修改其参数梯度值,但可以选择返回一个新的梯度,该梯度将代替grad使用。
hook(grad) -> Tensor or None
结合代码进行讲解:
import torch
# x,y 为leaf节点,也就是说,在计算的时候,PyTorch只会保留此节点的梯度值
x = torch.tensor([3.], requires_grad=True)
y = torch.tensor([5.], requires_grad=True)
# a,b均为中间值,在计算梯度时,此部分会被释放掉
a = x + y
b = x * y
c = a * b
# 新建列表,用于存储Hook函数保存的中间梯度值
a_grad = []
def hook_grad(grad):
a_grad.append(grad)
# register_hook的参数为一个函数
handle = a.register_hook(hook_grad)
c.backward()
# 只有leaf节点才会有梯度值
print('gradient:',x.grad, y.grad, a.grad, b.grad, c.grad)
# Hook函数保留下来的中间节点a的梯度
print('a_grad:', a_grad[0])
# 移除Hook函数
handle.remove()
Out:
gradient: tensor([55.]) tensor([39.]) None None None
a_grad: tensor([15.])
2、torch.nn.Module.register_forward_hook
功能:注册module的前向传播Hook函数
参数:
- module:当前网络层
- input:当前网络层输入数据
- output:当前网络层输出数据
结合代码进行讲解:
import torch
import torch.nn as nn
# 构建网网络,一个卷积层一个池化层
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x
# 初始化网络
net = Net()
# detach将张量分离
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.detach().zero_()
# 构建两个列表用于保存信息
fmap_block = []
input_block = []
def forward_hook(module, data_input, data_output):
fmap_block.append(data_output)
input_block.append(data_input)
# 注册Hook
net.conv1.register_forward_hook(forward_hook)
# 输入数据
fake_img = torch.ones((1, 1, 4, 4))
output = net(fake_img)
# 观察结果
# 卷积神经网络输出维度和结果
print("output share:{}\noutput value:{}\n".format(output.size(),output))
# 卷积神经网络Hook函数返回的结果
print("feature map share:{}\noutput value:{}\n".format(fmap_block[0].shape,fmap_block[0]))
# 输入的信息
print("input share:{}\ninput value:{}\n".format(input_block[0][0].size(),input_block[0][0]))
3、torch.nn.Module.register_forward_pre_hook
功能:注册module前向传播前的hook函数。
参数:
- module:当前网络层
- input:当前网络层输入数据
4、torch.nn.Module.register_backward_hook
功能:注册module反向传播的hook函数。
参数:
- module:当前网络层
- grad_input:当前网络层输入梯度数据
- grad_output:当前网络层输出梯度数据
Hook函数进行特征提取
我们这里用一学就会 | LeNet在CIFAR10数据集上的应用训练好的模型来做实验。
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self):
super(LeNet,self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool1 = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
x = x.view(-1, 16*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def main():
img_path = './car.jpg'
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
img = Image.open(img_path)
img = transform(img)
img.unsqueeze_(dim=0)
# 实例化
net = LeNet()
PATH = 'cifar_net_10.pth'
# 将训练好的参数导入
net.load_state_dict(torch.load(PATH))
fmap_block = []
input_block = []
def forward_hook(module, data_input, data_output):
fmap_block.append(data_output)
input_block.append(data_input)
# 注册Hook
net.conv1.register_forward_hook(forward_hook)
net.conv2.register_forward_hook(forward_hook)
with torch.no_grad():
outputs = net(img)
print("conv1 feature map share:{}".format(fmap_block[0].shape))
print("conv2 feature map share:{}".format(fmap_block[1].shape))
if __name__ == '__main__':
main()
Out:
conv1 feature map share:torch.Size([1, 6, 28, 28])
conv2 feature map share:torch.Size([1, 16, 10, 10])