引言
模型是团队在年月由等人在论文《》所提出,当前它已经成为领域中的首选模型。抛弃了的顺序结构,采用了-机制,使得模型可以并行化训练,而且能够充分利用训练资料的全局信息,加入的模型在的各个任务上都有了显著的提升。本文做了大量的图示目的是能够更加清晰地讲解的运行原理,以及相关组件的操作细节,文末还有完整可运行的代码示例。
注意力机制
中的核心机制就是-。-机制的本质来自于人类视觉注意力机制。当人视觉在感知东西时候往往会更加关注某个场景中显著性的物体,为了合理利用有限的视觉信息处理资源,人需要选择视觉区域中的特定部分,然后集中关注它。注意力机制主要目的就是对输入进行注意力权重的分配,即决定需要关注输入的哪部分,并对其分配有限的信息处理资源给重要的部分。
Self-Attention
-工作原理如上图所示,给定输入向量,然后对于输入向量通过矩阵进行线性变换得到向量,向量,以及向量,即如果令矩阵,,,,则此时则有接着再利用得到的向量和向量计算注意力得分,论文中采用的注意力计算公式为点积缩放公式论文中假定向量的元素和向量的元素独立同分布,且令均值为,方差为,则此时注意力向量的第个分量的均值为,方差具体的计算公式如下令注意力分数矩阵,则有注意分数向量经过层得到归一化后的注意力分布,即为最后利用得到的注意力分布向量和矩阵获得最后的输出,则有令输出矩阵,则有
Multi-Head Attention
-的工作原理与-的工作原理非常类似。为了方便图解可视化将-设置为-,如果-设置为-,则上图的的下一步的分支数为。给定输入向量,然后对于输入向量通过矩阵进行第一次线性变换得到向量,向量,以及向量。然后再对向量通过矩阵和进行第二次线性变换得到和,同理对向量通过矩阵和进行第二次线性变换得到和,对向量通过矩阵和进行第二次线性变换得到和,具体的计算公式如下所示:令矩阵此时则有对于每个利用得到对于向量和向量计算对应的注意力得分,其中注意力向量的第个分量的计算公式为令注意力分数矩阵,,则有注意分数向量经过层得到归一化后的注意力分布,即为对于每一个利用得到的注意力分布向量和矩阵获得最后的输出,则有两个的的向量按照如下方式拼接在一起,则有给定参数矩阵,则输出矩阵为综上所述则有
Mask Self-Attention
如下图左半部分所示,-的输出向量综合了输入向量的全部信息,由此可见,-在实际编程中支持并行运算。如下图右半部分所示,-的输出向量只利用了已知部分输入的向量的信息。例如,只是与有关;与和有关;与,和有关;与,,和有关。-在中被用到过两次。
- 的中如果输入一句话的长度小于指定的长度,为了能够让长度一致往往会用进行填充,此时则需要用-来计算注意力分布。
- 的的输出是有时序关系的,当前的输出只与之前的输入有关,所以此时算注意力分布时需要用到-。
Transformer模型
以上对中的核心内容即自注意力机制进行了详细解剖,接下来会对模型架构进行介绍。模型是由和两个模块组成,具体的示意图如下所示,为了能够对内部的操作细节进行更清晰的展示,下图以矩阵运算的视角对的原理进行讲解。
模块操作的具体流程如下所示:
- 的输入由两部分组成分别是词编码矩阵和位置编码矩阵,其中表示句子数目,表示一句话单词的最大数目,表示的是词向量的维度。位置编码矩阵表示的是每个单词在一句里的所有位置信息,因为-计算注意力分布的时候只能给出输出向量和输入向量之间的权重关系,但是不能给出词在一句话里的位置信息,所以需要在输入里引入位置编码矩阵。位置编码向量生成方法有很多。一种比较简单粗暴的方式就是根据单词在句子中的位置生成一个-的位置编码;还有的方法是将位置编码当成参数进行训练学习;在该论文里是利用三角函数对位置进行编码,具体的公式如下所示其中表示的是位置编码向量,表示词在句子中的位置,表示编码向量的位置索引。
- 输入矩阵通过线性变换生成矩阵,,。在实际编程中是将输入直接赋值给,,。如果输入单词长度小于最大长度并来填充的时候,还要相应引入矩阵。
- 将矩阵,,输入到-模块中进行注意分布的计算得到矩阵,计算公式为具体的计算细节参考上文关于-原理的讲解不在这里赘述。然后将原始输入与注意力分布进行残差计算得到输出矩阵。
- 对矩阵进行层归一化操作得到,具体的计算公式为
- 将输入到全连接神经网络中得到 ,然后再让全连接神经网络的输入与输出进行残差计算得到,接着对进行层归一化操作。
- 以上是一个的操作原理,将个进行堆叠就组成了的模块,得到的最后输出为。这里需要注意的是模块中的各个组件的操作顺序并不是固定的,也可以先进行归一化操作,然后再计算注意力分布,再归一化,再预测等。
模块操作的具体流程如下所示:
- 的输入也由两部分组成分别是词编码矩阵和位置编码矩阵。因为的输入是具有时顺序关系的(即上一步的输出为当前步输入)所以还需要输入矩阵以便计算注意力分布。
- 输入矩阵通过线性变换生成矩阵,,。在实际编程中是将输入直接赋值给,,。如果输入单词长度小于最大长度并来填充的时候,还要相应引入矩阵。
- 将矩阵,,以及矩阵输入到-模块中进行注意分布的计算得到矩阵,计算公式为具体的计算细节参考上文关于-的讲解不在这里赘述。然后将原始输入与注意力分布进行残差计算得到输出矩阵。接着再对矩阵进行层归一化操作得到。
- 的输出通过线性变换得到和,进行线性变换得到,利用矩阵和和进行交叉注意力分布的计算得到,计算公式为这里的交叉注意力分布综合输出结果和中间结果的信息。实际编程编程中将直接赋值给和,直接赋值给。然后将与注意力分布进行残差计算得到输出矩阵。
- 接着对进行层归一操作得到,再将输入到全连接神经网络中得到,接着再做一步残差操作得到,最后再进行一层归一化操作。
- 以上是一个的操作原理,将个进行堆叠就组成了的模块,得到的输出为。然后在词汇字典中找到当前预测最大概率的单词,并将该单词词向量作为下一阶段的输入,重复以上步骤,直到输出“”字符为止。
代码示例
具体的代码示例如下所示为一个国外博主视频里的代码,并根据上文对代码的一些细节进行了探讨。根据上文中-原理示例图可知,严格来看-在求注意分布的时候中间其实是有两步线性变换。给定输入向量 第一步线性变换直接让向量赋值给,,,这一过程以下程序中有所体现,在这里并不会产生歧义。第二步线性变换产生多,假设的时候,按理说要与个矩阵进行线性变换得到个,同理要与个矩阵进行线性变换得到个,要与个矩阵进行线性变换得到个,如果按照这个方式在程序实现则需要定义24个权重矩阵,非常的麻烦。以下程序中有一个简单的权重定义方法,通过该方法也可以实现以上多的线性变换,以向量为例:
- 首先将向量进行截断分成个向量,即为其中是的第个截断向量,是单位矩阵,是零矩阵。
- 然后对用相同的权重矩阵进行线性变换,此时可以发现,训练过程的时候只需要更新权重矩阵即可,而且可以进行多线性变换,个权重矩阵可以表示为:其中权重矩阵。
import torch
import torch.nn as nn
import os
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query, mask):
N =query.shape[0]
value_len , key_len , query_len = values.shape[1], keys.shape[1], query.shape[1]
# split embedding into self.heads pieces
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(queries)
energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)
# queries shape: (N, query_len, heads, heads_dim)
# keys shape : (N, key_len, heads, heads_dim)
# energy shape: (N, heads, query_len, key_len)
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = torch.softmax(energy/ (self.embed_size ** (1/2)), dim=3)
out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)
# attention shape: (N, heads, query_len, key_len)
# values shape: (N, value_len, heads, heads_dim)
# (N, query_len, heads, head_dim)
out = self.fc_out(out)
return out
class TransformerBlock(nn.Module):
def __init__(self, embed_size, heads, dropout, forward_expansion):
super(TransformerBlock, self).__init__()
self.attention = SelfAttention(embed_size, heads)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, forward_expansion*embed_size),
nn.ReLU(),
nn.Linear(forward_expansion*embed_size, embed_size)
)
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query, mask):
attention = self.attention(value, key, query, mask)
x = self.dropout(self.norm1(attention + query))
forward = self.feed_forward(x)
out = self.dropout(self.norm2(forward + x))
return out
class Encoder(nn.Module):
def __init__(
self,
src_vocab_size,
embed_size,
num_layers,
heads,
device,
forward_expansion,
dropout,
max_length,
):
super(Encoder, self).__init__()
self.embed_size = embed_size
self.device = device
self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
self.position_embedding = nn.Embedding(max_length, embed_size)
self.layers = nn.ModuleList(
[
TransformerBlock(
embed_size,
heads,
dropout=dropout,
forward_expansion=forward_expansion,
)
for _ in range(num_layers)]
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
N, seq_length = x.shape
positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
for layer in self.layers:
out = layer(out, out, out, mask)
return out
class DecoderBlock(nn.Module):
def __init__(self, embed_size, heads, forward_expansion, dropout, device):
super(DecoderBlock, self).__init__()
self.attention = SelfAttention(embed_size, heads)
self.norm = nn.LayerNorm(embed_size)
self.transformer_block = TransformerBlock(
embed_size, heads, dropout, forward_expansion
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, value, key, src_mask, trg_mask):
attention = self.attention(x, x, x, trg_mask)
query = self.dropout(self.norm(attention + x))
out = self.transformer_block(value, key, query, src_mask)
return out
class Decoder(nn.Module):
def __init__(
self,
trg_vocab_size,
embed_size,
num_layers,
heads,
forward_expansion,
dropout,
device,
max_length,
):
super(Decoder, self).__init__()
self.device = device
self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
self.position_embedding = nn.Embedding(max_length, embed_size)
self.layers = nn.ModuleList(
[DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
for _ in range(num_layers)]
)
self.fc_out = nn.Linear(embed_size, trg_vocab_size)
self.dropout = nn.Dropout(dropout)
def forward(self, x ,enc_out , src_mask, trg_mask):
N, seq_length = x.shape
positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))
for layer in self.layers:
x = layer(x, enc_out, enc_out, src_mask, trg_mask)
out =self.fc_out(x)
return out
class Transformer(nn.Module):
def __init__(
self,
src_vocab_size,
trg_vocab_size,
src_pad_idx,
trg_pad_idx,
embed_size = 256,
num_layers = 6,
forward_expansion = 4,
heads = 8,
dropout = 0,
device="cuda",
max_length=100
):
super(Transformer, self).__init__()
self.encoder = Encoder(
src_vocab_size,
embed_size,
num_layers,
heads,
device,
forward_expansion,
dropout,
max_length
)
self.decoder = Decoder(
trg_vocab_size,
embed_size,
num_layers,
heads,
forward_expansion,
dropout,
device,
max_length
)
self.src_pad_idx = src_pad_idx
self.trg_pad_idx = trg_pad_idx
self.device = device
def make_src_mask(self, src):
src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
# (N, 1, 1, src_len)
return src_mask.to(self.device)
def make_trg_mask(self, trg):
N, trg_len = trg.shape
trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
N, 1, trg_len, trg_len
)
return trg_mask.to(self.device)
def forward(self, src, trg):
src_mask = self.make_src_mask(src)
trg_mask = self.make_trg_mask(trg)
enc_src = self.encoder(src, src_mask)
out = self.decoder(trg, enc_src, src_mask, trg_mask)
return out
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
x = torch.tensor([[1,5,6,4,3,9,5,2,0],[1,8,7,3,4,5,6,7,2]]).to(device)
trg = torch.tensor([[1,7,4,3,5,9,2,0],[1,5,6,2,4,7,6,2]]).to(device)
src_pad_idx = 0
trg_pad_idx = 0
src_vocab_size = 10
trg_vocab_size = 10
model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(device)
out = model(x, trg[:, : -1])
print(out.shape)