Python: PyTorch でバックプロパゲーションが上手くいかない場所を自動で見つける - CUBE SUGAR CONTAINER

CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: PyTorch でバックプロパゲーションが上手くいかない場所を自動で見つける

PyTorch を使っていると、はるか遠く離れた場所で計算した結果に nan や inf が含まれることで、思いもよらない場所から非直感的なエラーを生じることがある。 あるいは、自動微分したときにゼロ除算が生じるようなパターンでは、順伝搬の結果だけ見ていても原因にたどり着くことが難しい。 こういった問題は、デバッガなどを使って地道に原因を探ろうとすると多くの手間と時間がかかる。

そんな折、PyTorch にはそうした問題に対処する上で有益な機能があることを知った。 具体的には、以下の関数を使うと自動でバックプロパゲーションが上手くいかない箇所を見つけることができる。 今回は、この機能について書いてみる。

  • torch.autograd.set_detect_anomaly()
  • torch.autograd.detect_anomaly()

使った環境は次のとおり。

$ sw_vers
ProductName:    macOS
ProductVersion: 12.6.2
BuildVersion:   21G320
$ python -V
Python 3.10.9
$ pip list | grep -i torch
torch                        1.13.1

もくじ

下準備

あらかじめ PyTorch と NumPy をインストールしておく。

$ pip install torch numpy

入力によって逆伝搬が上手くいかないコード (1)

例として RMSLE (Root Mean Squared Logarithmic Error) を計算する場合について考える。 以下のサンプルコードでは、RMSLE の計算を RMSLELoss というクラスで実装している。 このコードは入力によってはバックプロパゲーションが上手くいかない。 具体的には、入力されるモデルの予測 (y_pred) に -1 以下の値が含まれたとき torch.log1p() の返り値に inf が含まれる。

import torch
from torch import nn


class RMSLELoss(nn.Module):
    """Root Mean Squared Logarithmic Error"""

    def __init__(self):
        super().__init__()
        self.mse_loss = nn.MSELoss()

    def forward(self, y_pred, y_true):
        # ここで y_pred に -1 以下の値が含まれると順伝搬の返り値が -inf になる
        log_y_pred = torch.log1p(y_pred)
        log_y_true = torch.log1p(y_true)
        # 入力に -inf が含まれることで返り値が inf になってしまう
        msle = self.mse_loss(log_y_pred, log_y_true)
        rmsle_loss = torch.sqrt(msle)
        return rmsle_loss


def main():
    # モデルの出力に -1 以下の値が含まれるとする
    y_pred = torch.tensor([-1., 0., 1.], dtype=torch.float64, requires_grad=True)
    y_true = torch.tensor([2., 3., 4.], dtype=torch.float64, requires_grad=True)

    # 順伝搬
    loss_fn = RMSLELoss()
    out = loss_fn(y_pred, y_true)

    # 結果
    print(out)

    # 逆伝搬
    out.backward()

    # 勾配
    print(y_pred.grad)
    print(y_true.grad)


if __name__ == '__main__':
    main()

上記を実行してみよう。 入力に -1 以下の値が入ると、最終的な結果が inf になっている。 そして y_predy_true の勾配に nan が確認できる。

tensor(inf, dtype=torch.float64, grad_fn=<SqrtBackward0>)
tensor([nan, -0., -0.], dtype=torch.float64)
tensor([nan, 0., 0.], dtype=torch.float64)

上手くいかない箇所を自動で見つける

では、今回の主題となる torch.autograd.set_detect_anomaly() を使ってみよう。 この関数には、第一引数に真偽値のフラグを渡して機能の有効・無効を切り替える。 もちろんデフォルトでは機能は無効となっており、デバッグをするときだけ有効にすることが推奨されている。 これは、機能を有効にするとバックプロパゲーションにおいて値のチェックが逐一入ることによるオーバーヘッドが生じるため。

import torch
from torch import nn


class RMSLELoss(nn.Module):

    def __init__(self):
        super().__init__()
        self.mse_loss = nn.MSELoss()

    def forward(self, y_pred, y_true):
        log_y_pred = torch.log1p(y_pred)
        log_y_true = torch.log1p(y_true)
        msle = self.mse_loss(log_y_pred, log_y_true)
        rmsle_loss = torch.sqrt(msle)
        return rmsle_loss


def main():
    # 逆伝搬で nan になるケースを自動で見つける
    torch.autograd.set_detect_anomaly(True)

    y_pred = torch.tensor([-1., 0., 1.], dtype=torch.float64, requires_grad=True)
    y_true = torch.tensor([2., 3., 4.], dtype=torch.float64, requires_grad=True)

    loss_fn = RMSLELoss()
    out = loss_fn(y_pred, y_true)

    print(out)

    out.backward()

    print(y_pred.grad)
    print(y_true.grad)


