PyTorch を使っていると、はるか遠く離れた場所で計算した結果に nan や inf が含まれることで、思いもよらない場所から非直感的なエラーを生じることがある。 あるいは、自動微分したときにゼロ除算が生じるようなパターンでは、順伝搬の結果だけ見ていても原因にたどり着くことが難しい。 こういった問題は、デバッガなどを使って地道に原因を探ろうとすると多くの手間と時間がかかる。
そんな折、PyTorch にはそうした問題に対処する上で有益な機能があることを知った。 具体的には、以下の関数を使うと自動でバックプロパゲーションが上手くいかない箇所を見つけることができる。 今回は、この機能について書いてみる。
torch.autograd.set_detect_anomaly()
torch.autograd.detect_anomaly()
使った環境は次のとおり。
$ sw_vers ProductName: macOS ProductVersion: 12.6.2 BuildVersion: 21G320 $ python -V Python 3.10.9 $ pip list | grep -i torch torch 1.13.1
もくじ
- もくじ
- 下準備
- 入力によって逆伝搬が上手くいかないコード (1)
- 上手くいかない箇所を自動で見つける
- 問題を修正する (1)
- 入力によって逆伝搬が上手くいかないコード (2)
- 問題を修正する (2)
- 特定のスコープでチェックする
- まとめ
- 参考
下準備
あらかじめ PyTorch と NumPy をインストールしておく。
$ pip install torch numpy
入力によって逆伝搬が上手くいかないコード (1)
例として RMSLE (Root Mean Squared Logarithmic Error) を計算する場合について考える。
以下のサンプルコードでは、RMSLE の計算を RMSLELoss
というクラスで実装している。
このコードは入力によってはバックプロパゲーションが上手くいかない。
具体的には、入力されるモデルの予測 (y_pred
) に -1
以下の値が含まれたとき torch.log1p()
の返り値に inf
が含まれる。
import torch from torch import nn class RMSLELoss(nn.Module): """Root Mean Squared Logarithmic Error""" def __init__(self): super().__init__() self.mse_loss = nn.MSELoss() def forward(self, y_pred, y_true): # ここで y_pred に -1 以下の値が含まれると順伝搬の返り値が -inf になる log_y_pred = torch.log1p(y_pred) log_y_true = torch.log1p(y_true) # 入力に -inf が含まれることで返り値が inf になってしまう msle = self.mse_loss(log_y_pred, log_y_true) rmsle_loss = torch.sqrt(msle) return rmsle_loss def main(): # モデルの出力に -1 以下の値が含まれるとする y_pred = torch.tensor([-1., 0., 1.], dtype=torch.float64, requires_grad=True) y_true = torch.tensor([2., 3., 4.], dtype=torch.float64, requires_grad=True) # 順伝搬 loss_fn = RMSLELoss() out = loss_fn(y_pred, y_true) # 結果 print(out) # 逆伝搬 out.backward() # 勾配 print(y_pred.grad) print(y_true.grad) if __name__ == '__main__': main()
上記を実行してみよう。
入力に -1
以下の値が入ると、最終的な結果が inf
になっている。
そして y_pred
と y_true
の勾配に nan
が確認できる。
tensor(inf, dtype=torch.float64, grad_fn=<SqrtBackward0>) tensor([nan, -0., -0.], dtype=torch.float64) tensor([nan, 0., 0.], dtype=torch.float64)
上手くいかない箇所を自動で見つける
では、今回の主題となる torch.autograd.set_detect_anomaly()
を使ってみよう。
この関数には、第一引数に真偽値のフラグを渡して機能の有効・無効を切り替える。
もちろんデフォルトでは機能は無効となっており、デバッグをするときだけ有効にすることが推奨されている。
これは、機能を有効にするとバックプロパゲーションにおいて値のチェックが逐一入ることによるオーバーヘッドが生じるため。
import torch from torch import nn class RMSLELoss(nn.Module): def __init__(self): super().__init__() self.mse_loss = nn.MSELoss() def forward(self, y_pred, y_true): log_y_pred = torch.log1p(y_pred) log_y_true = torch.log1p(y_true) msle = self.mse_loss(log_y_pred, log_y_true) rmsle_loss = torch.sqrt(msle) return rmsle_loss def main(): # 逆伝搬で nan になるケースを自動で見つける torch.autograd.set_detect_anomaly(True) y_pred = torch.tensor([-1., 0., 1.], dtype=torch.float64, requires_grad=True) y_true = torch.tensor([2., 3., 4.], dtype=torch.float64, requires_grad=True) loss_fn = RMSLELoss() out = loss_fn(y_pred, y_true) print(out) out.backward() print(y_pred.grad) print(y_true.grad) if __name__ == '__main__': main()
上記を実行してみよう。
すると、MseLossBackward0
の結果において値に nan が含まれることが示されている。
$ python anodet.py tensor(inf, dtype=torch.float64, grad_fn=<SqrtBackward0>) /Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/autograd/__init__.py:197: UserWarning: Error detected in MseLossBackward0. Traceback of forward call that caused the error: File "/Users/amedama/Documents/temporary/anodet.py", line 35, in <module> main() File "/Users/amedama/Documents/temporary/anodet.py", line 27, in main out = loss_fn(y_pred, y_true) File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, **kwargs) File "/Users/amedama/Documents/temporary/anodet.py", line 14, in forward msle = self.mse_loss(log_y_pred, log_y_true) File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, **kwargs) File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/modules/loss.py", line 536, in forward return F.mse_loss(input, target, reduction=self.reduction) File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/functional.py", line 3292, in mse_loss return torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction)) File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/fx/traceback.py", line 57, in format_stack return traceback.format_stack() (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:119.) Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass Traceback (most recent call last): File "/Users/amedama/Documents/temporary/anodet.py", line 35, in <module> main() File "/Users/amedama/Documents/temporary/anodet.py", line 31, in main out.backward() File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/_tensor.py", line 488, in backward torch.autograd.backward( File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: Function 'MseLossBackward0' returned nan values in its 0th output.
このように、バックプロパゲーションが上手くいかない場所を自動で検知できた。
問題を修正する (1)
では、問題を修正するため試しに y_pred
の値の下限が 0
となるように torch.clamp()
の処理を挟んでみよう。
import torch from torch import nn class RMSLELoss(nn.Module): def __init__(self): super().__init__() self.mse_loss = nn.MSELoss() def forward(self, y_pred, y_true): # 入力の下限を 0. に制限する clamped_y_pred = torch.clamp(y_pred, min=0.) log_y_pred = torch.log1p(clamped_y_pred) log_y_true = torch.log1p(y_true) msle = self.mse_loss(log_y_pred, log_y_true) rmsle_loss = torch.sqrt(msle) return rmsle_loss def main(): torch.autograd.set_detect_anomaly(True) y_pred = torch.tensor([-1., 0., 1.], dtype=torch.float64, requires_grad=True) y_true = torch.tensor([2., 3., 4.], dtype=torch.float64, requires_grad=True) loss_fn = RMSLELoss() out = loss_fn(y_pred, y_true) print(out) out.backward() if __name__ == '__main__': main()
実行すると、今度は例外にならずに済んでいる。
y_pred
と y_true
の勾配にも nan
は登場しない。
$ python anodet.py tensor(1.1501, dtype=torch.float64, grad_fn=<SqrtBackward0>) tensor([0.0000, -0.4018, -0.1328], dtype=torch.float64) tensor([0.1061, 0.1004, 0.0531], dtype=torch.float64)
入力によって逆伝搬が上手くいかないコード (2)
さて、これで万事解決かと思いきや、実はまだ問題が残っている。
損失がゼロになるときを考えると torch.sqrt()
のバックプロパゲーションにおいてゼロ除算が生じるため。
これは順伝搬では値に nan や inf が登場しないことから問題に気づきにくそう。
import torch from torch import nn class RMSLELoss(nn.Module): def __init__(self): super().__init__() self.mse_loss = nn.MSELoss() def forward(self, y_pred, y_true): clamped_y_pred = torch.clamp(y_pred, min=0.) log_y_pred = torch.log1p(clamped_y_pred) log_y_true = torch.log1p(y_true) # 損失がゼロのときは...? msle = self.mse_loss(log_y_pred, log_y_true) rmsle_loss = torch.sqrt(msle) return rmsle_loss def main(): torch.autograd.set_detect_anomaly(True) # 損失がゼロになるパターンを考えてみると...? y_pred = torch.tensor([1., 2., 3.], dtype=torch.float64, requires_grad=True) y_true = torch.tensor([1., 2., 3.], dtype=torch.float64, requires_grad=True) loss_fn = RMSLELoss() out = loss_fn(y_pred, y_true) print(out) out.backward() print(y_pred.grad) print(y_true.grad) if __name__ == '__main__': main()
実行すると、今度も MseLossBackward0
において結果に nan が含まれると指摘されている。
$ python anodet.py tensor(0., dtype=torch.float64, grad_fn=<SqrtBackward0>) /Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/autograd/__init__.py:197: UserWarning: Error detected in MseLossBackward0. Traceback of forward call that caused the error: File "/Users/amedama/Documents/temporary/anodet.py", line 36, in <module> main() File "/Users/amedama/Documents/temporary/anodet.py", line 28, in main out = loss_fn(y_pred, y_true) File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, **kwargs) File "/Users/amedama/Documents/temporary/anodet.py", line 15, in forward msle = self.mse_loss(log_y_pred, log_y_true) File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, **kwargs) File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/modules/loss.py", line 536, in forward return F.mse_loss(input, target, reduction=self.reduction) File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/functional.py", line 3292, in mse_loss return torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction)) File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/fx/traceback.py", line 57, in format_stack return traceback.format_stack() (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:119.) Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass Traceback (most recent call last): File "/Users/amedama/Documents/temporary/anodet.py", line 36, in <module> main() File "/Users/amedama/Documents/temporary/anodet.py", line 32, in main out.backward() File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/_tensor.py", line 488, in backward torch.autograd.backward( File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: Function 'MseLossBackward0' returned nan values in its 0th output.
問題を修正する (2)
では、先ほどの問題を修正するために torch.sqrt()
の計算に小さな値を足してみよう。
import torch from torch import nn class RMSLELoss(nn.Module): def __init__(self, epsilon=1e-5): super().__init__() self.mse_loss = nn.MSELoss() self.epsilon = epsilon def forward(self, y_pred, y_true): clamped_y_pred = torch.clamp(y_pred, min=0.) log_y_pred = torch.log1p(clamped_y_pred) log_y_true = torch.log1p(y_true) msle = self.mse_loss(log_y_pred, log_y_true) # ゼロ除算が生じないように小さな値を足す rmsle_loss = torch.sqrt(msle + self.epsilon) return rmsle_loss def main(): torch.autograd.set_detect_anomaly(True) y_pred = torch.tensor([1., 2., 3.], dtype=torch.float64, requires_grad=True) y_true = torch.tensor([1., 2., 3.], dtype=torch.float64, requires_grad=True) loss_fn = RMSLELoss() out = loss_fn(y_pred, y_true) print(out) out.backward() print(y_pred.grad) print(y_true.grad) if __name__ == '__main__': main()
実行すると、今度は例外にならない。
$ python anodet.py tensor(0.0032, dtype=torch.float64, grad_fn=<SqrtBackward0>) tensor([0., 0., 0.], dtype=torch.float64) tensor([0., 0., 0.], dtype=torch.float64)
特定のスコープでチェックする
ちなみに、特定のスコープでだけ backward()
の結果をチェックしたいときは torch.autograd.detect_anomaly()
が使える。
これはコンテキストマネージャになっているため、チェックしたい部分にだけ入れて使うことができる。
import torch from torch import nn class RMSLELoss(nn.Module): def __init__(self): super().__init__() self.mse_loss = nn.MSELoss() def forward(self, y_pred, y_true): log_y_pred = torch.log1p(y_pred) log_y_true = torch.log1p(y_true) msle = self.mse_loss(log_y_pred, log_y_true) rmsle_loss = torch.sqrt(msle) return rmsle_loss def main(): y_pred = torch.tensor([-1., 0., 1.], dtype=torch.float64, requires_grad=True) y_true = torch.tensor([2., 3., 4.], dtype=torch.float64, requires_grad=True) loss_fn = RMSLELoss() out = loss_fn(y_pred, y_true) print(out) # 特定のスコープでチェックする with torch.autograd.detect_anomaly(): out.backward() if __name__ == '__main__': main()
とはいえ、そんなに出番は無さそうかな。
また、トレースバックに含まれる情報も torch.autograd.set_detect_anomaly()
より少なくなっているようだ。
$ python anodet.py tensor(inf, dtype=torch.float64, grad_fn=<SqrtBackward0>) /Users/amedama/Documents/temporary/anodet.py:29: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging. with torch.autograd.detect_anomaly(): /Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/autograd/__init__.py:197: UserWarning: Error detected in MseLossBackward0. No forward pass information available. Enable detect anomaly during forward pass for more information. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:97.) Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass Traceback (most recent call last): File "/Users/amedama/Documents/temporary/anodet.py", line 34, in <module> main() File "/Users/amedama/Documents/temporary/anodet.py", line 30, in main out.backward() File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/_tensor.py", line 488, in backward torch.autograd.backward( File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: Function 'MseLossBackward0' returned nan values in its 0th output.
まとめ
今回は PyTorch でバックプロパゲーションが上手くいかない場所を自動で見つけることのできる機能を試してみた。