ニューラルネットのレコメンドをメモ化して高速にする - エムスリーテックブログ

エムスリーテックブログ

エムスリー(m3)のエンジニア・開発メンバーによる技術ブログです

ニューラルネットのレコメンドをメモ化して高速にする

こんにちは、AI・機械学習チーム(AIチーム)の農見(@rookzeno)です。最近作ったニューラルネットのレコメンドが遅くて困ってました。その時ふと推論してるデータを見ると、これ同じユーザーとアイテムが多発してるなと気づいたので、メモ化をして高速化しました。メモ化して高速化は基礎の基礎ですが、ニューラルネットでやってるのはあまり見ないかなと思ったので、今回はそのやり方について記載します。

DALL-Eでサムネを作成

この記事はエムスリーAI・機械学習チームで2週間連続で行われるブログリレー2日目の記事です。昨日の記事もよろしくお願いします。

www.m3tech.blog

使っているモデルについて

よくあるユーザーベクトルとアイテムベクトルを作成して、concatして推論するモデルです。

よくあるレコメンドモデル

ここで大事なのは最後のclick prediction部分では重い操作はしてないという点です。一方でItem vector作成には言語モデルを使用していて遅い。そこでメモ化の出番ということです。(userとitem全部feature storeに入れて近傍探索したらいいんじゃないって言われそうですが、全部ニューラルネットで完結したいじゃん)

コード

例えばこんなモデルがあるとします。

class ItemEncoder(nn.Module):
    def __init__(
        self,
        pretrained: str = 'ku-nlp/deberta-v2-base-japanese',
    ):
        super().__init__()
        self.model = AutoModel.from_pretrained(pretrained)

    def forward(self, input) -> torch.Tensor:
        return self.model(**input).last_hidden_state[:, 0] 

class UserEncoder(nn.Module):
    def __init__(
        self,
        hidden_size: int = 256,
        user_feature_shape: int = 4,
    ) -> None:
        super().__init__()
        self.fc1 = nn.Linear(user_feature_shape, hidden_size)

    def forward(self, user_features: torch.Tensor) -> torch.Tensor:
        user_features = F.gelu(self.fc1(user_features))
        return user_features

class Recommender(nn.Module):
    def __init__(
        self,
        hidden_size: int = 1024,
    ) -> None:
        super().__init__()
        self.item_encoder = ItemEncoder()
        self.user_encoder = UserEncoder()
        self.fc1 = nn.Linear(hidden_size, 1)

    def forward(self, item: dict, user_feature: torch.Tensor) -> torch.Tensor:
        item_encoded = self.item_encoder(item)
        user_encoded = self.user_encoder(user_feature)
        item_user = torch.concat([item_encoded, user_encoded], dim=1) 
        output = self.fc1(item_user)
        return output

これはよくあるレコメンドモデルの図のまんまのモデルです。ItemEncoderとUserEncoderでベクトルを作成して、Recommenderでconcatしてpredictionするモデルです。このモデルで遅いところはItemEncoderで言語モデルを使用するところです。

このモデルをMacで100回回すと41秒かかりました。

model = Recommender()
tokenizer = AutoTokenizer.from_pretrained('ku-nlp/deberta-v2-base-japanese')
item = tokenizer('こんにちは', return_tensors='pt', padding=True)
for i in range(100):
    model(item, torch.tensor([[1., 2., 3., 4.]]))

早速メモ化しましょう。

メモ化したコード

class Recommender(nn.Module):
    def __init__(
        self,
        hidden_size: int = 1024,
    ) -> None:
        super().__init__()
        self.item_encoder = ItemEncoder()
        self.user_encoder = UserEncoder()
        self.fc1 = nn.Linear(hidden_size, 1)
        self.memo = {}
        self.enable_memo = False

    def forward(self, item: dict, user_feature: torch.Tensor) -> torch.Tensor:
        if self.enable_memo:
            if item['input_ids'] in self.memo:
                item_encoded = self.memo[item['input_ids']]
            else:
                item_encoded = self.item_encoder(item)
                self.memo[item['input_ids']] = item_encoded
        else:
            item_encoded = self.item_encoder(item)
        user_encoded = self.user_encoder(user_feature)
        item_user = torch.concat([item_encoded, user_encoded], dim=1) 
        output = self.fc1(item_user)
        return output

enable_memoとmemoという変数を加えました。enable_memoがTrueの時には、itemとitem_vectorの辞書をどんどん作っていきます。ただ、メモするのは推論の時だけにしましょう。そうでないとitem_encoderが学習できなくなってしまいます。では早速メモ化の力を発揮してもらいましょう。

model = Recommender()
model.enable_memo = True
tokenizer = AutoTokenizer.from_pretrained('ku-nlp/deberta-v2-base-japanese')
item = tokenizer('こんにちは', return_tensors='pt', padding=True)
for i in range(100):
    model(item, torch.tensor([[1., 2., 3., 4.]]))

なんと3秒で終わりました。早い!

これは極端ですが、登場するitemの数より推論する数がずっと多い場合にはとても早くなります。レコメンドではそういう場合が多いと思うので、メモ化が役立つのではないでしょうか。

感想

至極当たり前の話でしたが、まあ基礎を怠らないことが大事ということですね。難しく考えなくても、簡単にできる高速化もあります。

We're hiring!

AI・機械学習チームでは、レコメンドの高速化や改善などのタスクが豊富にあります。医療ニュースや医療トピックの最適化等、レコメンドするものは沢山あるので、興味を持った方は、次のリンクからご応募お待ちしています! インターンも通年募集中です!

jobs.m3.com