Mask机制

虽然Mask机制在NLP领域是一个十分常见的操作,但是过去并没有仔细思考它的意义。最近参加了阿里天池的一个关于医学影像报告异常检测的数据竞赛。本质上是一个关于文本的多标签分类任务。在这个任务中,我尝试使用Transformer的Encoder结构作为基础来构建分类模型。为了巩固以及加深理解,没有使用PyTorch自带的Transformer模型,而是选择手动搭建。

Encoder中的Mask

在Encoder部分,涉及到的mask主要指self-attention过程中,在计算每个token的query与key的相似度时,需要考虑一个重要的问题就是padding。因为我们的每条语料数据的长度一般是不同的,因此为了保证输入模型的input的size完全一致,我们会在末尾添加padding部分来使得每个输入的长度完全一样。但是这部分内容实际上是没有意义的。因此在attention时,注意力不应该放在这部分,应该将这部分mask起来。也就是说我们需要将query与padding部分对应的key的相似度度量统一转化为一个很小的数,比如1e-9或1e-10。这样经过softmax之后,这部分的权重会接近于0。那么具体该怎么做呢?额,直接看代码吧。

首先是Encoder部分的最底层实现,MultiHeadAttention以及之后的全连接层。
假设我们的输入尺寸为[B,L], B代表batch_size, L代表seq_len,也就是序列长度,那么我们再经过self-attention之后得到的输出,也就是下面的scaled_attn,尺寸为[B,H,L,L],其中H代表Head个数,具体的转化过程见代码,我就不展开了。然后就到我们需要做mask的时候了,这个时候我们会在模型中传入一个尺寸为[B,1,1,L]的mask,其中pad_idx对应位置被设置为0,也就是需要mask的位置,其余为1。借助PyTorch的broadcasting机制,我们可以顺利地实现对scaled_attn的mask任务,然后将经过mask之后的scaled_attn(相似度度量矩阵)去与value相乘,得到每个token最终的向量表示。因为本文重点在与解释mask 机制,对于后面的全连接以及residual+layer_norm的操作就不展开细说了,大家可以参考下面的代码。

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class ScaleDotProductAttention(nn.Module):
    def __init__(self,scale,atten_dropout=0.1):
        super(ScaleDotProductAttention,self).__init__()
        self.dropout=nn.Dropout(atten_dropout)
        self.scale=scale
    
    def forward(self,q,k,v,mask=None): #shape=[B,H,L,D]
        attn=torch.matmul(q,k.transpose(-2,-1)) #这里q:[B,H,L,D] k:[B,H,D,L]
        scaled_attn=attn/self.scale #attn 的output:[B,H,L,L]

        if mask is not None: #传入的mask:[B,1,1,L]
            scaled_attn.masked_fill(mask==0,-1e9)
        
        scaled_attn=self.dropout(F.softmax(scaled_attn,dim=-1))
        output=torch.matmul(scaled_attn,v) 
        return output,scaled_attn

class MultiHeadAttention(nn.Module):
    def __init__(self,n_head,dim_model,dim_k,dim_v,dropout=0.2):
        super(MultiHeadAttention,self).__init__()
        self.dim_model=dim_model
        self.n_head=n_head
        self.dim_k=dim_k #query 和 key的维度相同所以这里只定义一个
        self.dim_q=dim_k
        self.dim_v=dim_v

        self.w_q=nn.Linear(dim_model,n_head*dim_k,bias=False)
        self.w_k=nn.Linear(dim_model,n_head*dim_k,bias=False)
        self.w_v=nn.Linear(dim_model,n_head*dim_v,bias=False)
        self.fc=nn.Linear(n_head*dim_v,dim_model,bias=False)

        self.attention=ScaleDotProductAttention(scale=dim_k**0.5)
        self.dropout=nn.Dropout(dropout)
        self.layer_norm=nn.LayerNorm(dim_model,eps=1e-6)
    
    def forward(self,q,k,v,mask=None):
        d_k,d_v,n_head=self.dim_k,self.dim_v,self.n_head
        batch_size,len_q,len_k,len_v=q.size(0),q.size(1),k.size(1),v.size(1)

        residual=q
        q=self.w_q(q).view(batch_size,len_q,n_head,d_k) #将head单独取出作为一维
        k=self.w_k(k).view(batch_size,len_k,n_head,d_k)
        v=self.w_v(v).view(batch_size,len_v,n_head,d_v)

        #在attention前将len_ 与 head维度互换
        q,k,v=q.transpose(1,2),k.transpose(1,2),v.transpose(1,2) #shape=[B,H,L,D]

        if mask is not None: #传入的mask:[B,1,L]
            mask = mask.unsqueeze(1)   # For head axis broadcasting--->mask:[B,1,1,L]
        
        #attention
        output,attn=self.attention(q,k,v,mask=mask)
        output=output.transpose(1,2).contiguous().view(batch_size,len_q,-1)#合并heads
        output=self.dropout(self.fc(output))
        #print(output.shape,q.shape)
        output+=residual #+residual
        output=self.layer_norm(output) #layer normalization
        return output

