1、位置编码的意义

对于序列数据,目前存在三种主流的建模方式:卷积操作、循环操作和自注意力。其中,卷积和循环操作都具有局部性,即只作用目标元素的若干邻居上,而自注意力则是一种全局操作。具有局部性的操作,可以天然地注意到了元素间的相对位置;而注意力机制则是位置不敏感的·,即使调换序列中两个元素的位置对编码后的结果也不会产生影响。

embedding后进行位置编码实现代码 transformer中位置编码_代码实现

因此,有必要将元素对应的位置信息添加到表示中,或者在计算注意力得分时考虑两个元素之间的相对位置。这些方法统称为位置编码,可以分为绝对位置编码和相对位置编码。

2、绝对位置编码

最为经典的位置编码莫过于 BERT[1] 模型所使用的,即直接将位置的表示加到token的表示上,而每个位置的表示则为一个可学习的向量。这种编码方式,据我所知最早是由ConvS2S[2]提出,被BERT、GPT2[3]、ERNIE[4]、ALBERT[5]、electra[6] 等模型所采用。

以上的位置编码被称为learnable绝对位置编码,存在着两个问题:(1) 位置编码本身通过大量数据才能学习到;(2) 位置向量之间的相对关系没有被利用到,如位置1和位置2之间的相似性应比位置1和位置9之间的相似性高。当然这些问题都可以通过大规模语料上的预训练来缓解。与learnable绝对位置编码相对的则是fixed绝对位置编码,以三角式位置编码[7]为代表。

embedding后进行位置编码实现代码 transformer中位置编码_深度学习_02

 

embedding后进行位置编码实现代码 transformer中位置编码_代码实现_03

 3、相对位置编码

绝对位置编码是将位置编码直接嵌入到序列的表示中,而相对位置编码则是指在计算注意力分数的时候,直接考虑两个token之间的相对位置,即

embedding后进行位置编码实现代码 transformer中位置编码_代码实现_04

 4、绝对位置 v.s. 相对位置

绝对位置编码具有实施简单、计算速度快的优点。而其缺点也是明显的,因为真正重要的往往不是绝对位置,而是token之间的相对位置。在下面三个句子中,东西的含义和东西与鱼的相对位置有关,而与东西本身的绝对位置无关。

有个东西在吃鱼
小明放眼望去,看到有个东西在吃鱼
有条鱼在吃东西

虽然三角式位置编码,作为一种绝对位置编码,包含了一定相对位置信息,但这种相对位置信息仅仅包含在位置编码内部。当添加位置编码的表示在计算自注意力的时候,表示中的相对位置信息是否仍然保留就是个未知数了。

此外,对于线性attention而言,相对位置编码无法直接得到应用。因此,沿着三角式位置编码的思路,进一步发展绝对位置编码是有必要的。

5、旋转式位置编码 RoPE

embedding后进行位置编码实现代码 transformer中位置编码_代码实现_05

 

embedding后进行位置编码实现代码 transformer中位置编码_深度学习_06

 

embedding后进行位置编码实现代码 transformer中位置编码_卷积_07

embedding后进行位置编码实现代码 transformer中位置编码_机器学习_08

 

 5、 代码及实验部分

绝对位置编码

可学习的绝对位置编码

class LearnableAbsolutePositionEmbedding(nn.Module):
    def __init__(self, max_position_embeddings, hidden_size):
        super().__init__()
        self.is_absolute = True
        self.embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        self.register_buffer('position_ids', torch.arange(max_position_embeddings))

    def forward(self, x):
        """
        return (b l d) / (b h l d)
        """
        position_ids = self.position_ids[:x.size(-2)]

        if x.dim() == 3:
            return x + self.embeddings(position_ids)[None, :, :]

        elif x.dim() == 4:
            h = x.size(1)
            x = rearrange(x, 'b h l d -> b l (h d)')
            x = x + self.embeddings(position_ids)[None, :, :]
            x = rearrange(x, 'b l (h d) -> b h l d', h=h)
            return x

三角式绝对位置编码

