引言

transformer库有没有java的_transformer模型是transformer库有没有java的_transformer_02团队在transformer库有没有java的_深度学习_03transformer库有没有java的_自然语言处理_04月由transformer库有没有java的_自然语言处理_05等人在论文《transformer库有没有java的_深度学习_06》所提出,当前它已经成为transformer库有没有java的_transformer_07领域中的首选模型。transformer库有没有java的_transformer抛弃了transformer库有没有java的_深度学习_09的顺序结构,采用了transformer库有没有java的_自然语言处理_10-transformer库有没有java的_Google_11机制,使得模型可以并行化训练,而且能够充分利用训练资料的全局信息,加入transformer库有没有java的_transformertransformer库有没有java的_transformer_13模型在transformer库有没有java的_transformer_07的各个任务上都有了显著的提升。本文做了大量的图示目的是能够更加清晰地讲解transformer库有没有java的_transformer的运行原理,以及相关组件的操作细节,文末还有完整可运行的代码示例。

注意力机制

transformer库有没有java的_transformer中的核心机制就是transformer库有没有java的_自然语言处理_10-transformer库有没有java的_Google_11transformer库有没有java的_自然语言处理_10-transformer库有没有java的_Google_11机制的本质来自于人类视觉注意力机制。当人视觉在感知东西时候往往会更加关注某个场景中显著性的物体,为了合理利用有限的视觉信息处理资源,人需要选择视觉区域中的特定部分,然后集中关注它。注意力机制主要目的就是对输入进行注意力权重的分配,即决定需要关注输入的哪部分,并对其分配有限的信息处理资源给重要的部分。

Self-Attention


transformer库有没有java的_transformer_21

transformer库有没有java的_自然语言处理_10-transformer库有没有java的_Google_11工作原理如上图所示,给定输入transformer库有没有java的_深度学习_24向量transformer库有没有java的_Self_25,然后对于输入向量transformer库有没有java的_Google_26通过矩阵transformer库有没有java的_transformer_27进行线性变换得到transformer库有没有java的_transformer_28向量transformer库有没有java的_Self_29transformer库有没有java的_深度学习_30向量transformer库有没有java的_自然语言处理_31,以及transformer库有没有java的_Google_32向量transformer库有没有java的_transformer_33,即transformer库有没有java的_transformer_34如果令矩阵transformer库有没有java的_transformer_35transformer库有没有java的_Self_36transformer库有没有java的_Google_37transformer库有没有java的_自然语言处理_38,则此时则有transformer库有没有java的_自然语言处理_39接着再利用得到的transformer库有没有java的_transformer_28向量和transformer库有没有java的_深度学习_30向量计算注意力得分,论文中采用的注意力计算公式为点积缩放公式transformer库有没有java的_Google_42论文中假定transformer库有没有java的_深度学习_30向量transformer库有没有java的_transformer_44的元素和transformer库有没有java的_transformer_28向量transformer库有没有java的_Self_46的元素独立同分布,且令均值为transformer库有没有java的_深度学习_47,方差为transformer库有没有java的_Self_48,则此时注意力向量transformer库有没有java的_transformer_49的第transformer库有没有java的_深度学习_50个分量transformer库有没有java的_深度学习_51的均值为transformer库有没有java的_深度学习_47,方差transformer库有没有java的_Self_48具体的计算公式如下transformer库有没有java的_Google_54令注意力分数矩阵transformer库有没有java的_深度学习_55,则有transformer库有没有java的_自然语言处理_56注意分数向量transformer库有没有java的_Self_57经过transformer库有没有java的_自然语言处理_58层得到归一化后的注意力分布transformer库有没有java的_Google_59,即为transformer库有没有java的_transformer_60最后利用得到的注意力分布向量transformer库有没有java的_Google_59transformer库有没有java的_Google_32矩阵transformer库有没有java的_transformer_63获得最后的输出transformer库有没有java的_Self_64,则有transformer库有没有java的_Self_65令输出矩阵transformer库有没有java的_transformer_66,则有transformer库有没有java的_深度学习_67