class PositionwiseFeedForward(nn.Module):
    '''two feed forward layers'''
    def __init__(self,dim_in,dim_hid,dropout=0.2):
        super(PositionwiseFeedForward,self).__init__()
        self.w1=nn.Linear(dim_in,dim_hid)
        self.w2=nn.Linear(dim_hid,dim_in) #输出维度不变
        self.layer_norm=nn.LayerNorm(dim_in,eps=1e-6)
        self.dropout=nn.Dropout(dropout)
    
    def forward(self,x):
        residual=x
        x=self.w2(self.dropout(F.relu(self.w1(x))))
        x+=residual
        return x

在完成来上述基础组件的搭建之后,我们就可以实现单个encoder_layer以及由任意多个encoder_layer搭建的完整Encoder了,下面是代码,为了看起来清晰,我将encoder_layer单独写在一个脚本上了。

import torch.nn as nn
import torch
from transformer_sublayers import ScaleDotProductAttention,MultiHeadAttention, PositionwiseFeedForward

class EncoderLayer(nn.Module):
    def __init__(self,dim_model,dim_hid,n_head,dim_k,dim_v,dropout=0.2):
        super(EncoderLayer,self).__init__()
        self.slf_attn=MultiHeadAttention(n_head,dim_model,dim_k,dim_v)
        self.ffn=PositionwiseFeedForward(dim_model,dim_hid,dropout=dropout)
    

    def forward(self,enc_input,slf_attn_mask=None):
        attn_output=self.slf_attn(enc_input,enc_input,enc_input,mask=slf_attn_mask) #mask:Boolean构成的[B,1,L]
        output=self.ffn(attn_output)
        return output

下面是完整Encoder(其中包括设计mask的函数定义)

import torch
import torch.nn as nn
import numpy as np
from transformer_layers import EncoderLayer

def get_pad_mask(seq,pad_idx):
    return (seq!=pad_idx).unsqueeze(-2)

#定义位置信息
class PositionalEncoding(nn.Module):

    def __init__(self, dim_hid, n_position=200):
        super(PositionalEncoding, self).__init__()

        #缓存在内存中,常量
        self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, dim_hid))

    def _get_sinusoid_encoding_table(self, n_position, dim_hid):
        ''' Sinusoid position encoding table '''
        def get_position_angle_vec(position):
            return [position / np.power(10000, 2 * (hid_j // 2) / dim_hid) for hid_j in range(dim_hid)]

        sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
        sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
        sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

        return torch.FloatTensor(sinusoid_table).unsqueeze(0)

    def forward(self, x):
        return x + self.pos_table[:, :x.size(1)].clone().detach()


class Encoder(nn.Module):

    def __init__(self,vocab_size,dim_word_vec,n_layers,n_head,dim_k,dim_v,dim_model,dim_hid,pad_idx,dropout=0.2,n_position=200):
        super(Encoder,self).__init__()
        self.embedding_layer=nn.Embedding(num_embeddings=vocab_size,
                                          embedding_dim=dim_word_vec,
                                          padding_idx=pad_idx)
        self.positionencode=PositionalEncoding(dim_hid=dim_word_vec,n_position=200)
        self.dropout=nn.Dropout(dropout)
        self.layer_stacks=nn.ModuleList([
            EncoderLayer(dim_model=dim_model,dim_hid=dim_hid,n_head=n_head,dim_k=dim_k,dim_v=dim_v)
          for _ in range(n_layers)])
        self.layer_norm=nn.LayerNorm(dim_model,eps=1e-6)
        self.dim_model=dim_model
        self.pad_idx=pad_idx
    
    def forward(self,x):
        
        token_embedd=self.embedding_layer(x)
        token_position_embedd=self.dropout(self.positionencode(token_embedd))
        encode_output=self.layer_norm(token_position_embedd) #shape=[B,L,E]--->(batch_size,seq_len,embed_dim)

        mask=get_pad_mask(x,self.pad_idx) #shape=[B,1,L]
        for encode_layer in self.layer_stacks:
            encode_output=encode_layer(encode_output,slf_attn_mask=mask)
        
        return encode_output

参考代码:
https://github.com/jadore801120/attention-is-all-you-need-pytorch