LoRA レイヤー別学習率の実装、state_dict読み込みの際のdevice指定削除、typo修正 by u-haru · Pull Request #355 · kohya-ss/sd-scripts · GitHub
Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA レイヤー別学習率の実装、state_dict読み込みの際のdevice指定削除、typo修正 #355

Merged
merged 11 commits into from
Apr 2, 2023

Conversation

u-haru
Copy link
Contributor

@u-haru u-haru commented Mar 30, 2023

LoRA レイヤー別学習率の実装

network_argsdown_lr_weight, mid_lr_weight, up_lr_weightの3つのオプションを用意します。指定された値と学習率の積を学習率とすることで、レイヤー毎に異なる学習率を適用します。また、weightが0のレイヤーはapplyしないように変更しています。現在実装してる引数は以下のようになります。
・down_lr_weight=0,0,0,0,0,1,1,1,1,1 (conv2d_3x3使用時)
・down_lr_weight=0,0.5,1 (conv2d_3x3未使用時)
・mid_lr_weight=0.5
・up_lr_weight=linear (0から1へと変化する。他にはsine, cosine, reverse_linear, zeros等)
・stratified_zero_threshold=0.1 (0.1以下のweightは0として扱う)

レイヤー構造は私自身があまり理解できていなかった為、以下のように階層化して適用しています(prefix等一部省略)。conv2d_3x3無効時はupとdownでそれぞれ3層、有効時は10層に分割しています。midは常に1層として扱います。

conv2d_3x3無効時

(1)down_blocks_0_attentions_{0,1}
(2)down_blocks_1_attentions_{0,1}
(3)down_blocks_2_attentions_{0,1}

mid_block_attentions_0
mid_block_resnets_{0,1}

(1)up_blocks_1_attentions_{0,1,2}
(2)up_blocks_2_attentions_{0,1,2}
(3)up_blocks_3_attentions_{0,1,2}

conv2d_3x3有効時

(1)down_blocks_0_resnets_{0,1}
(2)down_blocks_0_downsamplers
(3)down_blocks_0_attentions_{0,1}
(4)down_blocks_1_resnets_{0,1}
(5)down_blocks_1_downsamplers
(6)down_blocks_1_attentions_{0,1}
(7)down_blocks_2_resnets_{0,1}
(8)down_blocks_2_downsamplers
(9)down_blocks_2_attentions_{0,1}
(10)down_blocks_3_resnets_{0,1}

mid_block_attentions_0
mid_block_resnets_{0,1}

(1)up_blocks_0_resnets_{0,1,2}
(2)up_blocks_0_upsamplers
(3)up_blocks_1_attentions_{0,1,2}
(4)up_blocks_1_resnets_{0,1,2}
(5)up_blocks_1_upsamplers
(6)up_blocks_2_attentions_{0,1,2}
(7)up_blocks_2_resnets_{0,1,2}
(8)up_blocks_2_upsamplers
(9)up_blocks_3_attentions_{0,1,2}
(10)up_blocks_3_resnets_{0,1,2}

これによって、学習したい範囲(細かいディテールの部分、全体の色の塗り等)を限定して学習することが出来るようになります。
例えば、down_lr_weight=cosine mid_lr_weight=0 up_lr_weight=sine とすればモデルの浅い層をメインに学習させることができます。
逆に、down_lr_weight=sine mid_lr_weight=1 up_lr_weight=cosine とすればモデルの深い層をメインに学習させられます。

以下に私が試した結果を示します。制服を着せたキャラクターをCM3D2で用意して、"test"というトークンに紐づくように800step程度学習させたLoRA(dim: 64, conv_dim:1, lr: 1e-4, text_encoder_lr: 5e-5)です。
xyz_grid-0018-1204636062
(down_lr_weight:mid_lr_weight:up_lr_weight)という表記で書くと左から順に
デフォルト(1:1:1)、sine:1:cosine、1:0:1、cosine:0:sine、LoRA無しです。