class FixedAbsolutePositionEmbedding(nn.Module):
    def __init__(self, max_position_embeddings, hidden_size, position_embedding_type):
        super().__init__()

        self.position_embedding_type = position_embedding_type
        self.is_absolute = True

        inv_freq = 1. / (10000 ** (torch.arange(0, hidden_size, 2, dtype=torch.float) / hidden_size))
        position = torch.arange(max_position_embeddings, dtype=torch.float)
        sinusoid_inp = torch.einsum('i,j -> ij', position, inv_freq)
        embeddings = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
        self.register_buffer('embeddings', embeddings)

    def forward_fixed(self, x):
        """
        return (b l d)
        """
        return x + self.embeddings[None, :x.size(1), :]

    def forward_rope(self, x):
        """
        return (b l d)
        """
        embeddings = self.embeddings[None, :x.size(1), :] # b l d
        embeddings = rearrange(embeddings, 'b l (j d) -> b l j d', j=2)
        sin, cos = embeddings.unbind(dim=-2) # b l d//2
        sin, cos = map(lambda t: repeat(t, '... d -> ... (d 2)'), (sin, cos)) # b l d
        return x * cos + self.rotate_every_two(x) * sin

    @staticmethod
    def rotate_every_two(x):
        x = rearrange(x, '... (d j) -> ... d j', j=2)
        x1, x2 = x.unbind(dim=-1)
        x = torch.stack((-x2, x1), dim=-1)
        return rearrange(x, '... d j -> ... (d j)')

    def _forward(self, x):
        if self.position_embedding_type == 'fixed':
            return self.forward_fixed(x)

        elif self.position_embedding_type == 'rope':
            return self.forward_rope(x)

    def forward(self, x):
        if x.dim() == 3:
            return self._forward(x)

        elif x.dim() == 4:
            h = x.size(1)
            x = rearrange(x, 'b h l d -> (b h) l d')
            x = self._forward(x)
            x = rearrange(x, '(b h) l d -> b h l d', h=h)
            return x

相对位置编码

class RelativePositionEmbedding(nn.Module):
    def __init__(self, 
                 relative_attention_num_buckets, num_attention_heads, 
                 hidden_size, position_embedding_type):

        super().__init__()

        self.relative_attention_num_buckets = relative_attention_num_buckets
        self.position_embedding_type = position_embedding_type
        self.num_attention_heads = num_attention_heads
        self.is_absolute = False

        if position_embedding_type == 'bias':
            self.embeddings = nn.Embedding(relative_attention_num_buckets, num_attention_heads)

        elif position_embedding_type == 'contextual(1)':
            self.embeddings = nn.Embedding(relative_attention_num_buckets, hidden_size)
            self.to_r = nn.Linear(hidden_size, hidden_size, bias=False)

        elif position_embedding_type == 'contextual(2)':
            self.embeddings = nn.Embedding(relative_attention_num_buckets, hidden_size)

    def compute_bias(self, q, k, to_q=None, to_k=None):
        """
        q, k: [b h l d]
        return [b h l l]
        """
        h = self.num_attention_heads
        query_position = torch.arange(q.size(2), dtype=torch.long, device=self.embeddings.weight.device)[:, None]
        key_position   = torch.arange(k.size(2), dtype=torch.long, device=self.embeddings.weight.device)[None, :]

        relative_position = query_position - key_position
        relative_position_bucket = self._relative_position_bucket(
            relative_position,
            num_buckets=self.relative_attention_num_buckets
        )

        if self.position_embedding_type == 'bias':
            bias = self.embeddings(relative_position_bucket)
            bias = rearrange(bias, 'm n h -> 1 h m n')

        elif self.position_embedding_type == 'contextual(1)':
            r = self.embeddings(relative_position_bucket)
            r = self.to_r(r)
            r = rearrange(r, 'm n (h d) -> h m n d', h=h)

            bias = torch.einsum('b h m d, h m n d -> b h m n', q, r)

        elif self.position_embedding_type == 'contextual(2)':
            r = self.embeddings(relative_position_bucket)

            kr = to_k(r)
            qr = to_q(r)

            kr = rearrange(kr, 'm n (h d) -> h m n d', h=h)
            qr = rearrange(qr, 'm n (h d) -> h m n d', h=h)

            bias1 = torch.einsum('b h m d, h m n d -> b h m n', q, kr)
            bias2 = torch.einsum('b h n d, h m n d -> b h m n', k, qr)

            bias = bias1 + bias2

        return bias

    @staticmethod
    def _relative_position_bucket(relative_position, num_buckets, max_distance=128):
        """
        relative_position: [m n]
        """

        num_buckets //= 2
        relative_buckets = (relative_position > 0).to(torch.long) * num_buckets
        relative_position = torch.abs(relative_position)

        max_exact = num_buckets // 2
        is_small = relative_position < max_exact

        relative_position_if_large = max_exact + (
            torch.log(relative_position.float() / max_exact)
            / math.log(max_distance / max_exact)
            * (num_buckets - max_exact)
        ).to(torch.long)

        relative_position_if_large = torch.min(
            relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
        )

        relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
        return relative_buckets

