scaler = GradScaler()
for features, target in data:
# Forward pass with mixed precision
with torch.cuda.amp.autocast(): # autocast as a context manager
output = model(features)
loss = criterion(output, target)
# Backward pass without mixed precision
# It's not recommended to use mixed precision for backward pass
# Because we need more precise loss
scaler.scale(loss).backward()
# scaler.step() first unscales the gradients .
# If these gradients contain infs or NaNs,
# optimizer.step() is skipped.
scaler.step(optimizer)
# If optimizer.step() was skipped,
# scaling factor is reduced by the backoff_factor in GradScaler()
scaler.update()
autocast自动应用精度到不同的操作。因为损失和梯度是按照float16精度计算的,当它们太小时,梯度可能会“下溢”并变成零。
GradScaler通过将损失乘以一个比例因子来防止下溢,根据比例损失计算梯度,然后在优化器更新权重之前取消梯度的比例。如果缩放因子太大或太小,并导致inf或NaN,则缩放因子将在下一个迭代中更新缩放因子。