誤差レベルかもしれませんが、デフォルトでは3Dゲームっぽい画像が生成されるようになる(ゲームのスクショを使ってるため)のに対し、深い層のlrを下げた方はデフォルトに比べて色の塗り方が維持されているように思えます。

state_dict読み込みの際のdevice指定削除

更新した際にsafetensors.torch.load_fileException: device cuda is invalidというエラーを返すようになっていたのでその修正です。

他の環境では試せていませんが、私の環境で学習ができなくなっていたことと「state_dict読み込み時はdeviceを指定する必要が無いだろう」という判断からstate_dict読み込みの際のdeviceの指定を外しています。モデルにstate_dictを読み込む際の.to(device)等はそのままなので、特に問題ないかと思われます(私の環境では問題なく動作しています)。

typo修正

繰り返し回数のないディレクトリを無視しますの表示がバグっていたので修正です。

@kohya-ss
Copy link
Owner

ありがとうございます。時間でき次第確認します。

またstate_dictを読み込む際のエラーについてもご報告ありがとうございます。私のところでは起きていないのですが、影響範囲が大きそうですので、先にmainで修正させていただくかもしれません。

@u-haru
Copy link
Contributor Author

u-haru commented Mar 31, 2023

お忙しい中ありがとうございます。

影響範囲を考えるとPullRequest分けたほうが良かったですね…
以降気をつけます。

Copy link
Owner

@kohya-ss kohya-ss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

内容について概要を確認しました。大変有意義で素晴らしいと思います。

影響範囲を考えるとPullRequest分けたほうが良かったですね…

いえ、たまたま影響が大きかったので、こちらこそ申し訳ありません。

全体を通してですが、層別学習率を適用しない場合に、 optimizer への trainable_params の指定が、以前と同じようになるようにしていただけませんでしょうか。というのも D-Adaptation optimizer を指定した場合に、lr 指定が有効になるのが最初の param group だけだからです。

またもしなにか予期せぬ不具合があった場合に、影響範囲を最小限に抑えたいという思いもあります。

また層の指定ですが、すでに広く使われている LoRA Block Weight に合わせるとわかりやすいかと思いますが、いかがでしょうか。
https://github.com/hako-mikan/sd-webui-lora-block-weight

なお、このまま適用すると生成用スクリプトが動作しなくなりますが(保存される重みが学習率が適用されたLoRAモジュールだけになるものの、そのような重みファイルを想定していないため)、これはマージ後に私の方で直します。

お忙しいようでしたら私の方でマージ後に修正しますが、適用まで少しお時間をいただくかもしれません。

よろしくお願いいたします。

Comment on lines +396 to +399
if len(skipped)>0:
print(f"stratified_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:")
for name in skipped:
print(f"\t{name}")
Copy link
Owner

@kohya-ss kohya-ss Apr 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ここで text_encoder_loras および unet_loras から該当の LoRA を削除しておいても良さそうですね。

いろいろ考慮が必要そうです。

@kohya-ss
Copy link
Owner

kohya-ss commented Apr 1, 2023

よく考えてみると、重みファイルは現在すべてのLoRAモジュールのものが含まれることが前提のため、他のスクリプトへの影響を考えると、学習率が0のLoRAモジュールについてもapplyしておき(state_dictに含まれるようになる)、重みはup/downとも0で初期化(マージ時を考慮して)、optimizerのパラメータに追加しない、という取り扱いの方が良いかもしれない、という気がしてきました。

@u-haru
Copy link
Contributor Author

u-haru commented Apr 1, 2023

ありがとうございます。層の指定については、リンク先の方を参考にして12層に分けるようにしました。
IN: 0 ~ 11
MID: 0
OUT: 0 ~ 11
(BASEはテキストエンコーダなので不要?)

  • 指定方法に関しては以前のnetwork_argsの設定をそのまま引き継いでいます。
  • レイヤー数の条件分岐が大変になるので、どのような学習設定でもフルモデル相当であると仮定して12層で処理するようにしています。
  • LoRAではIN00が使用されていない(?)ようなのでこの層の設定は効果が無いですが、わかりやすさのためにそのままにしてます。