if __name__ == '__main__':
    main()

上記を実行してみよう。 すると、MseLossBackward0 の結果において値に nan が含まれることが示されている。

$ python anodet.py       
tensor(inf, dtype=torch.float64, grad_fn=<SqrtBackward0>)
/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/autograd/__init__.py:197: UserWarning: Error detected in MseLossBackward0. Traceback of forward call that caused the error:
  File "/Users/amedama/Documents/temporary/anodet.py", line 35, in <module>
    main()
  File "/Users/amedama/Documents/temporary/anodet.py", line 27, in main
    out = loss_fn(y_pred, y_true)
  File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/amedama/Documents/temporary/anodet.py", line 14, in forward
    msle = self.mse_loss(log_y_pred, log_y_true)
  File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/modules/loss.py", line 536, in forward
    return F.mse_loss(input, target, reduction=self.reduction)
  File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/functional.py", line 3292, in mse_loss
    return torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
  File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/fx/traceback.py", line 57, in format_stack
    return traceback.format_stack()
 (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:119.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/Users/amedama/Documents/temporary/anodet.py", line 35, in <module>
    main()
  File "/Users/amedama/Documents/temporary/anodet.py", line 31, in main
    out.backward()
  File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'MseLossBackward0' returned nan values in its 0th output.

このように、バックプロパゲーションが上手くいかない場所を自動で検知できた。

問題を修正する (1)

では、問題を修正するため試しに y_pred の値の下限が 0 となるように torch.clamp() の処理を挟んでみよう。

import torch
from torch import nn


class RMSLELoss(nn.Module):

    def __init__(self):
        super().__init__()
        self.mse_loss = nn.MSELoss()

    def forward(self, y_pred, y_true):
        # 入力の下限を 0. に制限する
        clamped_y_pred = torch.clamp(y_pred, min=0.)
        log_y_pred = torch.log1p(clamped_y_pred)
        log_y_true = torch.log1p(y_true)
        msle = self.mse_loss(log_y_pred, log_y_true)
        rmsle_loss = torch.sqrt(msle)
        return rmsle_loss


def main():
    torch.autograd.set_detect_anomaly(True)

    y_pred = torch.tensor([-1., 0., 1.], dtype=torch.float64, requires_grad=True)
    y_true = torch.tensor([2., 3., 4.], dtype=torch.float64, requires_grad=True)

    loss_fn = RMSLELoss()
    out = loss_fn(y_pred, y_true)

    print(out)

    out.backward()


if __name__ == '__main__':
    main()

実行すると、今度は例外にならずに済んでいる。 y_predy_true の勾配にも nan は登場しない。

$ python anodet.py 
tensor(1.1501, dtype=torch.float64, grad_fn=<SqrtBackward0>)
tensor([0.0000, -0.4018, -0.1328], dtype=torch.float64)
tensor([0.1061, 0.1004, 0.0531], dtype=torch.float64)

入力によって逆伝搬が上手くいかないコード (2)

さて、これで万事解決かと思いきや、実はまだ問題が残っている。 損失がゼロになるときを考えると torch.sqrt() のバックプロパゲーションにおいてゼロ除算が生じるため。 これは順伝搬では値に nan や inf が登場しないことから問題に気づきにくそう。

import torch
from torch import nn


class RMSLELoss(nn.Module):

    def __init__(self):
        super().__init__()
        self.mse_loss = nn.MSELoss()

    def forward(self, y_pred, y_true):
        clamped_y_pred = torch.clamp(y_pred, min=0.)
        log_y_pred = torch.log1p(clamped_y_pred)
        log_y_true = torch.log1p(y_true)
        # 損失がゼロのときは...?
        msle = self.mse_loss(log_y_pred, log_y_true)
        rmsle_loss = torch.sqrt(msle)
        return rmsle_loss


def main():
    torch.autograd.set_detect_anomaly(True)

    # 損失がゼロになるパターンを考えてみると...?
    y_pred = torch.tensor([1., 2., 3.], dtype=torch.float64, requires_grad=True)
    y_true = torch.tensor([1., 2., 3.], dtype=torch.float64, requires_grad=True)

    loss_fn = RMSLELoss()
    out = loss_fn(y_pred, y_true)

    print(out)

    out.backward()

    print(y_pred.grad)
    print(y_true.grad)


if __name__ == '__main__':
    main()

実行すると、今度も MseLossBackward0 において結果に nan が含まれると指摘されている。

$ python anodet.py 
tensor(0., dtype=torch.float64, grad_fn=<SqrtBackward0>)
/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/autograd/__init__.py:197: UserWarning: Error detected in MseLossBackward0. Traceback of forward call that caused the error:
  File "/Users/amedama/Documents/temporary/anodet.py", line 36, in <module>
    main()
  File "/Users/amedama/Documents/temporary/anodet.py", line 28, in main
    out = loss_fn(y_pred, y_true)
  File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/amedama/Documents/temporary/anodet.py", line 15, in forward
    msle = self.mse_loss(log_y_pred, log_y_true)
  File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/modules/loss.py", line 536, in forward
    return F.mse_loss(input, target, reduction=self.reduction)
  File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/nn/functional.py", line 3292, in mse_loss
    return torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
  File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/fx/traceback.py", line 57, in format_stack
    return traceback.format_stack()
 (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:119.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/Users/amedama/Documents/temporary/anodet.py", line 36, in <module>
    main()
  File "/Users/amedama/Documents/temporary/anodet.py", line 32, in main
    out.backward()
  File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'MseLossBackward0' returned nan values in its 0th output.

問題を修正する (2)

では、先ほどの問題を修正するために torch.sqrt() の計算に小さな値を足してみよう。

import torch
from torch import nn


class RMSLELoss(nn.Module):

    def __init__(self, epsilon=1e-5):
        super().__init__()
        self.mse_loss = nn.MSELoss()
        self.epsilon = epsilon

    def forward(self, y_pred, y_true):
        clamped_y_pred = torch.clamp(y_pred, min=0.)
        log_y_pred = torch.log1p(clamped_y_pred)
        log_y_true = torch.log1p(y_true)
        msle = self.mse_loss(log_y_pred, log_y_true)
        # ゼロ除算が生じないように小さな値を足す
        rmsle_loss = torch.sqrt(msle + self.epsilon)
        return rmsle_loss


def main():
    torch.autograd.set_detect_anomaly(True)

    y_pred = torch.tensor([1., 2., 3.], dtype=torch.float64, requires_grad=True)
    y_true = torch.tensor([1., 2., 3.], dtype=torch.float64, requires_grad=True)

    loss_fn = RMSLELoss()
    out = loss_fn(y_pred, y_true)

    print(out)

    out.backward()

    print(y_pred.grad)
    print(y_true.grad)


if __name__ == '__main__':
    main()

実行すると、今度は例外にならない。

$ python anodet.py
tensor(0.0032, dtype=torch.float64, grad_fn=<SqrtBackward0>)
tensor([0., 0., 0.], dtype=torch.float64)
tensor([0., 0., 0.], dtype=torch.float64)

特定のスコープでチェックする

ちなみに、特定のスコープでだけ backward() の結果をチェックしたいときは torch.autograd.detect_anomaly() が使える。 これはコンテキストマネージャになっているため、チェックしたい部分にだけ入れて使うことができる。

import torch
from torch import nn


class RMSLELoss(nn.Module):

    def __init__(self):
        super().__init__()
        self.mse_loss = nn.MSELoss()

    def forward(self, y_pred, y_true):
        log_y_pred = torch.log1p(y_pred)
        log_y_true = torch.log1p(y_true)
        msle = self.mse_loss(log_y_pred, log_y_true)
        rmsle_loss = torch.sqrt(msle)
        return rmsle_loss


def main():
    y_pred = torch.tensor([-1., 0., 1.], dtype=torch.float64, requires_grad=True)
    y_true = torch.tensor([2., 3., 4.], dtype=torch.float64, requires_grad=True)

    loss_fn = RMSLELoss()
    out = loss_fn(y_pred, y_true)

    print(out)

    # 特定のスコープでチェックする
    with torch.autograd.detect_anomaly():
        out.backward()


if __name__ == '__main__':
    main()

とはいえ、そんなに出番は無さそうかな。 また、トレースバックに含まれる情報も torch.autograd.set_detect_anomaly() より少なくなっているようだ。

$ python anodet.py 
tensor(inf, dtype=torch.float64, grad_fn=<SqrtBackward0>)
/Users/amedama/Documents/temporary/anodet.py:29: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
  with torch.autograd.detect_anomaly():
/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/autograd/__init__.py:197: UserWarning: Error detected in MseLossBackward0. No forward pass information available. Enable detect anomaly during forward pass for more information. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:97.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/Users/amedama/Documents/temporary/anodet.py", line 34, in <module>
    main()
  File "/Users/amedama/Documents/temporary/anodet.py", line 30, in main
    out.backward()
  File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'MseLossBackward0' returned nan values in its 0th output.

まとめ

今回は PyTorch でバックプロパゲーションが上手くいかない場所を自動で見つけることのできる機能を試してみた。

参考

pytorch.org

pytorch.org