Python: LIME (Local Interpretable Model Explanations) を LightGBM と使ってみる - CUBE SUGAR CONTAINER

CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: LIME (Local Interpretable Model Explanations) を LightGBM と使ってみる

今回は、機械学習モデルの解釈可能性を向上させる手法のひとつである LIME (Local Interpretable Model Explanations) を LightGBM と共に使ってみる。 LIME は、大局的には非線形なモデルを、局所的に線形なモデルを使って近似することで、予測の解釈を試みる手法となっている。

今回使った環境は次のとおり。

$ sw_vers                            
ProductName:    Mac OS X
ProductVersion: 10.15.7
BuildVersion:   19H2
$ python -V                                      
Python 3.8.5

もくじ

下準備

まずは、下準備として使うパッケージをインストールしておく。

$ pip install lime scikit-learn lightgbm jupyterlab

LIME は Jupyter の WebUI に可視化する API を提供している。 そのため、今回は Jupyter Lab 上でインタラクティブに試していくことにしよう。

$ jupyter lab

Boston データセットを LightGBM で学習させる

とりあえず、LIME を使うにしても学習済みモデルがないと話が始まらない。 そこで、まずは scikit-learn から Boston データセットを読み込む。

>>> from sklearn import datasets
>>> dataset = datasets.load_boston()
>>> train_x, train_y = dataset.data, dataset.target
>>> feature_names = dataset.feature_names

続いて LightGBM の Early Stopping で検証用データにするためにデータセットを分割しておく。

>>> from sklearn.model_selection import train_test_split
>>> tr_x, val_x, tr_y, val_y = train_test_split(train_x, train_y,
...                                             shuffle=True,
...                                             random_state=42,
...                                            )

それぞれを LightGBM のデータ表現にする。

>>> import lightgbm as lgb
>>> lgb_train = lgb.Dataset(tr_x, tr_y)
>>> lgb_val = lgb.Dataset(val_x, val_y, reference=lgb_train)

上記のデータセットを LightGBM で回帰タスクとして学習させる。

>>> lgbm_params = {
...     'objective': 'regression',
...     'metric': 'rmse',
...     'verbose': -1,
... }
>>> booster = lgb.train(lgbm_params,
...                     lgb_train,
...                     valid_sets=lgb_val,
...                     num_boost_round=1_000,
...                     early_stopping_rounds=100,
...                     verbose_eval=50,
...                     )
Training until validation scores don't improve for 100 rounds
[50]  valid_0's rmse: 3.41701
[100]  valid_0's rmse: 3.30722
[150] valid_0's rmse: 3.24115
[200]  valid_0's rmse: 3.22073
[250] valid_0's rmse: 3.2121
[300]  valid_0's rmse: 3.21216
[350] valid_0's rmse: 3.21811
Early stopping, best iteration is:
[271]  valid_0's rmse: 3.20437

これで LightGBM の学習済みモデルが手に入った。

LIME を使って局所的な解釈を得る

続いては学習モデルの予測を LIME で解釈してみよう。

今回使ったのは構造化されたテーブルデータなので LimeTabularExplainer を用いる。 このクラスに学習データやタスクの内容といった情報を渡してインスタンスを作る。 なお、LIME 自体はテーブルデータ以外にも自然言語処理や画像認識など幅広い応用が効くらしい。

>>> from lime.lime_tabular import LimeTabularExplainer
>>> explainer = LimeTabularExplainer(training_data=tr_x,
...                                  feature_names=feature_names,
...                                  training_labels=tr_y,
...                                  mode='regression',
...                                  verbose=True,
...                                  )

上記を使うには、ひとつの特徴ベクトルを受け取って予測値を返す関数が必要になる。 そこで、今回は次のように定義しておく。

>>> predict_func = lambda x: booster.predict(x,
...                                          num_iteration=booster.best_iteration)

上記を使って、試しに学習データの先頭要素に対する予測を解釈してみよう。 ここでは、線形モデルを使って近似させたときの切片と予測値、そして本来のモデルの予測値が出力される。

>>> explanation = explainer.explain_instance(tr_x[0], predict_func)
Intercept 20.92723014344181
Prediction_local [37.81548009]
Right: 37.949401123930926

次のようにすると、結果が Notebook 上で視覚的に確認できる。

>>> explanation.show_in_notebook(show_table=False)

結果からは、予測されるレンジの中で要素がどこにあるか、そして各特徴量が線形モデルでどのように作用しているかがわかる。 たとえば、近似した線形モデルでは特徴量の RM6.66 以上なので、予測値を 8.71 押し上げる効果があるようだ。

f:id:momijiame:20201009235425p:plain
LIME による予測の解釈

ちなみに、学習データの先頭要素の中身はこんな感じ。

>>> import prettyprint
>>> pprint(dict(zip(feature_names, tr_x[0])))
{'CRIM': 0.09103,
 'ZN': 0.0,
 'INDUS': 2.46,
 'CHAS': 0.0,
 'NOX': 0.488,
 'RM': 7.155,
 'AGE': 92.2,
 'DIS': 2.7006,
 'RAD': 3.0,
 'TAX': 193.0,
 'PTRATIO': 17.8,
 'B': 394.12,
 'LSTAT': 4.82}

条件と、予測値に対してどのように作用しているかは as_list() メソッドで得られる。

>>> pprint(explanation.as_list())
[('RM > 6.66', 8.711363613654798),
 ('LSTAT <= 6.87', 7.923627881873513),
 ('TAX <= 279.00', 1.018446806095536),
 ('78.10 < AGE <= 93.85', -0.5986522163835298),
 ('0.45 < NOX <= 0.54', 0.5314517960646176),
 ('2.08 < DIS <= 3.11', 0.4316641373927205),
 ('16.60 < PTRATIO <= 18.60', 0.29967804190197933),
 ('INDUS <= 5.13', 0.21697123227933096),
 ('0.08 < CRIM <= 0.27', -0.17672992076030344),
 ('RAD <= 4.00', -0.16478059886664842)]

たとえば切片と上記の要素をすべて足すと、最初に得られた線形モデルで近似した予測値になる。

>>> sum(value for key, value in explanation.as_list()) + explanation.intercept[0]
38.33991784140311

上記はあくまで線形モデルを使った近似なので、本来のモデルのアルゴリズムにもとづいて説明しているわけではないはず。 とはいえ、局所的に特徴量がどのように作用しているか確認できるのはなかなか面白い。

所感としては、大局的な解釈にも応用が効く SHAP の方が使い勝手は良さそうかな。

blog.amedama.jp

参考

github.com

arxiv.org