多头注意力机制

  • 介绍
  • 代码实现
  • 使用pytorch函数


介绍

多头自注意力机制是自注意力机制(Self-Attention)的一种扩展形式,它通过将输入数据分为多个头(Head),并对每个头进行自注意力计算,最后将多个头的结果拼接起来,得到最终的输出。使用多头自注意力可以使得模型在处理长序列数据时更加有效。

代码实现

多头注意力机制(Multi-Head Attention)的源码实现可以分为以下几个步骤:

  1. 将输入张量 x 通过查询、键、值的线性变换得到三个张量 qkv,具体来说,就是分别通过三个 nn.Linear 层对输入张量 x 进行线性变换。
  2. qkv 分别通过头的数量 n_heads 个线性变换,得到 qkv 的多个头,具体来说,就是将每个张量分成 n_heads 个子张量,然后对每个子张量分别进行线性变换。
  3. 对每个头分别计算注意力得分,并将注意力得分通过 softmax 函数得到注意力权重。
  4. 对每个头的值 v,按照注意力权重进行加权平均,并将多个头的结果拼接起来得到最终的输出。

下面是一个简单的 PyTorch 实现:

import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, n_heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        
        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)
        self.out_linear = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, x):
        batch_size, seq_len, embed_dim = x.size()
        
        # Step 1: Linear transformations
        q = self.q_linear(x)
        k = self.k_linear(x)
        v = self.v_linear(x)
        
        # Step 2: Split heads
        q_heads = q.view(batch_size, seq_len, self.n_heads, embed_dim // self.n_heads).transpose(1, 2)
        k_heads = k.view(batch_size, seq_len, self.n_heads, embed_dim // self.n_heads).transpose(1, 2)
        v_heads = v.view(batch_size, seq_len, self.n_heads, embed_dim // self.n_heads).transpose(1, 2)
        
        # Step 3: Calculate attention scores
        scores = torch.matmul(q_heads, k_heads.transpose(-2, -1)) / (embed_dim // self.n_heads) ** 0.5
        attn_weights = nn.functional.softmax(scores, dim=-1)
        
        # Step 4: Weighted average
        weighted_v = torch.matmul(attn_weights, v_heads)
        weighted_v = weighted_v.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
        out = self.out_linear(weighted_v)
        
        return out, attn_weights

在这个实现中,我们首先定义了一个 MultiHeadAttention 类,包含了查询、键、值的线性变换层以及输出的线性变换层。在 forward 方法中,我们首先对输入张量 x 进行查询、键、值的线性变换,并将得到的张量分别分成多个头。然后,我们计算头之间的注意力得分,并将得分通过 softmax 函数得到注意力权重。最后,我们按照注意力权重对每个头的值进行加权平均,并将多个头的结果拼接起来得到最终的输出。

需要注意的是,这个实现中并没有考虑批次大小(batch size)的影响,如果要处理多个样本,需要对输入张量的维度进行适当调整。同时,这个实现中使用了 PyTorch 中的矩阵乘法和 softmax 函数,可以通过 GPU 进行加速。

使用pytorch函数

nn.MultiheadAttention 是 PyTorch 中的一个类,它实现了多头自注意力机制(Multi-Head Attention),是一种常用的深度学习模型组件。

在 PyTorch 中,nn.MultiheadAttention 的使用方法比较简单,需要指定输入数据的维度、每个头的维度、头的个数等参数。下面是一个简单的使用示例:

import torch
import torch.nn as nn

batch_size = 10
seq_len = 20
embed_dim = 128
num_heads = 8

input = torch.randn(seq_len, batch_size, embed_dim)
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)

output, attn_weights = multihead_attn(input, input, input)

在这个例子中,我们首先定义了一个输入张量 input,大小为 (seq_len, batch_size, embed_dim),其中 seq_len 表示序列长度,batch_size 表示批次大小,embed_dim 表示每个单词的嵌入维度。然后我们定义了一个 nn.MultiheadAttention 对象 multihead_attn,其中输入维度为 embed_dim,每个头的维度为 embed_dim // num_heads,头的个数为 num_heads。最后我们对输入张量进行 multihead_attn 的前向计算,并输出计算结果和注意力权重。

需要注意的是,nn.MultiheadAttention 中的注意力权重是相对于输入张量的,因此输入张量的维度需要满足一定的要求。具体来说,输入张量的维度应该为 (seq_len, batch_size, embed_dim),其中 seq_len 表示序列长度,batch_size 表示批次大小,embed_dim 应该是头的维度 embed_dim // num_heads 的整数倍。如果输入张量的维度不满足要求,可以使用 nn.Linear 等层进行维度转换。