1、位置编码的意义
对于序列数据,目前存在三种主流的建模方式:卷积操作、循环操作和自注意力。其中,卷积和循环操作都具有局部性,即只作用目标元素的若干邻居上,而自注意力则是一种全局操作。具有局部性的操作,可以天然地注意到了元素间的相对位置;而注意力机制则是位置不敏感的·,即使调换序列中两个元素的位置对编码后的结果也不会产生影响。
因此,有必要将元素对应的位置信息添加到表示中,或者在计算注意力得分时考虑两个元素之间的相对位置。这些方法统称为位置编码,可以分为绝对位置编码和相对位置编码。
2、绝对位置编码
最为经典的位置编码莫过于 BERT[1] 模型所使用的,即直接将位置的表示加到token的表示上,而每个位置的表示则为一个可学习的向量。这种编码方式,据我所知最早是由ConvS2S[2]提出,被BERT、GPT2[3]、ERNIE[4]、ALBERT[5]、electra[6] 等模型所采用。
以上的位置编码被称为learnable绝对位置编码,存在着两个问题:(1) 位置编码本身通过大量数据才能学习到;(2) 位置向量之间的相对关系没有被利用到,如位置1和位置2之间的相似性应比位置1和位置9之间的相似性高。当然这些问题都可以通过大规模语料上的预训练来缓解。与learnable绝对位置编码相对的则是fixed绝对位置编码,以三角式位置编码[7]为代表。
3、相对位置编码
绝对位置编码是将位置编码直接嵌入到序列的表示中,而相对位置编码则是指在计算注意力分数的时候,直接考虑两个token之间的相对位置,即
4、绝对位置 v.s. 相对位置
绝对位置编码具有实施简单、计算速度快的优点。而其缺点也是明显的,因为真正重要的往往不是绝对位置,而是token之间的相对位置。在下面三个句子中,东西的含义和东西与鱼的相对位置有关,而与东西本身的绝对位置无关。
有个东西在吃鱼
小明放眼望去,看到有个东西在吃鱼
有条鱼在吃东西
虽然三角式位置编码,作为一种绝对位置编码,包含了一定相对位置信息,但这种相对位置信息仅仅包含在位置编码内部。当添加位置编码的表示在计算自注意力的时候,表示中的相对位置信息是否仍然保留就是个未知数了。
此外,对于线性attention而言,相对位置编码无法直接得到应用。因此,沿着三角式位置编码的思路,进一步发展绝对位置编码是有必要的。
5、旋转式位置编码 RoPE
5、 代码及实验部分
绝对位置编码
可学习的绝对位置编码
class LearnableAbsolutePositionEmbedding(nn.Module):
def __init__(self, max_position_embeddings, hidden_size):
super().__init__()
self.is_absolute = True
self.embeddings = nn.Embedding(max_position_embeddings, hidden_size)
self.register_buffer('position_ids', torch.arange(max_position_embeddings))
def forward(self, x):
"""
return (b l d) / (b h l d)
"""
position_ids = self.position_ids[:x.size(-2)]
if x.dim() == 3:
return x + self.embeddings(position_ids)[None, :, :]
elif x.dim() == 4:
h = x.size(1)
x = rearrange(x, 'b h l d -> b l (h d)')
x = x + self.embeddings(position_ids)[None, :, :]
x = rearrange(x, 'b l (h d) -> b h l d', h=h)
return x
三角式绝对位置编码
class FixedAbsolutePositionEmbedding(nn.Module):
def __init__(self, max_position_embeddings, hidden_size, position_embedding_type):
super().__init__()
self.position_embedding_type = position_embedding_type
self.is_absolute = True
inv_freq = 1. / (10000 ** (torch.arange(0, hidden_size, 2, dtype=torch.float) / hidden_size))
position = torch.arange(max_position_embeddings, dtype=torch.float)
sinusoid_inp = torch.einsum('i,j -> ij', position, inv_freq)
embeddings = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
self.register_buffer('embeddings', embeddings)
def forward_fixed(self, x):
"""
return (b l d)
"""
return x + self.embeddings[None, :x.size(1), :]
def forward_rope(self, x):
"""
return (b l d)
"""
embeddings = self.embeddings[None, :x.size(1), :] # b l d
embeddings = rearrange(embeddings, 'b l (j d) -> b l j d', j=2)
sin, cos = embeddings.unbind(dim=-2) # b l d//2
sin, cos = map(lambda t: repeat(t, '... d -> ... (d 2)'), (sin, cos)) # b l d
return x * cos + self.rotate_every_two(x) * sin
@staticmethod
def rotate_every_two(x):
x = rearrange(x, '... (d j) -> ... d j', j=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, '... d j -> ... (d j)')
def _forward(self, x):
if self.position_embedding_type == 'fixed':
return self.forward_fixed(x)
elif self.position_embedding_type == 'rope':
return self.forward_rope(x)
def forward(self, x):
if x.dim() == 3:
return self._forward(x)
elif x.dim() == 4:
h = x.size(1)
x = rearrange(x, 'b h l d -> (b h) l d')
x = self._forward(x)
x = rearrange(x, '(b h) l d -> b h l d', h=h)
return x
相对位置编码
class RelativePositionEmbedding(nn.Module):
def __init__(self,
relative_attention_num_buckets, num_attention_heads,
hidden_size, position_embedding_type):
super().__init__()
self.relative_attention_num_buckets = relative_attention_num_buckets
self.position_embedding_type = position_embedding_type
self.num_attention_heads = num_attention_heads
self.is_absolute = False
if position_embedding_type == 'bias':
self.embeddings = nn.Embedding(relative_attention_num_buckets, num_attention_heads)
elif position_embedding_type == 'contextual(1)':
self.embeddings = nn.Embedding(relative_attention_num_buckets, hidden_size)
self.to_r = nn.Linear(hidden_size, hidden_size, bias=False)
elif position_embedding_type == 'contextual(2)':
self.embeddings = nn.Embedding(relative_attention_num_buckets, hidden_size)
def compute_bias(self, q, k, to_q=None, to_k=None):
"""
q, k: [b h l d]
return [b h l l]
"""
h = self.num_attention_heads
query_position = torch.arange(q.size(2), dtype=torch.long, device=self.embeddings.weight.device)[:, None]
key_position = torch.arange(k.size(2), dtype=torch.long, device=self.embeddings.weight.device)[None, :]
relative_position = query_position - key_position
relative_position_bucket = self._relative_position_bucket(
relative_position,
num_buckets=self.relative_attention_num_buckets
)
if self.position_embedding_type == 'bias':
bias = self.embeddings(relative_position_bucket)
bias = rearrange(bias, 'm n h -> 1 h m n')
elif self.position_embedding_type == 'contextual(1)':
r = self.embeddings(relative_position_bucket)
r = self.to_r(r)
r = rearrange(r, 'm n (h d) -> h m n d', h=h)
bias = torch.einsum('b h m d, h m n d -> b h m n', q, r)
elif self.position_embedding_type == 'contextual(2)':
r = self.embeddings(relative_position_bucket)
kr = to_k(r)
qr = to_q(r)
kr = rearrange(kr, 'm n (h d) -> h m n d', h=h)
qr = rearrange(qr, 'm n (h d) -> h m n d', h=h)
bias1 = torch.einsum('b h m d, h m n d -> b h m n', q, kr)
bias2 = torch.einsum('b h n d, h m n d -> b h m n', k, qr)
bias = bias1 + bias2
return bias
@staticmethod
def _relative_position_bucket(relative_position, num_buckets, max_distance=128):
"""
relative_position: [m n]
"""
num_buckets //= 2
relative_buckets = (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
max_exact = num_buckets // 2
is_small = relative_position < max_exact
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
Embedding层
class Embedddings(nn.Module):
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.dropout = StableDropout(config.hidden_dropout_prob)
self.dense = nn.Linear(config.embedding_size, config.hidden_size)
if config.position_embedding_type == 'learnable':
self.position_embeddings = LearnableAbsolutePositionEmbedding(
max_position_embeddings=config.max_position_embeddings,
hidden_size=config.hidden_size
)
elif config.position_embedding_type in ('fixed', 'rope'):
self.position_embeddings = FixedAbsolutePositionEmbedding(
max_position_embeddings=config.max_position_embeddings,
hidden_size=config.hidden_size,
position_embedding_type=config.position_embedding_type
)
def forward(self, input_ids):
embeds = self.word_embeddings(input_ids)
embeds = self.dropout(embeds)
embeds = self.dense(embeds)
if hasattr(self, 'position_embeddings'):
embeds = self.position_embeddings(embeds)
return embeds
注意力
class Attention(nn.Module):
def __init__(self, config):
super().__init__()
self.n_heads = config.num_attention_heads
dim_heads = config.hidden_size // config.num_attention_heads
self.to_q = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.to_k = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.to_v = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.to_out = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.dropout = StableDropout(config.hidden_dropout_prob)
if config.encoder_layer == 'transformer':
self.attn_fn = TransformerAttention(config)
elif config.encoder_layer == 'performer':
self.attn_fn = PerformerAttention(config)
else:
raise NotImplementedError
def forward(self, x, mask, pos_emb):
h = self.n_heads
q = self.to_q(x)
k = self.to_k(x)
v = self.to_v(x)
q, k, v = map(lambda t: rearrange(t, 'b l (h d) -> b h l d', h=h), (q, k, v))
context = self.attn_fn(q, k, v, mask, pos_emb, to_q=self.to_q, to_k=self.to_k)
out = self.to_out(context)
out = self.dropout(out)
return out
自注意力
class TransformerAttention(nn.Module):
def __init__(self, config):
super().__init__()
attention_head_size = config.hidden_size // config.num_attention_heads
self.scale = attention_head_size ** -0.5
# self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.dropout = StableDropout(config.attention_probs_dropout_prob)
def forward(self, q, k, v, mask, pos_emb, to_q, to_k):
"""
q, k, v: [b h l d]
mask: [b l]
"""
if pos_emb is not None and pos_emb.is_absolute is True:
q = pos_emb(q)
k = pos_emb(k)
dots = torch.einsum('b h m d, b h n d -> b h m n', q, k)
if pos_emb is not None and pos_emb.is_absolute is False:
bias = pos_emb.compute_bias(q, k, to_q, to_k)
dots = dots + bias
# assert mask is not None
# if mask is not None:
mask = mask[:, None, None, :] & mask[:, None, :, None]
# dots = dots.masked_fill(~mask, -10000.)
# probs = dots.softmax(dim=-1)
probs = XSoftmax.apply(dots, mask, -1)
probs = self.dropout(probs)
context = torch.einsum('b h m n, b h n d -> b h m d', probs, v)
context = rearrange(context, 'b h m d -> b m (h d)')
return context
线性注意力
class PerformerAttention(nn.Module):
def __init__(self, config):
super().__init__()
attention_head_size = config.hidden_size // config.num_attention_heads
self.attn = FastAttention(dim_heads=attention_head_size, causal=False)
def forward(self, q, k, v, mask, pos_emb, **kwargs):
"""
q, k, v: [b h l d]
mask: [b l]
"""
if pos_emb is not None:
assert pos_emb.is_absolute is True
q = pos_emb(q)
k = pos_emb(k)
mask = mask[:, None, :, None]
v = v.masked_fill(~mask, 0.)
context = self.attn(q, k, v)
context = rearrange(context, 'b h l d -> b l (h d)')
return context
Transformer Encoder
class Encoder(nn.Module):
def __init__(self, config):
super().__init__()
dim_heads = config.hidden_size // config.num_attention_heads
if config.position_embedding_type == 'layerwise_learnable':
self.position_embeddings = LearnableAbsolutePositionEmbedding(
max_position_embeddings=config.max_position_embeddings,
hidden_size=config.hidden_size
)
elif config.position_embedding_type in ('layerwise_fixed', 'layerwise_rope'):
self.position_embeddings = FixedAbsolutePositionEmbedding(
max_position_embeddings=config.max_position_embeddings,
hidden_size=dim_heads,
position_embedding_type=config.position_embedding_type.split('_')[-1],
)
elif config.position_embedding_type in ('layerwise_bias', 'layerwise_contextual(1)', 'layerwise_contextual(2)'):
self.position_embeddings = RelativePositionEmbedding(
config.relative_attention_num_buckets,
config.num_attention_heads,
config.hidden_size,
position_embedding_type=config.position_embedding_type.split('_')[-1]
)
else:
self.position_embeddings = None
self.layer = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, x, mask):
for layer_module in self.layer:
x = layer_module(x, mask, self.position_embeddings)
return x
完整模型
class TDLMModel(TDLMPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.embeddings = Embedddings(config)
self.encoder = Encoder(config)
self.init_weights()
def forward(self, input_ids, attention_mask):
"""
input_ids: [b, l]
attention_mask: [b, l]
"""
attention_mask = attention_mask.bool()
x = self.embeddings(input_ids)
x = self.encoder(x, attention_mask)
return x
实验及结果
实验设置
- 参数量 由于GPU资源的限制,实验中所使用的Transformer模型(29M)在参数量上,要比BERT-base(110M)小很多。
- 训练语料 模型在英语维基百科语料(13G文本)上训练,batch_size通过梯度累计的方式设置为2048,一共训练了20K步,这相当于全部语料的1/3左右。
- 评价指标 选择bpd(Bits Per Dimension)作为语言模型的评价指标,bpd=loss/ln(2)。
TDLMConfig { [0/1855]
"attention_probs_dropout_prob": 0.1,
"embedding_size": 128,
"encoder_layer": "transformer",
"glu": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 512,
"initializer_range": 0.02,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"num_attention_heads": 8,
"num_hidden_layers": 8,
"pad_token_id": 0,
"position_embedding_type": "layerwise_rope",
"pre_norm": true,
"relative_attention_num_buckets": 32,
"transformers_version": "4.6.1",
"vocab_size": 30522
}
首先,进行了简单的超参数调节,最后发现初始学习率设为2e-4比较好(使用了linear_schedule_with_warmup)。
在embedding层应用绝对位置编码,如下图,可以发现RoPE优于三角式位置编码和可学习的位置编码,bqd最低为3.05。
将RoPE与相对位置编码进行比较时,可以发现contextual模式的相对位置编码还是优于RoPE的。但是相对于相对位置编码,RoPE仍然有以下优势
- RoPE本身不会给模型引入额外的参数
- RoPE是直接作用于q,k上,因而无需修改注意力的计算过程。进一步说,RoPE可以直接且方便的作用在Transformer变体上,如Performer[14]、Reformer[15]等
Pre-Norm v.s. Post-Norm
此外,这里对比了transformer模型中的pre-norm和post-norm。如下图所示,对于语言模型而言,post-norm还是更好。也许pre-norm只适合CV领域上的任务吧。