多头注意力机制
- 介绍
- 代码实现
- 使用pytorch函数
介绍
多头自注意力机制是自注意力机制(Self-Attention)的一种扩展形式,它通过将输入数据分为多个头(Head),并对每个头进行自注意力计算,最后将多个头的结果拼接起来,得到最终的输出。使用多头自注意力可以使得模型在处理长序列数据时更加有效。
代码实现
多头注意力机制(Multi-Head Attention)的源码实现可以分为以下几个步骤:
- 将输入张量
x
通过查询、键、值的线性变换得到三个张量q
、k
、v
,具体来说,就是分别通过三个nn.Linear
层对输入张量x
进行线性变换。 - 将
q
、k
、v
分别通过头的数量n_heads
个线性变换,得到q
、k
、v
的多个头,具体来说,就是将每个张量分成n_heads
个子张量,然后对每个子张量分别进行线性变换。 - 对每个头分别计算注意力得分,并将注意力得分通过 softmax 函数得到注意力权重。
- 对每个头的值
v
,按照注意力权重进行加权平均,并将多个头的结果拼接起来得到最终的输出。
下面是一个简单的 PyTorch 实现:
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, n_heads):
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.n_heads = n_heads
self.q_linear = nn.Linear(embed_dim, embed_dim)
self.k_linear = nn.Linear(embed_dim, embed_dim)
self.v_linear = nn.Linear(embed_dim, embed_dim)
self.out_linear = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
batch_size, seq_len, embed_dim = x.size()
# Step 1: Linear transformations
q = self.q_linear(x)
k = self.k_linear(x)
v = self.v_linear(x)
# Step 2: Split heads
q_heads = q.view(batch_size, seq_len, self.n_heads, embed_dim // self.n_heads).transpose(1, 2)
k_heads = k.view(batch_size, seq_len, self.n_heads, embed_dim // self.n_heads).transpose(1, 2)
v_heads = v.view(batch_size, seq_len, self.n_heads, embed_dim // self.n_heads).transpose(1, 2)
# Step 3: Calculate attention scores
scores = torch.matmul(q_heads, k_heads.transpose(-2, -1)) / (embed_dim // self.n_heads) ** 0.5
attn_weights = nn.functional.softmax(scores, dim=-1)
# Step 4: Weighted average
weighted_v = torch.matmul(attn_weights, v_heads)
weighted_v = weighted_v.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
out = self.out_linear(weighted_v)
return out, attn_weights
在这个实现中,我们首先定义了一个 MultiHeadAttention
类,包含了查询、键、值的线性变换层以及输出的线性变换层。在 forward
方法中,我们首先对输入张量 x
进行查询、键、值的线性变换,并将得到的张量分别分成多个头。然后,我们计算头之间的注意力得分,并将得分通过 softmax 函数得到注意力权重。最后,我们按照注意力权重对每个头的值进行加权平均,并将多个头的结果拼接起来得到最终的输出。
需要注意的是,这个实现中并没有考虑批次大小(batch size)的影响,如果要处理多个样本,需要对输入张量的维度进行适当调整。同时,这个实现中使用了 PyTorch 中的矩阵乘法和 softmax 函数,可以通过 GPU 进行加速。
使用pytorch函数
nn.MultiheadAttention
是 PyTorch 中的一个类,它实现了多头自注意力机制(Multi-Head Attention),是一种常用的深度学习模型组件。
在 PyTorch 中,nn.MultiheadAttention
的使用方法比较简单,需要指定输入数据的维度、每个头的维度、头的个数等参数。下面是一个简单的使用示例:
import torch
import torch.nn as nn
batch_size = 10
seq_len = 20
embed_dim = 128
num_heads = 8
input = torch.randn(seq_len, batch_size, embed_dim)
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
output, attn_weights = multihead_attn(input, input, input)
在这个例子中,我们首先定义了一个输入张量 input
,大小为 (seq_len, batch_size, embed_dim)
,其中 seq_len
表示序列长度,batch_size
表示批次大小,embed_dim
表示每个单词的嵌入维度。然后我们定义了一个 nn.MultiheadAttention
对象 multihead_attn
,其中输入维度为 embed_dim
,每个头的维度为 embed_dim // num_heads
,头的个数为 num_heads
。最后我们对输入张量进行 multihead_attn
的前向计算,并输出计算结果和注意力权重。
需要注意的是,nn.MultiheadAttention
中的注意力权重是相对于输入张量的,因此输入张量的维度需要满足一定的要求。具体来说,输入张量的维度应该为 (seq_len, batch_size, embed_dim)
,其中 seq_len
表示序列长度,batch_size
表示批次大小,embed_dim
应该是头的维度 embed_dim // num_heads
的整数倍。如果输入张量的维度不满足要求,可以使用 nn.Linear
等层进行维度转换。