关于Transformer架构和原理解析的优秀文章有很多,这里列出一些供大家参考学习。本篇也就不对Transformer的结构和原理进行解读了(肯定没他们解读的好)。本篇主要从代码实现的层面,试图讲一下Transformer的Encoder和Decoder阶段各个模块的输入、输出形状,以及他们之间的关系。

详解Transformer (Attention Is All You Need) - 知乎

搞懂Vision Transformer 原理和代码,看这篇技术综述就够了(三)

搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了

https://pytorch.org/tutorials/beginner/transformer_tutorial.html


        上述参考文献足以让一个初学者完全了解transformer的架构及原理了。本文从代码实现的角度来看一下transformer的核心架构中的每个模块的具体实现方式,希望以此能对transformer架构有更深入的了解。本文参考的实现代码来自于这里

        首先,咱们还是先给出transformer的架构图:

pytorch unet 多输出网络_自注意力

        根据上图以及开源代码实现,可以把transformer划分为输入模块、Encoder模块、Decoder模块三部分。这些模块中包含的核心部分如下:

        1、单词向量编码word encoding

        2、位置编码positional encoding

        3、多头自注意力Multi-Head Self-Attention

        4、掩码Mask多头自注意力Masked Multi-Head Self-Attention

        5、多头交叉注意力Multi-Head Cross-Attention

        6、前馈网络(多层MLP)Feed Forward Network

        其中:

        输入模块:Encoder和Decoder模块都需要使用的,包括单词向量编码word encoding和位置编码positional encoding

        Encoder模块:包括多头自注意力Multi-Head Self-Attention和前馈网络Feed Forward Network

        Decoder模块:包括掩码Mask多头自注意力Multi-Head Self-Attention、多头交叉注意力Multi-Head Cross-Attention、前馈网络Feed Forward Network

        从上面的描述可以看出来,Transformer的核心模块就这几个,并且Encoder和Decoder中的好多模块是可以直接复用的。下面分别看一下transformer中的这些核心模块的代码实现,以及各个模块的输入输出是什么样子。 


任务描述        

        假设我们在做中英文翻译任务,将中文翻译成英文:        

                中文:猫吃鱼

                英文:Cats eat fish

         为了表示一个句子的开始和结束,需要在句子的开始和结束位置分别添加上标记符。

                开始标记符:<s>

                结束标记符:</s>

        在开源代码preprocess.py实现中,使用torchtext来实现上述操作:

SRC = torchtext.data.Field(
        tokenize=tokenize_src, lower=not opt.keep_case,
        pad_token=Constants.PAD_WORD, init_token=Constants.BOS_WORD, eos_token=Constants.EOS_WORD)

    TRG = torchtext.data.Field(
        tokenize=tokenize_trg, lower=not opt.keep_case,
        pad_token=Constants.PAD_WORD, init_token=Constants.BOS_WORD, eos_token=Constants.EOS_WORD)
PAD_WORD = '<blank>'
UNK_WORD = '<unk>'
BOS_WORD = '<s>'
EOS_WORD = '</s>'

        现在输入序列就变成了"<s>猫吃鱼</s>",序列长度是5,在后文中我们将使用time_len_src表示原始序列的长度,使用time_len_tgt表示目标序列的长度。


输入模块

        transformer的输入模块是Encoder和Decoder模块所共有的,主要包含两个部分:单词向量编码word encoding和位置编码positional encoding。

1、单词向量化word embedding

        参数解释:

                num_embeddings表示输入序列中单词表的大小,包括开始、结束符,PAD占位符

                embedding_dim表示每个单词的特征维度大小

                padding_idx表示[0, num_embeddings - 1]中的哪一个索引下标用来表示占位符,占位符对应的embedding为全0

        输入:一个单词序列,也就是一个句子"<s>猫吃鱼</s>"

        输出:一个单词向量矩阵,矩阵的大小是[5, 512],5是序列的总长度,512是每个词向量的维度

        假如我们现在训练语料中总共包含10000个词语,我们需要把每个词语数值化(也叫向量化),经过数值化之后的每个词语才能方便的被计算机处理。word embedding就是预先定义一个可学习的embedding矩阵,把每个词语转换成一个512维(维度可调)的向量。那么就预先创建一个大小为[10000, 512]的可学习的权重矩阵,矩阵的行索引对应单词表中每个单词的索引编号,这样根据每个单词的索引编号就能获取到每个单词512维的embedding向量。在pytorch框架中,使用nn.Embedding来实现。这里给一个Embedding的示例:

import torch
from torch import nn

# num_embeddings表示输入序列中单词表的大小,包括开始、结束符,PAD占位符
# embedding_dim表示每个单词的特征维度大小
# padding_idx表示[0, num_embeddings - 1]中的哪一个用来表示占位符,占位符对应的embedding为全0
model = nn.Embedding(num_embeddings=101, embedding_dim=16, padding_idx=100)
print(model.weight.shape)
text_data = torch.tensor([0, 1, 100])
emb_1 = model(text_data)
print(emb_1.shape)

emb_2 = model(torch.tensor([[0, 1, 100], [1, 2, 3]]))
print(emb_2.shape)

        其中:

        num_embeddings表示输入序列中单词表的大小,包括开始、结束符,PAD占位符
        embedding_dim表示每个单词的特征维度大小
        padding_idx表示[0, num_embeddings - 1]中的哪一个用来表示占位符,占位符对应的embedding为全0

2、单词位置编码positional encoding

        参数解释:

                n_position:序列的长度大小,也就是所有position的位置数量

                d_hid:位置编码的维度大小,要等于词向量编码的维度大小

        输入:序列的长度大小5,一个位置向量的维度大小512

        输出:一个位置向量矩阵,矩阵的大小是[5, 512],5是序列的总长度,512是每个位置向量的维度

        在transformer出现之前,处理序列任务通常采用循环神经网络RNN及其变体LSTM、GRU等,循环神经网络经过专门的设计可以用来捕获长距离的依赖关系, t 时刻输出依赖 t - 1时刻的输出,所以循环神经网络只能顺序的处理序列数据,计算并行度为1,捕获长距离的依赖难度系数正比于序列的长度N。

        而transformer相比于RNN循环神经网络有两个主要优势:

        a)、捕获长距离依赖的难度系数从序列的长度N降为常数1

        b)、计算并行度由1提升到序列长度大小N

        看过transformer原理的一定对其中的Q、K、V印象深刻,通过Q和K矩阵运算,直接计算序列中任意两个时刻之间的依赖关系,在这种矩阵计算模式下,序列中的所有输入在位置层面上是等价的,丢失了序列之间的前后关系,所以要显式的把这种序列关系给弥补回来,这就是位置编码。位置编码通过给序列的每个位置添加一个能够唯一标识该位置的信息,transformer中使用周期性的三角函数来计算每个位置时刻的唯一的位置编码,具体计算如下:

pytorch unet 多输出网络_pytorch_02

        给每个时刻生成一个和词向量维度相同的位置向量,这个位置向量通过sin、cos周期函数进行表示,这样可以保证每个时刻的位置标识都是唯一的。

class PositionalEncoding(nn.Module):
    """
        对序列中的每个输入embedding使用sin、cos函数生成一个唯一的位置编码
    """

    def __init__(self, d_hid, n_position=200):
        super(PositionalEncoding, self).__init__()

        # Not a parameter
        # 保存在模型中的非训练参数
        self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))

    def _get_sinusoid_encoding_table(self, n_position, d_hid):
        """
        Sinusoid position encoding table
        :param n_position: 位置的数量,也就是序列的长度
        :param d_hid: 每个位置生成的位置编码向量的大小
        :return:
        """

        # TODO: make it with torch instead of numpy

        def get_position_angle_vec(position):
            return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

        sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
        sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
        sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

        return torch.FloatTensor(sinusoid_table).unsqueeze(0)  # (1, time_len, d_hid)

    def forward(self, x):
        """
        将位置编码与word embedding相加
        :param x: x.shape -> [batch_size, time_len, d_hid]
        :return:
        """
        # 给序列中的每个输入加上位置编码信息
        return x + self.pos_table[:, :x.size(1)].clone().detach()


def positional_embedding_demo():
    posEmb = PositionalEncoding(d_hid=128, n_position=100)
    pos = posEmb.pos_table.squeeze(dim=0).detach().cpu().numpy()
    print(pos.shape)

    plt.imshow(pos)
    plt.show()

    input_emb = torch.randn((1, 100, 128))
    pos_embedding = posEmb(input_emb)
    print(pos_embedding.shape)

    plt.imshow(pos_embedding[0])
    plt.show()


if __name__ == '__main__':
    positional_embedding_demo()

        在位置编码PositionalEncoding的实现中,n_position参数表示位置的数量,也就是序列的长度,d_hid参数表示每个位置生成的位置编码向量的大小,d_hid的大小要与word embedding的维度保持一致。假如n_position=100, d_hid=512(等于词向量的维度),那么位置编码中的pos_table的大小为[100, 512]。

        位置编码的可视化结果如下图所示,每一行代表一个位置向量,可以看到每个位置向量都是唯一的。

pytorch unet 多输出网络_transformer_03

        回到开头讲的中英文翻译任务,输入的中文序列是"<s>猫吃鱼</s>",将这个序列输入到单词向量化word embedding模块中,得到序列中每个词的Embedding向量,这个Embedding向量是可学习的,假设我们设置向量的维度是512,那么对于"<s>猫吃鱼</s>"这个序列我们就得到了一个大小为[5, 512]的词向量矩阵。

        同样针对"<s>猫吃鱼</s>"这个输入序列,序列的长度是5,所以位置编码同样会计算得到一个大小为[5, 512]的位置向量矩阵。将两个大小为[5, 512]的矩阵按位相加,就得到了这个输入序列的向量化表示。


Encoder模块

        Encoder模块主要包含三个部分:多头自注意力Multi-Head Self-Attention和前馈网络(多层MLP)Feed Forward Network。

1、多头自注意力Multi-Head Self-Attention

       参数解释:

                temperature:温度系数,也就是缩放系数

                attn_dropout:在attention上使用dropout的丢弃率

                batch_size:批量大小

                n_heads:注意力头的数量

                time_len:序列的时序长度

                dim:输入序列的维度大小

                mask:掩码矩阵,在Encoder的Self-Attention阶段和Decoder的预测阶段不需要使用,在Decoder训练阶段的Self-Attention部分使用
    

        Self-Attention的计算方式有多种方法,具体可以参考Non-Local论文里面的attention计算部分。transformer的self-attention模块使用的是Scaled Dot Product计算方法。Scaled Dot Product模块的输入输出格式如下:

        输入:

 

q:q.shape -> [batch_size, n_heads, time_len, dim]
                k:k.shape -> [batch_size, n_heads, time_len, dim]
                v:v.shape -> [batch_size, n_heads, time_len, dim]
                mask.shape -> [time_len ,time_len] or None

        输出:

                shape -> [batch_size, n_heads, time_len, dim]

        此处的time_len既可以是源序列的长度time_len_src,也可以是目标序列的长度time_len_tgt,所以不做特别说明。

class ScaledDotProductAttention(nn.Module):
    """
        Scaled Dot-Product Attention,这是self-attention的核心部分
        在原始Transformer架构中,用于解决NLP问题,模型的输入维度是(batch_size, seq_len, dim)
        同时,在Transformer架构中,使用了multi-head多头注意力机制,为了提高并行计算能力,同时计算n个注意力头的结果,
        self-Attention的输入维度变成(batch_size, n_head, seq_len, dim)

        Mask矩阵的缘由:
        1、在transformer模型的训练阶段,为了提升训练时的并行计算能力,一次性将整个训练样本输入到模型中,
        对序列中的所有输入同时做self-Attention操作,为了防止模型偷窥到未来的信息,也就是在计算
        self-Attention的过程中t时刻的输入偷窥到t+1时刻的输入信息,在模型训练阶段,在计算self-Attention
        权重矩阵的时候使用Mask将t时刻之后的attention score置为0,这样在基于attention score和V计算加权平
        均时,就只会计算t时刻之前内容的加权平均。
        2、在transformer模型的预测推理阶段,由于未来的信息本来就不存在,所以此时就不需要使用Mask矩阵
    """

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, q, k, v, mask=None):
        """
        :param q: q.shape -> [batch_size, n_heads, time_len, dim]
        :param k: k.shape -> [batch_size, n_heads, time_len, dim]
        :param v: v.shape -> [batch_size, n_heads, time_len, dim]
        :param mask:
        :return:
        """
        # 使用scaled dot product公式计算Q和K之间的attention score
        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))

        if mask is not None:
            # 在Transformer模型训练阶段,防止偷窥到后面时刻的输入信息,相加docstring部分的mask矩阵解释
            attn = attn.masked_fill(mask == 0, -1e9)

        attn = self.dropout(F.softmax(attn, dim=-1))
        # 将Q和K计算得到的attention score使用softmax归一化之后,对V做weighted sum
        output = torch.matmul(attn, v)

        return output, attn


def scaled_dot_product_demo():
    batch_size = 8
    n_heads = 2
    time_step = 16
    dim = 32

    # 下面shape为1的维度表示此时只有一个attention head
    q = torch.randn((batch_size, n_heads, time_step, dim))
    k = torch.randn((batch_size, n_heads, time_step, dim))
    v = torch.randn((batch_size, n_heads, time_step, dim))
    # print(torch.matmul(q, k.transpose(2, 3)).shape)

    dotProductAttn = ScaledDotProductAttention(temperature=dim, attn_dropout=0.0)
    # diagonal表示考虑对角线之外的几个位置
    mask = torch.tril(torch.ones(time_step, time_step), diagonal=0)
    print(mask)
    output, attn = dotProductAttn(q, k, v, mask=mask)
    print(output.shape, attn.shape)

    for a in attn:
        plt.imshow(a[0].detach().cpu().numpy())
        plt.show()
        break


