LayoutLMの特徴と事前学習タスクについて - LayerX エンジニアブログ

LayerX エンジニアブログ

LayerX の エンジニアブログです。

LayoutLMの特徴と事前学習タスクについて

LayerXで機械学習エンジニアを担当している @yoppiblog です。今回はOCRチームで検証したLayoutLMについて簡単に紹介します。

LayoutLMとは

LayoutLMとは昨今注目されているマルチモーダルなDocument Understanding領域の1実装です。 様々な文書(LayerXだとバクラクではお客様の多種多様な請求書といった帳票を扱っています)から情報を抽出(支払金額、支払期日や取引先名など)するために考案されたものになります。

BERT(LayoutLMv3はRoBERTa)ベースのencoder層を用いレイアウト情報や、文書そのものを画像特徴量としてembedding層で扱っているところが既存のモデルより、より文書解析に特化している点です。 v1〜v3まで提唱されており、v3が一番精度が高いモデルです。

もともと、LayoutLMv2では多言語対応されたモデル(LayoutXLM)の事前学習済みモデルが公開されていたため、それを元に検証しはじめたのですが、事前学習済みモデル自体は商用利用不可なため、自分たちで事前学習モデルから作り検証に至ったというお話です。

LayoutLMv1

(huggingface実装でもv1とはついておらず、無印LayoutLMなのですが区別しやすくここではv1と表現しています)

LayoutLMv1は文書のテキスト情報とその情報の位置関係(BoundingBox)をembeddingとして扱い、BERTベースのTransformer(encoder層のみ)で表現しています。 シンプルな実装で、input layer(embedding層)とheadが異なるだけで、encoder層はBERTそのものです。

LaoutLMv1アーキテクチャ

LayoutLMv2

LayoutLMv2はv1の特徴量に加え、文書自身を画像として扱いimage embeddingとして加えていることです。 Visual EmbeddingはCNNベースで表現されており、文書画像を224x224にリサイズし、ResNextを通して特徴マップを作成し、線形射影することでtext embeddingと次元を統一させてconcatして扱っています。

LayoutLMv2アーキテクチャ

LayoutLMv3

LayoutLMv3は、v2では画像をCNNベースでimage embeddingを作成していましたが、ViTにして画像特徴量を学習している部分が大きな変更点です。 結果、v2で複雑になってしまったアーキテクチャも見直されて全体的にすっきりもしました。 テキストのembeddingもBERTから、RoBERTaに移行されました。

LayoutLMv3アーキテクチャ

事前学習モデルの実装詳細

事前学習モデルはよく知られているように、大量のデータセットで事前学習タスクを解くことで、特定のタスク(今回だと請求書に含まれている金額などを抽出する)に共通な汎用的なモデルにすることです。 LayoutLMの研究で扱っている帳票などは、商用利用不可なものも含まれているため、バクラクで扱っている請求書を用いて事前学習を行いました。

また、今回検証するにあたり、LayoutLMv1とLayoutLMv3の事前学習タスクも実装しました(v2より簡易なものとより精度が高いとされているもので精度比較したいため) 転移学習(fine-tuning)の実装はhuggingfaceにされているのですが、事前学習タスクは実装されていない状態だったため、自分たちで実装する必要があったためです。

LayotuLMv1の事前学習タスク

LayoutLMv1では2つ提唱されていますが、1つはオプショナルとされていて、実質1つの事前学習タスクを解くだけとなっています。

  • MVLM(Masked Visual-Language Model)
    • 基本的にはBERTの事前学習タスクのmasked modelと同等だが、layout情報(BoundingBox)をembeddingに加えているところが特徴的な点

LayoutLMv3の事前学習タスク

  • MLM(Masked Language Model)

    • 基本的にはmasked modelだが、マスクするアルゴリズムとして、span maskingという戦略を採用してマスクして解くタスクと定義されている
    • SpanBERTが主な実装例
  • MIM(Masked Image Modeling)

    • v3からは画像をViTで扱うようになり、そこを踏まえてViTで一般的な事前学習タスクであるMIMが採用されている
    • 画像を224x224にresizeしたあと14x14のパッチに分割しそのパッチをランダムにマスクしたのち、そこからImage Patchを生成し(BERTのembeddingと同等)
    • また元画像をVisual Token(離散的なトークン。テキストにおけるsubwordのidのようなもの)に変換し
    • Image Pathを元画像のVisual Tokenに戻していくことを学習するタスク
  • WPA(Word Patch Alignment)

    • MLMとMIMはそれぞれ独立にテキストと画像の特徴を学習するが、それぞれの関係は学習できていないため、
    • テキストと画像の関係を学習するために、ある単語(subword)と対応(同じ領域に含まれているかどうか)するImage Patchがマスクされているかどうか、を予測する2値分類のタスクを解く(対応するImage Patchがマスクされていない場合「整列」されている、マスクされている場合「非整列」というラベルを付与する)
    • ただし、単語がマスクされている場合は損失計算時に除去する

事前学習タスクの実装について