Embedding层

class Embedddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
        # self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.dropout = StableDropout(config.hidden_dropout_prob)
        self.dense = nn.Linear(config.embedding_size, config.hidden_size)
        
        if config.position_embedding_type == 'learnable':
            self.position_embeddings = LearnableAbsolutePositionEmbedding(
                max_position_embeddings=config.max_position_embeddings, 
                hidden_size=config.hidden_size
            )
        
        elif config.position_embedding_type in ('fixed', 'rope'):
            self.position_embeddings = FixedAbsolutePositionEmbedding(
                max_position_embeddings=config.max_position_embeddings,
                hidden_size=config.hidden_size,
                position_embedding_type=config.position_embedding_type
            )

    def forward(self, input_ids):
        embeds = self.word_embeddings(input_ids)
        embeds = self.dropout(embeds)
        embeds = self.dense(embeds)

        if hasattr(self, 'position_embeddings'):
            embeds = self.position_embeddings(embeds)

        return embeds

注意力

class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.n_heads = config.num_attention_heads
        dim_heads = config.hidden_size // config.num_attention_heads

        self.to_q = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
        self.to_k = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
        self.to_v = nn.Linear(config.hidden_size, config.hidden_size, bias=False)

        self.to_out  = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
        self.dropout = StableDropout(config.hidden_dropout_prob)

        if config.encoder_layer == 'transformer':
            self.attn_fn = TransformerAttention(config)

        elif config.encoder_layer == 'performer':
            self.attn_fn = PerformerAttention(config)

        else:
            raise NotImplementedError

    def forward(self, x, mask, pos_emb):
        h = self.n_heads

        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        q, k, v = map(lambda t: rearrange(t, 'b l (h d) -> b h l d', h=h), (q, k, v))

        context = self.attn_fn(q, k, v, mask, pos_emb, to_q=self.to_q, to_k=self.to_k)
        out = self.to_out(context)
        out = self.dropout(out)

        return out

自注意力

class TransformerAttention(nn.Module):
    def __init__(self, config):
        super().__init__()

        attention_head_size = config.hidden_size // config.num_attention_heads
        self.scale = attention_head_size ** -0.5
        # self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.dropout = StableDropout(config.attention_probs_dropout_prob)

    def forward(self, q, k, v, mask, pos_emb, to_q, to_k):
        """
        q, k, v: [b h l d]
        mask: [b l]
        """
        if pos_emb is not None and pos_emb.is_absolute is True:
            q = pos_emb(q)
            k = pos_emb(k)

        dots = torch.einsum('b h m d, b h n d -> b h m n', q, k)

        if pos_emb is not None and pos_emb.is_absolute is False:
            bias = pos_emb.compute_bias(q, k, to_q, to_k)
            dots = dots + bias

        # assert mask is not None
        # if mask is not None:
        mask = mask[:, None, None, :] & mask[:, None, :, None]
        # dots = dots.masked_fill(~mask, -10000.)
        # probs = dots.softmax(dim=-1)
        probs = XSoftmax.apply(dots, mask, -1)

        probs = self.dropout(probs)

        context = torch.einsum('b h m n, b h n d -> b h m d', probs, v)
        context = rearrange(context, 'b h m d -> b m (h d)')

        return context

线性注意力

class PerformerAttention(nn.Module):
    def __init__(self, config):
        super().__init__()

        attention_head_size = config.hidden_size // config.num_attention_heads
        self.attn = FastAttention(dim_heads=attention_head_size, causal=False)

    def forward(self, q, k, v, mask, pos_emb, **kwargs):
        """
        q, k, v: [b h l d]
        mask: [b l]
        """
        if pos_emb is not None:
            assert pos_emb.is_absolute is True
            q = pos_emb(q)
            k = pos_emb(k)

        mask = mask[:, None, :, None]
        v = v.masked_fill(~mask, 0.)

        context = self.attn(q, k, v)
        context = rearrange(context, 'b h l d -> b l (h d)')
        return context

Transformer Encoder

class Encoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        dim_heads = config.hidden_size // config.num_attention_heads 

        if config.position_embedding_type == 'layerwise_learnable':
            self.position_embeddings = LearnableAbsolutePositionEmbedding(
                max_position_embeddings=config.max_position_embeddings, 
                hidden_size=config.hidden_size
            )
        
        elif config.position_embedding_type in ('layerwise_fixed', 'layerwise_rope'):
            self.position_embeddings = FixedAbsolutePositionEmbedding(
                max_position_embeddings=config.max_position_embeddings,
                hidden_size=dim_heads,
                position_embedding_type=config.position_embedding_type.split('_')[-1],
            )

        elif config.position_embedding_type in ('layerwise_bias', 'layerwise_contextual(1)', 'layerwise_contextual(2)'):
            self.position_embeddings = RelativePositionEmbedding( 
                 config.relative_attention_num_buckets, 
                 config.num_attention_heads, 
                 config.hidden_size, 
                 position_embedding_type=config.position_embedding_type.split('_')[-1]
            )

        else:
            self.position_embeddings = None

        self.layer = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_hidden_layers)])

    def forward(self, x, mask):
        for layer_module in self.layer:
            x = layer_module(x, mask, self.position_embeddings)

        return x

完整模型

class TDLMModel(TDLMPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.embeddings = Embedddings(config)
        self.encoder  = Encoder(config)

        self.init_weights()

    def forward(self, input_ids, attention_mask):
        """
        input_ids: [b, l]
        attention_mask: [b, l]
        """
        attention_mask = attention_mask.bool()

        x = self.embeddings(input_ids)
        x = self.encoder(x, attention_mask)

        return x

实验及结果

实验设置

  • 参数量 由于GPU资源的限制,实验中所使用的Transformer模型(29M)在参数量上,要比BERT-base(110M)小很多。
  • 训练语料 模型在英语维基百科语料(13G文本)上训练,batch_size通过梯度累计的方式设置为2048,一共训练了20K步,这相当于全部语料的1/3左右。
  • 评价指标 选择bpd(Bits Per Dimension)作为语言模型的评价指标,bpd=loss/ln(2)。
TDLMConfig {                                                                                                                             [0/1855]
  "attention_probs_dropout_prob": 0.1,
  "embedding_size": 128,
  "encoder_layer": "transformer",
  "glu": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 512,
  "initializer_range": 0.02,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "num_attention_heads": 8,
  "num_hidden_layers": 8,
  "pad_token_id": 0,
  "position_embedding_type": "layerwise_rope",
  "pre_norm": true,
  "relative_attention_num_buckets": 32,
  "transformers_version": "4.6.1",
  "vocab_size": 30522
}

首先,进行了简单的超参数调节,最后发现初始学习率设为2e-4比较好(使用了linear_schedule_with_warmup)。

embedding后进行位置编码实现代码 transformer中位置编码_代码实现_09

在embedding层应用绝对位置编码,如下图,可以发现RoPE优于三角式位置编码和可学习的位置编码,bqd最低为3.05。

embedding后进行位置编码实现代码 transformer中位置编码_卷积_10

将RoPE与相对位置编码进行比较时,可以发现contextual模式的相对位置编码还是优于RoPE的。但是相对于相对位置编码,RoPE仍然有以下优势

  • RoPE本身不会给模型引入额外的参数
  • RoPE是直接作用于q,k上,因而无需修改注意力的计算过程。进一步说,RoPE可以直接且方便的作用在Transformer变体上,如Performer[14]、Reformer[15]等

 

Pre-Norm v.s. Post-Norm

此外,这里对比了transformer模型中的pre-norm和post-norm。如下图所示,对于语言模型而言,post-norm还是更好。也许pre-norm只适合CV领域上的任务吧。