if __name__ == '__main__':
    scaled_dot_product_demo()

        在Self-Attention计算的基础上,为了提升特征提取的多样性,组合使用多个self-attention组件,并将多个self-attention的结果进行合并。同时,为了提升多头注意力Multi-Head Self-Attention模块的计算速度和并行度,将多个注意力头通过一次计算得到。MultiHeadAttention模块的输入输出格式如下: 

        输入:

   

q:q.shape -> [batch_size, time_len_src, dim]
                k:k.shape -> [batch_size, time_len_src, dim]
                v:v.shape -> [batch_size, time_len_src, dim]
                mask.shape -> [time_len_src ,time_len_src] or None

        输出:

                shape -> [batch_size, time_len_src, dim]

     

class MultiHeadAttention(nn.Module):
    """
        Multi-Head Attention module
        根据输入的序列embedding内容,执行下面步骤:
        1、执行multi-head多头注意力
        2、执行residual残差连接
        3、执行LayerNorm归一化
    """

    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        # 将输入embedding一次性转换为能被n_head个attention使用的形状,后面一次性输入到n_head个attention里面
        # 实现n_head个attention的并行计算
        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
        self.fc = nn.Linear(n_head * d_v, d_model, bias=False)

        self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)

        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self, q, k, v, mask=None):
        """
        :param q: q.shape -> [batch_size, time_len, dim]
        :param k: k.shape -> [batch_size, time_len, dim]
        :param v: v.shape -> [batch_size, time_len, dim]
        :param mask:
        :return:
        """
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)

        # 用于计算残差
        residual = q

        # Pass through the pre-attention projection: b x lq x (n*dv)
        # Separate different heads: b x lq x n x dv
        # 将Q、K、V经过一个转换,维度提升为(batch_size, seq_len, n_head, dim)
        # q.shape -> [batch_size, time_len, n_head, dim]
        # k.shape -> [batch_size, time_len, n_head, dim]
        # v.shape -> [batch_size, time_len, n_head, dim]
        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        # Transpose for attention dot product: b x n x lq x dv
        # 交换Q、K、V的seq_len,和n_head两个维度
        # q.shape -> [batch_size, n_head, time_len, dim]
        # k.shape -> [batch_size, n_head, time_len, dim]
        # v.shape -> [batch_size, n_head, time_len, dim]
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        if mask is not None:
            # mask本来是2D的(seq_len, seq_len),现在给mask增加一个维度,变成(seq_len, 1, seq_len),能够在n_head维度进行广播
            mask = mask.unsqueeze(0)   # For head axis broadcasting.

        # 执行self-attention操作
        # q.shape -> [batch_size, n_head, time_len, dim]
        q, attn = self.attention(q, k, v, mask=mask)

        # Transpose to move the head dimension back: b x lq x n x dv
        # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
        # 把attention输出的结果的形状变成(batch_size, seq_len, n_head * dim)
        # q.shape -> [batch_size, time_len, n_head * dim]
        q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
        # q.shape -> [batch_size, time_len, n_head * dim] 转换为[batch_size, time_len, d_model]
        q = self.dropout(self.fc(q))
        q += residual

        # 执行LayerNorm归一化
        q = self.layer_norm(q)

        return q, attn


def multi_head_attn_demo():
    batch_size = 8
    time_step = 16
    dim = 32
    n_head = 8

    q = torch.randn((batch_size, time_step, dim))
    k = torch.randn((batch_size, time_step, dim))
    v = torch.randn((batch_size, time_step, dim))

    multi_head_attn = MultiHeadAttention(n_head=n_head, d_model=dim, d_k=dim, d_v=dim, dropout=0.1)
    mask = torch.tril(torch.ones(time_step, time_step), diagonal=0)
    out, attn = multi_head_attn(q, k, v, mask=mask)
    print(out.shape, attn.shape)

    for a in attn:
        plt.imshow(a[0].detach().cpu().numpy())
        plt.show()
        break


if __name__ == '__main__':
    multi_head_attn_demo()

        multi-head attention就是用多组不同的Q、K、V权重矩阵来提取特征,然后把提取到的特征进行concat,为了提升计算的并行度,在multi-head attention中把多组self-attention操作合并在一次进行处理。下面通过一段拆解的代码来看一下多头注意力模块中做了哪些操作(这里省略了残差和LayerNorm的计算),代码如下:

batch_size = 4
    n_heads = 8
    len_seq = 16
    dim = 32

    # 词向量 + 位置编码之后得到的输入矩阵word_vec
    word_vec = torch.rand(batch_size, len_seq, dim)

    # 计算多头注意力的转换矩阵Q、K、V,将n_heads个注意力头使用一次计算完成
    weight_q = torch.rand(dim, dim * n_heads)
    weight_k = torch.rand(dim, dim * n_heads)
    weight_v = torch.rand(dim, dim * n_heads)

    # 将输入的word_vec与Q、K、V做矩阵乘法
    q = torch.matmul(word_vec, weight_q).view(batch_size, len_seq, n_heads, -1).transpose(1, 2)
    k = torch.matmul(word_vec, weight_k).view(batch_size, len_seq, n_heads, -1).transpose(1, 2)
    v = torch.matmul(word_vec, weight_v).view(batch_size, len_seq, n_heads, -1).transpose(1, 2)
    print(q.shape, k.shape, v.shape)

    # Scaled Dot Product的缩放系数
    temperature=dim ** 0.5

    attn = torch.matmul(q / temperature, k.transpose(2, 3))
    print(attn.shape)

    # 计算每个注意力头对于batch中每个输入样本的注意力权重矩阵
    attn_weight = F.softmax(attn, dim=-1)
    print(attn_weight.shape)
    
    # 将注意力权重矩阵和v做加权得到最终输出
    output = torch.matmul(attn_weight, v)
    print(output.shape)

2、前馈网络(多层MLP)Feed Forward Network

        参数解释:

                d_in:输入向量的特征维度

                d_hid:MLP隐层大小

                dropout:dropout丢弃率

        输入:

                shape -> [batch_size, time_len_src, d_in]

        输出:

                shape -> [batch_size, time_len_src, d_in]

        在Multi-Head Self-Attention之后,接一个多层的MLP网络进行特征融合,具体实现代码如下:

class PositionwiseFeedForward(nn.Module):
    """
        A two-feed-forward-layer module
        在self-attention block之后接一个FFN层,执行一下操作:
        1、使用两层MLP做一次线性变换
        2、执行residual残差连接
        3、执行LayerNorm归一化
    """

    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_in, d_hid)  # position-wise
        self.w_2 = nn.Linear(d_hid, d_in)  # position-wise
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        :param x: x.shape -> [batch_size, time_len, d_in]
        :return:
        """
        residual = x

        x = self.w_2(F.relu(self.w_1(x)))
        x = self.dropout(x)
        x += residual

        x = self.layer_norm(x)

        return x

        下面看一下整合了上面几个部分的一个完整的Encoder Transformer Block模块的代码:

class EncoderLayer(nn.Module):
    """
        Compose with two layers
        EncoderLayer是Encoder里面的一层,里面包含一个完整的Transformer Encoder模块,有如下内容:
        1、multi-head多头注意力
        2、残差连接
        3、LayerNorm
        4、多层MLP,也就是FFN
        在transformer的Encoder部分可以连续堆叠使用多层EncoderLayer,就像多层LSTM一样
    """

    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

    def forward(self, enc_input, slf_attn_mask=None):
        # encoder阶段的self-attention的输入是整个embedding序列
        enc_output, enc_slf_attn = self.slf_attn(
            enc_input, enc_input, enc_input, mask=slf_attn_mask)
        enc_output = self.pos_ffn(enc_output)
        return enc_output, enc_slf_attn

Auto-Encoding与Auto-Regressive

        在讲Transformer的Decoder部分之前,先介绍一下Auto-Encoding与Auto-Regressive的相关概念。

        双向语言模型(通常指Auto-Encoding):通常采用随机掩码的方式遮盖部分内容,然后利用上下文预测被掩码的部分。如BERT及其后续改进版本。Auto Encoding比较适合做语言理解类NLU(Natural Language Unstanding)的任务。

        单向语言模型(通常指Auto-Regressive):单向语言模型可以是从左向右也可以是从右向左,在这种预测方式中,模型在预测第 t 时间片的内容时只能看到第 t 时间片之前的内容,因此单向语言模型一次只能预测输出一个单词或一个字。如GPT系列、ChatGPT、GPT-4等。Auto Regressive比较适合做语言生成类NLG(Natural Language Generative)的任务。

        Encoder-Decoder语言模型:Encoder-Decoder(如Transformer、Google T5)将两者结合,Encoder部分使用Auto Encoding,Decoder部分使用Auto Regressive,Encoder与Decoder分别是不同的底层Transformer网络。

pytorch unet 多输出网络_encoder-decoder_04

pytorch unet 多输出网络_transformer_05

         经过上述介绍我们知道了Transformer的Encoder部分采用的是Auto-Encoding模式,Decoder部分采用的是Auto-Regressive模式。


Decoder模块

        Decoder模块主要包含三个部分:掩码Mask多头自注意力Masked Multi-Head Self-Attention、多头交叉注意力Multi-Head Cross-Attention和前馈网络(多层MLP)Feed Forward Network

1、掩码Mask多头自注意力Masked Multi-Head Self-Attention

        在上面的介绍中我们知道了Decoder部分属于Auto-Regressive模式,也就是单向的语言模型。与Encoder中的多头自注意力模块的计算几乎一致,区别在于Decoder阶段的Self-Attention需要用到mask掩码来实现单向语言模型, Self-Attention模块 t 时刻的输出只能依赖于其之前时刻的输出内容来计算,没办法依赖未来没有发生的内容。

        既然 t 时刻的输出内容只能依赖其之前时刻的输出,就像循环神经网络RNN那样,那为什么需要用到Mask掩码呢?掩码在这里起到什么作用呢?

        谈到掩码Mask的具体作用,就要简单说一下Decoder部分再训练和预测阶段的不同。先看一下预测阶段的Decoder

        Decoder预测阶段:

        在预测阶段,我们看一下每个时间步Decoder的输入和输出内容:

        时刻1:

                输入:开始标记符"<s>"

                输出:Cats

                真实标签:None

        时刻2:

                输入:开始标记符 “<s> Cats”

                输出:eat

                真实标签:None

        时刻3:

                输入:开始标记符 “<s> Cats eat”

                输出:fish

                真实标签:None

        时刻4:

                输入:开始标记符 “<s> Cats eat fish”

                输出:</s>

                真实标签:None

        这就是Decoder在预测时候的一次完整输出过程,在预测阶段,我们根据当前时刻之前的所有内容来预测输出当前时刻的内容,因为预测阶段我们是不知道标准答案的,所以预测阶段每一步的真实标签(Ground Truth)都是空的,所以在Decoder进行预测时是不需要使用掩码Mask的。下面我们看一下Decoder的训练阶段有什么不同。

        Decoder训练阶段:

        在Decoder的训练阶段,需要根据每个时间步的输出结果与真实标签计算目标损失,从而完成模型训练。所以预测阶段每个时间步的输入输出如下:

        时刻1:

                输入:正确的翻译结果"<s> Cats eat fish </s>"

                输出:Cats

                真实标签:Cats

        时刻2:

                输入:正确的翻译结果"<s> Cats eat fish </s>"

                输出:eat

                真实标签:eat

        时刻3:

                输入:正确的翻译结果"<s> Cats eat fish </s>"

                输出:fish

                真实标签:fish

        时刻4:

                输入:正确的翻译结果"<s> Cats eat fish </s>"

                输出:</s>

                真实标签:</s>

        我们都知道Transformer相比于循环神经网络的主要优势就是捕获长距离依赖和计算并行度高,所以为了提升训练阶段的计算并行度,提升训练速度,在训练阶段的每个时间步的输入都是完整的正确的翻译结果"<s> Cats eat fish </s>",但是在每个时间步我们是不能提前知道之后的时间发生的事情的,所以对于时间步 t ,只能让他看到他之前包括他自身这些已经发生的内容,不能让他看到未发生的内容(泄露天机),这就需要使用到掩码Mask了。掩码Mask主要作用于Scaled Dot Product阶段计算得到的attention权重矩阵上,将被掩码的部分的attention设置为 0 或者极小的数值。

pytorch unet 多输出网络_pytorch unet 多输出网络_06

        上图中左边的矩阵从上往下看分别是时刻1、时刻2、……。在时刻1只能看到自己,时刻2可以看到时刻1和时刻2,依次类推。灰色的方块内容表示被掩码的部分,其代表的权重值几乎都等于 0 。以第二行时刻2为例,计算时刻2的注意力内容时,由于时刻2之后的所有attention权重几乎都是0,不会对时刻2的计算内容产生影响,只有时刻1和时刻2的内容会对时刻2的计算产生影响。

        掩码Mask主要在Self-Attention模块发挥作用,下面我们看一下带掩码的多头注意力是啥样的,所以还是要看一下ScaledDotProductAttention部分的代码实现:

class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, q, k, v, mask=None):
        """
        q.shape -> [batch_size, n_heads, time_len, dim]
        k.shape -> [batch_size, n_heads, time_len, dim]
        v.shape -> [batch_size, n_heads, time_len, dim]
        :param q:
        :param k:
        :param v:
        :param mask:
        :return:
        """
        # attn shape -> [batch_size, n_heads, time_len, time_len]
        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))

        if mask is not None:
            # 将mask的位置设置为一个特别小的数,这样在softmax时输出的权重接近于0
            attn = attn.masked_fill(mask == 0, -1e9)

        attn = self.dropout(F.softmax(attn, dim=-1))
        output = torch.matmul(attn, v)

        return output, attn

        首先,mask是一个大小为[seq_len, seq_len]的对角矩阵,如果语言模型是从左向右的,mask矩阵就是一个下三角矩阵,对角线和下半部分的值都是1,其余部分是0。前面说过掩码Mask主要作用于Self-Attention的注意力权重计算阶段,注意力权重的计算使用softmax得到一个权重的概率分布,将被掩码的部分的权重设置为0,在softmax计算中如何让一个数值计算出来的概率为0或者几乎为 0 呢?答案就是让这个原始数值足够小,这样这个数值经过softmax之后对应的权重就几乎为0。Mask部分的核心代码就是下面两行:

if mask is not None:
            # 将mask的位置设置为一个特别小的数,这样在softmax时输出的权重接近于0
            attn = attn.masked_fill(mask == 0, -1e9)

        attn = self.dropout(F.softmax(attn, dim=-1))

        下面我们来看一下Decoder阶段的掩码Mask多头自注意力Masked Multi-Head Self-Attention的输入、输出是啥样的。我们前面讲了,Decoder在训练和预测阶段的处理是不一样的,所以我们就分别看一下训练和预测阶段的输入、输出是啥样的。

        与Encoder阶段使用的多头自注意力Multi-Head Self-Attention模块的参数解释完全一样,其实就是同一个代码模块。

        参数解释:

                temperature:温度系数,也就是缩放系数

                attn_dropout:在attention上使用dropout的丢弃率

                batch_size:批量大小

                n_heads:注意力头的数量

                time_len:序列的时序长度

                dim:输入序列的维度大小

                mask:掩码矩阵,在Encoder的Self-Attention阶段和Decoder的预测阶段不需要使用,在Decoder训练阶段的Self-Attention部分使用


        Decoder预测阶段:

        输入:

                q:q.shape -> [batch_size, t, dim]

                k:k.shape -> [batch_size, t, dim]

                v:v.shape -> [batch_size, t, dim]

                mask.shape -> None

        输出:

                shape -> [batch_size, t, dim]

       其中,t 表示时刻 t ,也就是在时刻1输入q的形状是[batch_size, 1, dim],时刻2输入q的形状是[batch_size, 2, dim],k和v同理。也就是在预测阶段的不同时刻,输入进掩码Mask多头自注意力Masked Multi-Head Self-Attention的特征形状是动态变化的。在Decoder的预测阶段,不需要使用掩码,所以Mask是None。

        Decoder训练阶段:

        输入:

                q:q.shape -> [batch_size, time_len_tgt, dim]

                k:k.shape -> [batch_size, time_len_tgt, dim]

                v:v.shape -> [batch_size, time_len_tgt, dim]

                mask.shape -> [time_len_tgt, time_len_tgt]

        输出:

                shape -> [batch_size, time_len_tgt, dim]

       上面讲了,为了提升Decoder阶段的并行度,训练阶段每次输入到Decoder的都是一个完整的标准答案,所以输入的序列长度time_len_tgt是整个目标句子的长度,然后通过掩码Mask来进行时序控制。在Decoder的训练阶段,掩码矩阵是一个[time_len_tgt, time_len_tgt]的三角矩阵。

2、多头交叉注意力Multi-Head Cross-Attention

        前面我们已经见到了Encoder阶段的多头自注意力Multi-Head Self-Attention,Decoder阶段的掩码Mask多头自注意力Masked Multi-Head Self-Attention,现在又出来一个多头交叉注意力Multi-Head Cross-Attention,他们之间有哪些不同点呢?下面我们还是从输入、输出的角度来对比一下这几个模块。

        首先,我们知道Transformer中使用的是Scaled Dot Product进行注意力计算,这个模块一共有四个输入参数:q、k、v和mask。那么,上面说到的多头自注意力Multi-Head Self-Attention、掩码Mask多头自注意力Masked Multi-Head Self-Attention和多头交叉注意力Multi-Head Cross-Attention的主要区别就在于他们的q、k、v的来源问题,以及是否使用mask。这里我们的分别以Encoder和Decoder中的第一层Transformer Block为例,其他曾的Block的输入分别来自于他们的上一层Transformer Block的输出。

        Encoder多头自注意力Multi-Head Self-Attention:

                q:来自于Encoder的输入序列,也就是源语言的 词向量 + 位置编码向量

                k:来自于Encoder的输入序列,也就是源语言的 词向量 + 位置编码向量

                v:来自于Encoder的输入序列,也就是源语言的 词向量 + 位置编码向量

                mask:不使用Mask

        Decoder掩码Mask多头自注意力Masked Multi-Head Self-Attention:

                q:来自于Decoder的输入序列,也就是目标语言的 词向量 + 位置编码向量

                k:来自于Decoder的输入序列,也就是目标语言的 词向量 + 位置编码向量

                v:来自于Decoder的输入序列,也就是目标语言的 词向量 + 位置编码向量

                mask:训练阶段的mask大小为[time_len_tgt, time_len_tgt],预测阶段不使用mask

        Decoder多头交叉注意力Multi-Head Cross-Attention:

                q:来自于Decoder的输入序列,也就是目标语言的 词向量 + 位置编码向量

                k:来自于Encoder最后一层Transformer Block的输出序列

                v:来自于Encoder最后一层Transformer Block的输出序列

                mask:不使用mask

        通过上面的描述大概也能看出来不同之处了,不管是Encoder,还是Decoder阶段只要是Self-Attention,那么q、k、v的来源是一致的,无非就是用不用mask掩码的问题。对于多头交叉注意力Multi-Head Cross-Attention模块,q来自Decoder部分的输入,k和v来自于Encoder模块的输出,所以多头交叉注意力Multi-Head Cross-Attention是跨Encoder和Decoder进行信息传递的,所以叫Cross。所以多头交叉注意力Multi-Head Cross-Attention模块的代码实现和之前的多头自注意力是一样的,只不过就是在调用的时候传入的q、k、v的来源不一样。

        同样,由于Decoder需要区分是训练阶段还是预测阶段,所以在不同的阶段,多头交叉注意力Multi-Head Cross-Attention模块的输入、输出形状也是不一样的。与Encoder阶段使用的多头自注意力Multi-Head Self-Attention模块的参数解释完全一样,其实就是同一个代码模块。

        参数解释:

                temperature:温度系数,也就是缩放系数

                attn_dropout:在attention上使用dropout的丢弃率

                batch_size:批量大小

                n_heads:注意力头的数量

                time_len:序列的时序长度

                dim:输入序列的维度大小

                mask:掩码矩阵,在Encoder的Self-Attention阶段和Decoder的预测阶段不需要使用,在Decoder训练阶段的Self-Attention部分使用


        Decoder预测阶段:

        输入:

                q:q.shape -> [batch_size, t, dim]

                k:k.shape -> [batch_size, time_len_src, dim]

                v:v.shape -> [batch_size, time_len_src, dim]

                mask.shape -> None

        输出:

                shape -> [batch_size, t, dim]

       其中,t 表示时刻 t ,也就是在时刻1输入q的形状是[batch_size, 1, dim],时刻2输入q的形状是[batch_size, 2, dim]。不同的是k和v是来自于Encoder阶段的输出,所以k和v的长度等于Encoder的输入序列长度。也就是在预测阶段的不同时刻,输入交叉多头自注意力Masked Multi-Head Cross-Attention的 q 的特征形状是动态变化的,k和v的形状是固定不变的。同时不需要使用掩码,所以Mask是None。

        Decoder训练阶段:

        输入:

                q:q.shape -> [batch_size, time_len_tgt, dim]

                k:k.shape -> [batch_size, time_len_src, dim]

                v:v.shape -> [batch_size, time_len_src, dim]

                mask.shape -> None

        输出:

                shape -> [batch_size, time_len_tgt, dim]

       由于q和k,v的来源不一样,不会导致未来信息的泄露,所以不需要使用mask,并且在上一层Self-Attention计算之后,已经使用mask计算出了完整的目标序列,所以序列长度是time_len_tgt。

3、前馈网络(多层MLP)Feed Forward Network

        Decoder部分的前馈网络(多层MLP)Feed Forward Network和Decoder部分是一样的,就不介绍了。这里我们主要讲一下前馈网络(多层MLP)Feed Forward Network在训练和预测阶段的输入、输出形状。

        参数解释:

                d_in:输入向量的特征维度

                d_hid:MLP隐层大小

                dropout:dropout丢弃率

        Decoder预测阶段:

        输入:

                shape -> [batch_size, t, d_in]

        输出:

                shape -> [batch_size, t, d_in]

        Decoder训练阶段:

        输入:

                shape -> [batch_size, time_len_tgt, d_in]

        输出:

                shape -> [batch_size, time_len_tgt, d_in]

        下面看一下整合了上面几个部分的一个完整的Decoder Transformer Block模块的代码:

class DecoderLayer(nn.Module):
    ''' Compose with three layers '''

    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

    def forward(
            self, dec_input, enc_output,
            slf_attn_mask=None, dec_enc_attn_mask=None):
        # decoder阶段的self-attention的Q/K/V均来自于decoder部分
        dec_output, dec_slf_attn = self.slf_attn(
            dec_input, dec_input, dec_input, mask=slf_attn_mask)
        # decoder阶段的cross-attention的Q来自于decoder上一时刻的输出,K/V均来自于encoder部分的输出
        dec_output, dec_enc_attn = self.enc_attn(
            dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
        dec_output = self.pos_ffn(dec_output)
        return dec_output, dec_slf_attn, dec_enc_attn