基本的に、事前学習タスクは下流で汎用的なタスクを解くための基盤モデルとして位置づけられており、headを付け替えることで任意のタスクに適用可能にすることを目的としているため、embedding層とエンコーダー層は転移学習時と同じものが使えます。 以下に、huggingface実装をもとに実装可能です。 LayoutLMv3の事前学習タスクの実装としてこのように書けます。

class LayoutLMv3ForPretrain(LayoutLMv3PreTrainedModel):
    def __init__(self, config: LayoutLMv3ConfigForPretrain):
        super().__init__(config)

        self.layoutlmv3 = LayoutLMv3Model(config)
        self.mlm = LayoutLMv3MLMHead(config)
        self.mim = LayoutLMv3MIMHead(config)
        self.wpa = LayoutLMv3WPAHead(config)

        self.post_init()

    def forward(
        self,
        input_ids=None,
        bbox=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        mlm_labels=None,  # MLMの正解ラベル。input_idsに対応する元のtoken_id
        mim_labels=None,  # MIMの正解ラベル。pixel_values(Image Patch)のVisual Token
        wpa_labels=None,  # WPAの正解ラベル。input_idsに対応するpixel_valuesがmaskされているかどうか
        bool_masked_pos=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        pixel_values=None,
    ) -> MaskedLMOutput:
        outputs = self.layoutlmv3(
            input_ids,
            bbox=bbox,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            pixel_values=pixel_values,
        )

        input_shape = input_ids.size()
        sequence_length = input_shape[1]

        # mlm
        sequence_output = outputs[0][:, :sequence_length]
        mlm_output = self.mlm(sequence_output, mlm_labels)

        # mim
        image_sequence_output = outputs[0][:, sequence_length + 1 :]  # cls tokenがcatされているので1つずらす
        mim_output = self.mim(image_sequence_output, mim_labels, bool_masked_pos)

        # wpa
        wpa_output = self.wpa(sequence_output, wpa_labels)

        return MaskedLMOutput(
            loss=mlm_output.loss + mim_output.loss + wpa_output.loss,
        )

embedding層とencoder層は変わらないため、元のLayoutLMv3実装であるモデルをfowardします(内部でembedding層・encoder層を通る)。 その結果をもとに、それぞれの事前学習タスクのheadを通して損失を求めて足すことで全体の損失を計算します。

また、特徴的なのは、テキストと画像をconcatして一つのベクトルにしてネットワークを通すことでマルチモーダルを実現している点でもあります。 ただ、それぞれの事前学習タスクでは独立して計算するため、損失を求めるときにはテキスト・画像のシーケンスをそれぞれ取り出してheadに通して損失を計算しています。

各事前学習タスクのheadの実装は、単純にMLPを実装しているだけです。 たとえば、MLMのheadはこのように実装しています。

class LayoutLMv3MLMPredictionHead(nn.Module):
    def __init__(self, config: LayoutLMv3ConfigForPretrain):
        super().__init__()
        decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        bias = nn.Parameter(torch.zeros(config.vocab_size))
        decoder.bias = bias
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            activations.ACT2FN[config.hidden_act],
            nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps),
            decoder,
        )

    def forward(self, sequence_output: Any) -> Any:
        hidden_states = self.mlp(sequence_output)
        return hidden_states


class LayoutLMv3MLMHead(nn.Module):
    def __init__(self, config: LayoutLMv3ConfigForPretrain):
        super().__init__()
        self.config = config
        self.predictions = LayoutLMv3MLMPredictionHead(config)

    def forward(self, sequence_output: torch.Tensor, mlm_labels: torch.Tensor) -> MaskedLMOutput:
        prediction_scores = self.predictions(sequence_output)

        if mlm_labels is not None:
            loss_fct = CrossEntropyLoss()
            mlm_loss = loss_fct(
                prediction_scores.view(-1, self.config.vocab_size),
                mlm_labels.view(-1),
            )

        return MaskedLMOutput(
            loss=mlm_loss,
            logits=prediction_scores,
        )

実装内容はシンプルで、encoder層の出力結果に対してMLPを通し損失をCrossEntropyで求めています。 MIM、WPAのheadも同様にMLPを通し、CrossEntropyで損失を求めています。

事前学習モデルの学習

データセットはバクラクで扱っている請求書、数十万件を用いて学習させました。 LayoutLMはほぼBERT(RoBERTa)なので、パラメータサイズは、baseとlargeがあり、今回はbaseで検証しています。 ハイパーパラメータは論文に示されているものをそのまま用いて、warmupのみ追加でいれています(だいぶ安定します)

転移学習と評価

転移学習には、トレーニング用の請求書数万件用いて学習し、評価にテストデータセット数万件を使い請求書から読み取りたい項目のAccuracyを指標にしています。 ルールベース(バクラクのOCRはもとはルールベースで動いている)よりいくつかの項目で精度が向上しました。

まとめ

簡単にですが、LayoutLMの概要と事前学習タスクについて簡単に紹介しました。 事前学習モデルを作ることは意外と難しくなく、データセットが大量にあれば(こっちが難しい)自前で作れるという認識を持てたためとてもよい取り組みだったと思います。 今後も、データセットの改善とともに基盤モデルをLayerX内で作っていければと考えています。