Multi-Head Attention


transformer库有没有java的_transformer_68

transformer库有没有java的_Self_69-transformer库有没有java的_transformer_70的工作原理与transformer库有没有java的_自然语言处理_10-transformer库有没有java的_Google_11的工作原理非常类似。为了方便图解可视化将transformer库有没有java的_Self_69-transformer库有没有java的_transformer_74设置为transformer库有没有java的_Google_75-transformer库有没有java的_transformer_74,如果transformer库有没有java的_Self_69-transformer库有没有java的_transformer_74设置为transformer库有没有java的_自然语言处理_79-transformer库有没有java的_transformer_74,则上图的transformer库有没有java的_transformer_81的下一步的分支数为transformer库有没有java的_自然语言处理_79。给定输入transformer库有没有java的_深度学习_24向量transformer库有没有java的_Self_25,然后对于输入向量transformer库有没有java的_Google_26通过矩阵transformer库有没有java的_transformer_27进行第一次线性变换得到transformer库有没有java的_transformer_28向量transformer库有没有java的_Self_29transformer库有没有java的_深度学习_30向量transformer库有没有java的_Google_90,以及transformer库有没有java的_Google_32向量transformer库有没有java的_自然语言处理_92。然后再对transformer库有没有java的_transformer_28向量transformer库有没有java的_自然语言处理_94通过矩阵transformer库有没有java的_自然语言处理_95transformer库有没有java的_Self_96进行第二次线性变换得到transformer库有没有java的_自然语言处理_97transformer库有没有java的_深度学习_98,同理对transformer库有没有java的_深度学习_30向量transformer库有没有java的_Google_100通过矩阵transformer库有没有java的_transformer_101transformer库有没有java的_transformer_102进行第二次线性变换得到transformer库有没有java的_Self_103transformer库有没有java的_Self_104,对transformer库有没有java的_Google_32向量transformer库有没有java的_transformer_106通过矩阵transformer库有没有java的_深度学习_107transformer库有没有java的_深度学习_108进行第二次线性变换得到transformer库有没有java的_Self_109transformer库有没有java的_transformer_110,具体的计算公式如下所示:transformer库有没有java的_自然语言处理_111令矩阵transformer库有没有java的_Self_112此时则有transformer库有没有java的_自然语言处理_113对于每个transformer库有没有java的_transformer_74利用得到对于transformer库有没有java的_transformer_28向量和transformer库有没有java的_深度学习_30向量计算对应的注意力得分,其中注意力向量transformer库有没有java的_transformer_117的第transformer库有没有java的_深度学习_50个分量的计算公式为transformer库有没有java的_Google_119令注意力分数矩阵transformer库有没有java的_深度学习_120transformer库有没有java的_Self_121,则有transformer库有没有java的_深度学习_122注意分数向量transformer库有没有java的_transformer_117经过transformer库有没有java的_自然语言处理_58层得到归一化后的注意力分布transformer库有没有java的_深度学习_125,即为transformer库有没有java的_transformer_126对于每一个transformer库有没有java的_transformer_74利用得到的注意力分布向量transformer库有没有java的_深度学习_125transformer库有没有java的_Google_32矩阵transformer库有没有java的_Self_130获得最后的输出transformer库有没有java的_Google_131,则有transformer库有没有java的_自然语言处理_132两个transformer库有没有java的_transformer_74transformer库有没有java的_自然语言处理_134的向量按照如下方式拼接在一起,则有transformer库有没有java的_自然语言处理_135给定参数矩阵transformer库有没有java的_Google_136,则输出矩阵为transformer库有没有java的_Google_137综上所述则有transformer库有没有java的_深度学习_138

Mask Self-Attention