次にD-Adaptation optimizer使用時の動作に関してですが、レイヤー別学習率を使用しない際にはparam_dataにlrを付与しないように変更しました。恐らくこれで正常な動作になっているはずです。

最後にLoRAモジュールの扱いに関してです。
削除した際の利点として学習が早くなるという点があり、可能であれば不要なパラメータを削除したほうがいいのではないかと考えています。なるべくパラメータを残したまま不要な学習をしないようにしようと色々試してみたのですが、

  • LoRANetWorkのself.add_module()が実行された時点で(apply_to()をしていなくても)学習が遅くなる
  • self.add_module()せずにapply_to()だけ実行すると学習時にエラーが発生する
  • optimizerのtrainable_paramsに渡さなくても変化しない
  • 使わないモジュールでrequires_grad_(False)しても変化しない
    といった感じでした。

学習スピードに関しては、私の環境では深い層を除いた際に1.6it/s -> 1.9it/s程度まで上昇しました。
効果が結構大きいので、とりあえず現時点ではパラメータを削除する実装のままになっています。

@kohya-ss
Copy link
Owner

kohya-ss commented Apr 2, 2023

更新ありがとうございます。かなり良さそうですね。

レイヤーの指定については、お書きいただいた通り12層+1層+12層で良さそうです。LoRAのconv2d_3x3適用有無に関わらず共通なのも良さそうです。

次にD-Adaptation optimizer使用時の動作に関してですが、レイヤー別学習率を使用しない際にはparam_dataにlrを付与しないように変更しました。恐らくこれで正常な動作になっているはずです。

D-Adaptationについては良さそうですね。ただ(あとから気づいたので申し訳ないのですが)ログへの学習率の出力が正しく動作しなくなってしまうようです。

層別学習率を指定したときにはそれぞれの学習率を表示し、指定しないときには今まで通りの学習率を表示したいところですが(学習率のグラフが大量に出てくるのはさすがに避けたいので)、なかなか悩ましいですね……。ちょっと私の方でも検討してみますが、なにかアイデアありましたらコメントまたは更新いただけると幸いです。

LoRAモジュールの扱いについて詳細に検討していただきありがとうございます。なるほど、学習時間に影響するのは大きいですね……。となると学習しないパラメータは削除する方向が良さそうですね。幸い、Web UIのbuilt-inのLoRA機能は、一部のLoRAモジュールが欠けていても動作するようですし。

@u-haru
Copy link
Contributor Author

u-haru commented Apr 2, 2023

あまり複雑なことをしたくなかったので、とりあえず条件分岐で分けました。
コードが冗長になってしまった気もしますが、層別学習率を使わない場合の動作は以前と同じになったと思います。

@kohya-ss
Copy link
Owner

kohya-ss commented Apr 2, 2023

ありがとうございます。結局のところ、それが一番妥当な気がしますね……。確認後マージいたします。

@kohya-ss kohya-ss merged commit 36c8a4a into kohya-ss:dev Apr 2, 2023
@u-haru
Copy link
Contributor Author

u-haru commented Apr 2, 2023

マージ確認しました。お忙しい中対応していただき、ありがとうございました。

@kohya-ss
Copy link
Owner

kohya-ss commented Apr 2, 2023

こちらこそ、なかなか自分では手が付けられないので、ありがとうございました。stratifyがちょっとわかりにくく思いましたので、blockに変えさせていただきました(若干意味が変わってしまいますがわかりやすさを優先して……)。

テストおよびログ出力のところで苦戦しているのでmainへのマージは明日になりそうです。お時間をいただき申し訳ありませんが、よろしくお願いいたします。

@bmaltais bmaltais mentioned this pull request Apr 7, 2023
@u-haru u-haru deleted the feature/stratified_lr branch July 8, 2023 14:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants