Megatron-LMとGKEで作るMixtral 8x7Bを語彙拡張継続事前学習 Part1 ~学習コードとモデルの先行公開~ - ABEJA Tech Blog

ABEJA Tech Blog

中の人の興味のある情報を発信していきます

Megatron-LMとGKEで作るMixtral 8x7Bを語彙拡張継続事前学習 Part1 ~学習コードとモデルの先行公開~

こんにちは!ABEJAでデータサイエンティストをしている大谷です。

ABEJAは国立研究開発法人新エネルギー・産業技術総合開発機構(以下「NEDO」)が公募した「ポスト5G情報通信システム基盤強化研究開発事業/ポスト5G情報通信システムの開発」に当社提案の「LLMの社会実装に向けた特化型モデルの元となる汎化的LLM」が採択されたことを受け、LLMの事前学習を実施しました。 以降、本LLMプロジェクトをGENIAC(Generative AI Accelerator Challenge)と表記します。

開発内容は表題の通り、Mistral社のMIxtral 8x7Bをベースにした日本語の語彙拡張版継続事前学習です。弊社が調べた限り、Megatron-LMでMixtralモデルを継続事前学習するソースコードは2024年4月12日時点(執筆時)では存在していません。 GENIACの計算資源提供の支援を最初に受けた2月開始の事業者として、後続の事業者への貢献(後続の事業者から直接ご要望をいただいた)と、LLMに関わる皆さまのお役に少しでも立てればと思い、事後学習前ではありますが、先行してソースコードと取り組みの一部を公開することにしました。実はGENIACコミュニティの皆様には本ブログからさらに先行してソースコードは公開しております。

また、2024年4月10日に突如としてMixtral 8x22Bが公開されました。今回公開する学習コードは元モデルがHuggingfaceのMixtralForCausalLMである限り流用可能ですので、GPUリソースをお持ちの方にはぜひご活用いただければと思います!

ソースコードと学習途中のモデル重み

コードとモデルだけ欲しいよという方のために最初にリンクを貼ります。

(WIP)学習用コード

  1. mistralai/Mixtral-8x7B-v0.1から継続学習した重み(約250B tokenを学習して停止したもの)

  2. mistralai/Mixtral-8x7B-Instruct-v0.1から継続学習したモデル(学習途中ですが、ベンチマーク評価を実施したもの)

  3. 学習途中の2から差分マージしたモデル

  4. mistralai/Mixtral-8x7B-Instruct-v0.1から継続学習したモデル

公開コードとモデルの備考

  • ソースコードは先行公開扱いのため、あまり綺麗にできていないことをご了承ください。
  • Mixtralモデルの実装に関して、Rolling Buffer Cache 未実装、Sliding Window Attentionは適用したものの学習速度が低下する事象に遭遇したためオフとなっています。
  • GKEのセットアップに関しては別記事にて公開します。
  • 1に関しては、学習途中のチェックポイントを調査していったところMixtral-8x7B-Instruct-v0.1から継続事前学習した方が精度が良いと判断したため学習を途中停止させました。
  • 2はまだ学習途中のモデルになりますが、ベンチマークを実施したタイミングのモデルになります。
  • 3に関して、まだpost trainingを実施できていない中、インストラクト性能を確保するために実施しています。差分マージをしたモデルはmistralai/Mixtral-8x7B-v0.1mistralai/Mixtral-8x7B-Instruct-v0.1です。マージ詳細はこちらで確認できます。

概要

ABEJAはGENIACプロジェクトを通じて、日本語LLMおよびRAG、Agentといった周辺技術の研究開発を行っています。また、LLMの利活用の推進や社会全体におけるAI技術革新の加速、次世代の研究者や技術者の育成に貢献できるよう、開発したLLMおよびソースコードや開発ノウハウなどを適宜公開していく予定です。参考プレスリリース

本ブログを皮切りに、学習用ソースコードや学習済みのモデルだけでなく、前処理済みの学習用データや、前処理コード、インフラ構築のノウハウなど、可能な限りオープンにしていきますので、ぜひお楽しみにしていてください!

学習データ

学習データは以下のような構成となっており、前処理実施後の語彙拡張tokenizerで合計450B tokenの学習データを使用しています。

言語 データソース トークン数 備考
日本語 Common Crawl 約260B 2019~2023のtimestampを利用 。語彙拡張前のtokenizerでは約400B
日本語 独自収集データ 約30B -
英語 Redpajama 約126B -
英語 (Redpajama内)ソースコード 34B -

Model

Mixtral 8x7Bを選んだのはいくつかの理由があります。まず一つ目に、LLMの社会実装を目指す以上、推論時の速度を少しでも高速化させたかったのでMixture of Expertモデルを使いたかったこと。二つ目に、OSSとしての元のモデル性能が高かったこと。三つ目にGENIACを始めた時はまだ誰もMixtral 8x7Bの日本語継続事前学習を行なっておらずこれが実現できたらかっこいいと思ったからです。(2024年3月11日にSwallow MXが発表され、悔しかった非常に参考にさせていただきました。)

Megatron-LMでMixtral 8x7Bを実装する上でポイントとなった点は以下の4点でした。

  1. Huggingfaceモデル→Megatronの重みの変換(継続事前学習のため)

  2. Megatronの重み→Huggingfaceへの変換(評価のため)

  3. Mixtralモデルをそのものの実装(こちらのPull Requestを参考にしています)

  4. load balancing lossの実装。パイプライン間でデータ通信に載せる必要がありましたが、工夫で最小限の変更に抑えています。

モデルの詳細やMegatron-LM、実は実装していたMegatron-Deepspeedへの実装の工夫点・つらみ、学習ロスなどの詳細は別のブログにて公開する予定です。

モデルマージ

計算資源の支援期間はInstructチューニングなどのpost trainingではなく、基盤モデルの事前学習にリソースを割きたかったため、モデルマージによってインストラクトの性能を引き上げる取り組みを実施しました。 語彙追加しているため、vocab数が元のMixtralモデルと異なっていますが、元のvocabサイズ(32k)だけマージしても性能が出ることが判明したのでその手法を使っています。

pretrained_model_dict = {k: v for k, v in pretrained_model.named_parameters()}
mixtral_model_dict = {k: v for k, v in mixtral_model.named_parameters()}
mixtral_inst_model_dict = {k: v for k, v in mixtral_inst_model.named_parameters()}

with torch.no_grad():
    for k, src in pretrained_model_dict.items():
        print(k)
        diff = mixtral_inst_model_dict[k] - mixtral_model_dict[k] 
        shared_shape = tuple(min(a, b) for a, b in zip(src.shape, diff.shape))
        slices = tuple(slice(0, dim) for dim in shared_shape)
        src[slices] += diff[slices]

print("Merged model saving...")
pretrained_model.save_pretrained("merged_model") 

ベンチマーク結果

GENIACではWeights & Biases社のご支援のもと、llm-leaderboard (g-leaderboardブランチ)でベンチマークを実施しています。総合の平均点(AVG)は以下の式を利用し、1が満点になるように計算しています。
Overall average = ((llm-jp-eval_ja_0shot + llm-jp-eval_ja_4shot) / 2 + (MMLU_0shot + MMLU_4shot) / 2 + MT-bench_ja / 10 + MT-bench_en / 10) / 4
また、参考までに日本語のみの平均点(AVG ja)を弊社で計算し記載しております。計算式は以下の通りです。
Japanese average = ((llm-jp-eval_ja_0shot + llm-jp-eval_ja_4shot) / 2 + MT-bench_ja / 10) / 2

モデル AVG AVG ja llm-jp-eval_ja_0shot llm-jp-eval_ja_4shot MMLU_0shot MMLU_4shot MT-bench_ja MT-bench_en 備考
mistralai/Mixtral-8x7B-Instruct-v0.1 0.5674 0.6276 0.5091 0.6063 0 0.4246 6.975 8.022 Weights & Biases社での評価
stabilityai/StableBeluga2 0.6048 0.6355 0.572 0.6711 0.3544 0.5018 6.494 7.203 Weights & Biases社での評価
tokyotech-llm/Swallow-MX-8x7b-NVE-v0.1 0.5795 0.5464 0.518 0.6386 0.4421 0.6526 5.144 6.781 Weights & Biases社での評価
tokyotech-llm/Swallow-70b-instruct-hf 0.5318 0.5318 0.5755 0.6755 0.5158 0.6386 4.381 4.866 Weights & Biases社での評価
1. abeja/Mixtral-8x7B-v0.1 語彙拡張継続学習 0.3371 0.3373 0.2144 0.6735 0.2596 0.5895 2.306 2.494 自社での評価(同ライブラリ)
2. abeja/Mixtral-8x7B-Instruct-v0.1 語彙拡張継続学習(学習途中モデル) 0.4101 0.4731 0.4672 0.6850 0.0211 0.5789 3.700 3.944
3. abeja/Mixtral-8x7B-Instruct-v0.1 語彙拡張継続学習差分マージ(学習途中モデル) 0.5300 0.6137 0.5318 0.7080 0 0.5860 6.075 5.997

ベンチマークが全てではないと思いますが、オープンな日本語モデルの中では高い性能が出ていることを確認しています。引き続きpost trainingで性能向上活動を実施していきます!

さいごに

大規模な事前学習を実施する機会を与えてくださったすべての関係者の皆様に感謝申し上げます。開発をするにあたって、Mixtralモデルの理解はもちろんのこと、データの前処理、前処理の分散処理、並列学習の知見、Google Cloudのインフラ知見、評価観点など、非常に多くの学びがありました。この学びは可能な限り公開し、日本の生成AIコミュニティを盛り上げていければと思います!

この成果は、NEDOの助成事業(JPNP20017)の結果得られたものです。