如下图左半部分所示,transformer库有没有java的_自然语言处理_10-transformer库有没有java的_Google_11的输出向量transformer库有没有java的_自然语言处理_141综合了输入向量transformer库有没有java的_深度学习_142的全部信息,由此可见,transformer库有没有java的_自然语言处理_10-transformer库有没有java的_Google_11在实际编程中支持并行运算。如下图右半部分所示,transformer库有没有java的_transformer_145-transformer库有没有java的_Google_11的输出向量transformer库有没有java的_Self_147只利用了已知部分输入的向量transformer库有没有java的_transformer_148的信息。例如,transformer库有没有java的_Google_149只是与transformer库有没有java的_transformer_150有关;transformer库有没有java的_自然语言处理_151transformer库有没有java的_transformer_150transformer库有没有java的_深度学习_153有关;transformer库有没有java的_transformer_154transformer库有没有java的_transformer_150transformer库有没有java的_深度学习_153transformer库有没有java的_transformer_157有关;transformer库有没有java的_深度学习_158transformer库有没有java的_transformer_150transformer库有没有java的_深度学习_153transformer库有没有java的_transformer_157transformer库有没有java的_深度学习_162有关。transformer库有没有java的_transformer_145-transformer库有没有java的_Google_11transformer库有没有java的_transformer中被用到过两次。

  • transformer库有没有java的_Google_166transformer库有没有java的_Self_167中如果输入一句话的transformer库有没有java的_transformer_168长度小于指定的长度,为了能够让长度一致往往会用transformer库有没有java的_自然语言处理_169进行填充,此时则需要用transformer库有没有java的_Google_170-transformer库有没有java的_Self_171来计算注意力分布。
  • transformer库有没有java的_Google_166transformer库有没有java的_transformer_173的输出是有时序关系的,当前的输出只与之前的输入有关,所以此时算注意力分布时需要用到transformer库有没有java的_Google_170-transformer库有没有java的_Self_171

transformer库有没有java的_深度学习_176

Transformer模型

以上对transformer库有没有java的_transformer中的核心内容即自注意力机制进行了详细解剖,接下来会对transformer库有没有java的_transformer模型架构进行介绍。transformer库有没有java的_transformer模型是由transformer库有没有java的_Google_180transformer库有没有java的_深度学习_181两个模块组成,具体的示意图如下所示,为了能够对transformer库有没有java的_transformer内部的操作细节进行更清晰的展示,下图以矩阵运算的视角对transformer库有没有java的_transformer的原理进行讲解。
transformer库有没有java的_Self_167模块操作的具体流程如下所示:

  • transformer库有没有java的_Self_167的输入由两部分组成分别是词编码矩阵transformer库有没有java的_深度学习_186和位置编码矩阵transformer库有没有java的_深度学习_187,其中transformer库有没有java的_transformer_188表示句子数目,transformer库有没有java的_Google_189表示一句话单词的最大数目,transformer库有没有java的_深度学习_190表示的是词向量的维度。位置编码矩阵transformer库有没有java的_Google_191表示的是每个单词在一句里的所有位置信息,因为transformer库有没有java的_Google_192-transformer库有没有java的_Self_171计算注意力分布的时候只能给出输出向量和输入向量之间的权重关系,但是不能给出词在一句话里的位置信息,所以需要在输入里引入位置编码矩阵transformer库有没有java的_Google_191。位置编码向量生成方法有很多。一种比较简单粗暴的方式就是根据单词在句子中的位置生成一个transformer库有没有java的_Google_195-transformer库有没有java的_Self_196的位置编码;还有的方法是将位置编码当成参数进行训练学习;在该论文里是利用三角函数对位置进行编码,具体的公式如下所示transformer库有没有java的_深度学习_197其中transformer库有没有java的_深度学习_198表示的是位置编码向量,transformer库有没有java的_Google_199表示词在句子中的位置,transformer库有没有java的_Self_200表示编码向量的位置索引。
  • 输入矩阵transformer库有没有java的_Google_201通过线性变换生成矩阵transformer库有没有java的_Google_202transformer库有没有java的_Google_203transformer库有没有java的_深度学习_204。在实际编程中是将输入transformer库有没有java的_Google_201直接赋值给transformer库有没有java的_Google_202transformer库有没有java的_Google_203transformer库有没有java的_深度学习_204。如果输入单词长度小于最大长度并transformer库有没有java的_自然语言处理_169来填充的时候,还要相应引入transformer库有没有java的_Self_210矩阵。
  • 将矩阵transformer库有没有java的_Google_202transformer库有没有java的_Google_203transformer库有没有java的_深度学习_204输入到transformer库有没有java的_深度学习_214-transformer库有没有java的_自然语言处理_215模块中进行注意分布的计算得到矩阵transformer库有没有java的_Self_216,计算公式为transformer库有没有java的_transformer_217具体的计算细节参考上文关于transformer库有没有java的_深度学习_214-transformer库有没有java的_自然语言处理_215原理的讲解不在这里赘述。然后将原始输入transformer库有没有java的_Google_201与注意力分布transformer库有没有java的_transformer_221进行残差计算得到输出矩阵transformer库有没有java的_Google_222
  • 对矩阵transformer库有没有java的_Self_223进行层归一化操作得到transformer库有没有java的_Google_224,具体的计算公式为transformer库有没有java的_深度学习_225
  • transformer库有没有java的_transformer_226输入到全连接神经网络中得到transformer库有没有java的_深度学习_227 ,然后再让全连接神经网络的输入transformer库有没有java的_transformer_226与输出transformer库有没有java的_Self_229进行残差计算得到transformer库有没有java的_深度学习_230,接着对transformer库有没有java的_深度学习_230进行层归一化操作。
  • 以上是一个transformer库有没有java的_Self_232的操作原理,将transformer库有没有java的_transformer_233transformer库有没有java的_Self_232进行堆叠就组成了transformer库有没有java的_Self_167的模块,得到的最后输出为transformer库有没有java的_Google_236。这里需要注意的是transformer库有没有java的_Self_167模块中的各个组件的操作顺序并不是固定的,也可以先进行归一化操作,然后再计算注意力分布,再归一化,再预测等。

transformer库有没有java的_深度学习_181模块操作的具体流程如下所示:

  • transformer库有没有java的_transformer_173的输入也由两部分组成分别是词编码矩阵transformer库有没有java的_深度学习_240和位置编码矩阵transformer库有没有java的_深度学习_241。因为transformer库有没有java的_transformer_173的输入是具有时顺序关系的(即上一步的输出为当前步输入)所以还需要输入transformer库有没有java的_Self_210矩阵transformer库有没有java的_Self_244以便计算注意力分布。
  • 输入矩阵transformer库有没有java的_Google_245通过线性变换生成矩阵transformer库有没有java的_Self_246transformer库有没有java的_transformer_247transformer库有没有java的_自然语言处理_248。在实际编程中是将输入transformer库有没有java的_Google_245直接赋值给transformer库有没有java的_Self_246transformer库有没有java的_transformer_247transformer库有没有java的_自然语言处理_248。如果输入单词长度小于最大长度并transformer库有没有java的_自然语言处理_169来填充的时候,还要相应引入transformer库有没有java的_Self_210矩阵。
  • 将矩阵transformer库有没有java的_Self_246transformer库有没有java的_transformer_247transformer库有没有java的_自然语言处理_248以及transformer库有没有java的_Self_210矩阵transformer库有没有java的_Self_244输入到transformer库有没有java的_Self_260-transformer库有没有java的_自然语言处理_215模块中进行注意分布的计算得到矩阵transformer库有没有java的_Google_262,计算公式为transformer库有没有java的_自然语言处理_263具体的计算细节参考上文关于transformer库有没有java的_深度学习_264-transformer库有没有java的_Self_171的讲解不在这里赘述。然后将原始输入transformer库有没有java的_Google_245与注意力分布transformer库有没有java的_transformer_267进行残差计算得到输出矩阵transformer库有没有java的_transformer_268。接着再对矩阵transformer库有没有java的_深度学习_269进行层归一化操作得到transformer库有没有java的_自然语言处理_270
  • transformer库有没有java的_Self_167的输出transformer库有没有java的_深度学习_272通过线性变换得到transformer库有没有java的_Google_273transformer库有没有java的_Google_274transformer库有没有java的_transformer_267进行线性变换得到transformer库有没有java的_Google_276,利用矩阵transformer库有没有java的_Google_273transformer库有没有java的_Google_274transformer库有没有java的_Google_276进行交叉注意力分布的计算得到transformer库有没有java的_Google_280,计算公式为transformer库有没有java的_自然语言处理_281这里的交叉注意力分布综合transformer库有没有java的_Self_167输出结果和transformer库有没有java的_transformer_173中间结果的信息。实际编程编程中将transformer库有没有java的_深度学习_272直接赋值给transformer库有没有java的_Self_246transformer库有没有java的_transformer_247transformer库有没有java的_transformer_267直接赋值给transformer库有没有java的_Google_276。然后将transformer库有没有java的_Self_289与注意力分布transformer库有没有java的_Google_280进行残差计算得到输出矩阵transformer库有没有java的_自然语言处理_291
  • 接着对transformer库有没有java的_自然语言处理_291进行层归一操作得到transformer库有没有java的_Google_293,再将transformer库有没有java的_Google_293输入到全连接神经网络中得到transformer库有没有java的_自然语言处理_295,接着再做一步残差操作得到transformer库有没有java的_Google_296,最后再进行一层归一化操作。
  • 以上是一个transformer库有没有java的_Self_232的操作原理,将transformer库有没有java的_transformer_233transformer库有没有java的_Self_232进行堆叠就组成了transformer库有没有java的_transformer_173的模块,得到的输出为transformer库有没有java的_Self_301。然后在词汇字典中找到当前预测最大概率的单词,并将该单词词向量作为下一阶段的输入,重复以上步骤,直到输出“transformer库有没有java的_transformer_302”字符为止。

transformer库有没有java的_深度学习_303

代码示例

transformer库有没有java的_transformer具体的代码示例如下所示为一个国外博主视频里的代码,并根据上文对代码的一些细节进行了探讨。根据上文中transformer库有没有java的_Self_69-transformer库有没有java的_transformer_70原理示例图可知,严格来看transformer库有没有java的_Self_69-transformer库有没有java的_transformer_70在求注意分布的时候中间其实是有两步线性变换。给定输入向量transformer库有没有java的_深度学习_309 第一步线性变换直接让向量transformer库有没有java的_深度学习_310赋值给transformer库有没有java的_transformer_311transformer库有没有java的_自然语言处理_312transformer库有没有java的_深度学习_313,这一过程以下程序中有所体现,在这里并不会产生歧义。第二步线性变换产生多transformer库有没有java的_transformer_74,假设transformer库有没有java的_Self_315的时候,按理说transformer库有没有java的_transformer_311要与transformer库有没有java的_自然语言处理_79个矩阵transformer库有没有java的_自然语言处理_318进行线性变换得到transformer库有没有java的_自然语言处理_79transformer库有没有java的_Google_320,同理transformer库有没有java的_自然语言处理_312要与transformer库有没有java的_自然语言处理_79个矩阵transformer库有没有java的_Google_323进行线性变换得到transformer库有没有java的_自然语言处理_79transformer库有没有java的_transformer_325transformer库有没有java的_深度学习_313要与transformer库有没有java的_自然语言处理_79个矩阵transformer库有没有java的_自然语言处理_328进行线性变换得到transformer库有没有java的_自然语言处理_79transformer库有没有java的_深度学习_330,如果按照这个方式在程序实现则需要定义24个权重矩阵,非常的麻烦。以下程序中有一个简单的权重定义方法,通过该方法也可以实现以上多transformer库有没有java的_transformer_74的线性变换,以向量transformer库有没有java的_Self_332为例:

  • 首先将向量transformer库有没有java的_Self_333进行截断分成transformer库有没有java的_深度学习_334个向量,即为transformer库有没有java的_深度学习_335其中transformer库有没有java的_深度学习_336transformer库有没有java的_Self_333的第transformer库有没有java的_Self_200个截断向量,transformer库有没有java的_Google_339是单位矩阵,transformer库有没有java的_transformer_340是零矩阵。
  • 然后对transformer库有没有java的_Self_341用相同的权重矩阵transformer库有没有java的_深度学习_342进行线性变换,此时可以发现,训练过程的时候只需要更新权重矩阵transformer库有没有java的_自然语言处理_343即可,而且可以进行多transformer库有没有java的_深度学习_344线性变换,transformer库有没有java的_transformer_345个权重矩阵可以表示为:transformer库有没有java的_Self_346其中权重矩阵transformer库有没有java的_Google_347
import torch
import torch.nn as nn
import os

class SelfAttention(nn.Module):
	def __init__(self, embed_size, heads):
		super(SelfAttention, self).__init__()
		self.embed_size = embed_size
		self.heads = heads
		self.head_dim = embed_size // heads

		assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"

		self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
		self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
		self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
		self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

	def forward(self, values, keys, query, mask):
		N =query.shape[0]
		value_len , key_len , query_len = values.shape[1], keys.shape[1], query.shape[1]

		# split embedding into self.heads pieces
		values = values.reshape(N, value_len, self.heads, self.head_dim)
		keys = keys.reshape(N, key_len, self.heads, self.head_dim)
		queries = query.reshape(N, query_len, self.heads, self.head_dim)
		
		values = self.values(values)
		keys = self.keys(keys)
		queries = self.queries(queries)

		energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)
		# queries shape: (N, query_len, heads, heads_dim)
		# keys shape : (N, key_len, heads, heads_dim)
		# energy shape: (N, heads, query_len, key_len)

		if mask is not None:
			energy = energy.masked_fill(mask == 0, float("-1e20"))

		attention = torch.softmax(energy/ (self.embed_size ** (1/2)), dim=3)

		out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)
		# attention shape: (N, heads, query_len, key_len)
		# values shape: (N, value_len, heads, heads_dim)
		# (N, query_len, heads, head_dim)

		out = self.fc_out(out)
		return out


class TransformerBlock(nn.Module):
	def __init__(self, embed_size, heads, dropout, forward_expansion):
		super(TransformerBlock, self).__init__()
		self.attention = SelfAttention(embed_size, heads)
		self.norm1 = nn.LayerNorm(embed_size)
		self.norm2 = nn.LayerNorm(embed_size)

		self.feed_forward = nn.Sequential(
			nn.Linear(embed_size, forward_expansion*embed_size),
			nn.ReLU(),
			nn.Linear(forward_expansion*embed_size, embed_size)
		)
		self.dropout = nn.Dropout(dropout)

	def forward(self, value, key, query, mask):
		attention = self.attention(value, key, query, mask)

		x = self.dropout(self.norm1(attention + query))
		forward = self.feed_forward(x)
		out = self.dropout(self.norm2(forward + x))
		return out


