在阅读以下内容前,请务必先大致了解计算图机制,特别是叶子节点:pytorch——计算图与动态图机制
with torch.no_grad()
在 PyTorch官网 中的定义为:
Context-manager that disables gradient calculation.
意思是with torch.no_grad()
是一个用于禁用梯度的上下文管理器。禁用梯度计算对于推理是很有用的,当我们确定不会调用 Tensor.backward()
时,它将减少计算的内存消耗。因为在此模式下,即使输入为 requires_grad=True
,每次计算的结果也将具有requires_grad=False
。
总的来说, with torch.no_grad()
可以理解为,在管理器外产生的与原参数有关联的参数requires_grad
属性都默认为True
,而在该管理器内新产生的参数的requires_grad
属性都将置为False
。
除此之外,with torch.no_grad()
还通常与原地操作(in-place operation)组合在一起。原地操作有明确定义:
对于 requires_grad=True 的叶子张量(leaf tensor)不能使用 inplace operation
因为原地操作会覆盖当前内存的值,但叶子节点所指向的内存块进行无法进行修改操作,否则会导致其中梯度信息与节点的值不再有计算上的对应关系。
with torch.no_grad():
for param in params:
param -= lr * param.grad / batch_size
param.grad.zero_() # 清空当前梯度
于是我们针对以上操作进行探究,以更好理解该情况下with torch.no_grad()
的作用。
- 不使用
with torch.no_grad()
进行原地操作
for param in params:
param -= lr * param.grad / batch_size
param.grad.zero_() # 清空当前梯度
运行上面的代码会报错,错误信息为RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
意思是在原地操作中使用了需要梯度的叶子节点。
如果你有意验证有无 with torch.no_grad() 进行原地操作的两种情况下 param 的 requires_grad 属性,你会发现其值都为True。那么可能有人会有疑问,影响原地操作的定义不就是 requires_grad 属性吗。那么你需要做的相信定义,并理解以下两层:
- lr * param.grad / batch_size 会创建一块临时内存,这块临时内存的 requires_grad 属性是 False
- param.grad 也会占用一块内存,其也具有 requires_grad 属性,且为 False
- 不使用
with torch.no_grad()
进行赋值操作
for param in params:
param = param - lr * param.grad / batch_size
print(param.is_leaf) # False
param.grad.zero_() # 清空当前梯度
运行上面的代码会报错,错误信息为AttributeError: 'NoneType' object has no attribute 'zero_'
。我们都知道赋值操作会新创建一块内存以存放数据,所以根据计算图理论,此时的param
是中间节点,不再是叶子节点,不具有grad
属性了。
- 使用
with torch.no_grad()
进行赋值操作
with torch.no_grad():
for param in params:
print(param.requires_grad) # True
param = param - lr * param.grad / batch_size
print(param.is_leaf) # True
print(param.requires_grad) # Flase
# param.requires_grad = True
param.grad.zero_() # 清空当前梯度
运行上面的代码会报错,错误信息为AttributeError: 'NoneType' object has no attribute 'zero_'
。
我们知道在 PyTorch 中,前向传播过程中构建计算图,而反向传播时销毁计算图以释放内存并计算叶子节点的梯度信息。当我们使用 with torch.no_grad() 上下文管理器时,我们指示 PyTorch 在此上下文中不跟踪梯度信息,因此不会构建用于反向传播的计算图。尽管如此,由于在 torch.no_grad() 上下文中创建的张量(如 param)不依赖于计算图中的其他节点,它们仍然被视为叶子节点。因此,这些张量的梯度信息仍然可以被访问,但是梯度计算不会在该上下文中进行,因此在此上下文内产生的张量不会保存任何梯度信息。
在 with torch.no_grad()
上下文中,param
仍然是叶子节点。但是赋值操作会创建一个新的张量,并且这个新的张量中的requires_grad = False
。理论上,我们可以将requires_grad
重新设置为True
,然后再进行反向传播,但这样做非常麻烦且没有意义,且会导致大量的内存占用。因此通常不建议这样做。如果想要尝试,可以修改函数结构,将操作集合到一个函数内共享参数,以确保梯度追踪和梯度计算的一致性。