class Encoder(nn.Module):
	def __init__(
			self,
			src_vocab_size,
			embed_size,
			num_layers,
			heads,
			device,
			forward_expansion,
			dropout,
			max_length,
		):
		super(Encoder, self).__init__()
		self.embed_size = embed_size
		self.device = device
		self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
		self.position_embedding = nn.Embedding(max_length, embed_size)

		self.layers = nn.ModuleList(
			[
				TransformerBlock(
					embed_size,
					heads,
					dropout=dropout,
					forward_expansion=forward_expansion,
					)
				for _ in range(num_layers)]
		)
		self.dropout = nn.Dropout(dropout)


	def forward(self, x, mask):
		N, seq_length = x.shape
		positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
		out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
		for layer in self.layers:
			out = layer(out, out, out, mask)

		return out


class DecoderBlock(nn.Module):
	def __init__(self, embed_size, heads, forward_expansion, dropout, device):
		super(DecoderBlock, self).__init__()
		self.attention = SelfAttention(embed_size, heads)
		self.norm = nn.LayerNorm(embed_size)
		self.transformer_block = TransformerBlock(
			embed_size, heads, dropout, forward_expansion
		)

		self.dropout = nn.Dropout(dropout)

	def forward(self, x, value, key, src_mask, trg_mask):
		attention = self.attention(x, x, x, trg_mask)
		query = self.dropout(self.norm(attention + x))
		out = self.transformer_block(value, key, query, src_mask)
		return out

class Decoder(nn.Module):
	def __init__(
			self,
			trg_vocab_size,
			embed_size,
			num_layers,
			heads,
			forward_expansion,
			dropout,
			device,
			max_length,
	):
		super(Decoder, self).__init__()
		self.device = device
		self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
		self.position_embedding = nn.Embedding(max_length, embed_size)
		self.layers = nn.ModuleList(
			[DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
			for _ in range(num_layers)]
			)
		self.fc_out = nn.Linear(embed_size, trg_vocab_size)
		self.dropout = nn.Dropout(dropout)

	def forward(self, x ,enc_out , src_mask, trg_mask):
		N, seq_length = x.shape
		positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
		x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

		for layer in self.layers:
			x = layer(x, enc_out, enc_out, src_mask, trg_mask)

		out =self.fc_out(x)
		return out


class Transformer(nn.Module):
	def __init__(
			self,
			src_vocab_size,
			trg_vocab_size,
			src_pad_idx,
			trg_pad_idx,
			embed_size = 256,
			num_layers = 6,
			forward_expansion = 4,
			heads = 8,
			dropout = 0,
			device="cuda",
			max_length=100
		):
		super(Transformer, self).__init__()
		self.encoder = Encoder(
			src_vocab_size,
			embed_size,
			num_layers,
			heads,
			device,
			forward_expansion,
			dropout,
			max_length
			)
		self.decoder = Decoder(
			trg_vocab_size,
			embed_size,
			num_layers,
			heads,
			forward_expansion,
			dropout,
			device,
			max_length
			)


		self.src_pad_idx = src_pad_idx
		self.trg_pad_idx = trg_pad_idx
		self.device = device


	def make_src_mask(self, src):
		src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
		# (N, 1, 1, src_len)
		return src_mask.to(self.device)

	def make_trg_mask(self, trg):
		N, trg_len = trg.shape
		trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
			N, 1, trg_len, trg_len
		)
		return trg_mask.to(self.device)

	def forward(self, src, trg):
		src_mask = self.make_src_mask(src)
		trg_mask = self.make_trg_mask(trg)
		enc_src = self.encoder(src, src_mask)
		out = self.decoder(trg, enc_src, src_mask, trg_mask)
		return out


if __name__ == '__main__':
	device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	print(device)
	x = torch.tensor([[1,5,6,4,3,9,5,2,0],[1,8,7,3,4,5,6,7,2]]).to(device)
	trg = torch.tensor([[1,7,4,3,5,9,2,0],[1,5,6,2,4,7,6,2]]).to(device)

	src_pad_idx = 0
	trg_pad_idx = 0
	src_vocab_size = 10
	trg_vocab_size = 10
	model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(device)
	out = model(x, trg[:, : -1])
	print(out.shape)