GAT机制应用于nlp_world 统计字数代码实现原理


图方法分为谱方法(spectral method)和空间方法( spatial method),谱方法是将图映射到谱域上,例如拉普拉斯矩阵经过特征分解得到的空间,代表方法之一是GCN;空间方法是直接在图上进行操作,代表方法之一GAT。本文主要介绍GAT方法的基本原理,以及代码实现。

GAT论文网址:https://arxiv.org/abs/1710.10903


给定图



表示点,


表示边,节点的个数



输入:


个节点的特征,




输出:





1. 基本原理

首先 ,为了更加充分地表示节点的特征,对节点


进行特征变换,



,即将节点的特征维度


映射到维度


上。关键的步骤来了,对图中的每个节点进行

self-attention操作,计算任意两个节点之间的注意力权重。节点

对节点


的重要性计算公式如下:



masked attention将图结构注入这个机制中,即对于节点


来说,只计算其一阶邻居节点集合


中节点 对


的作用,



为了使系数在不同节点之间易于比较,论文中使用softmax函数在集合


中对它们进行归一化,如下所示。在实验中,注意力机制是一个单层的前馈神经网络,激活函数采用LeakyReLU。



最终,将归一化的注意力系数与其对应的特征进行线性组合,以作为每个节点的最终输出特征。



此外,为了稳定自我注意力的学习过程,论文中发现采用多头注意力(Multi-head Attention)扩展注意力对模型是有提升的。采用


头注意力机制的两种计算公式如下:


GAT机制应用于nlp_归一化_02

(1)拼接方式


GAT机制应用于nlp_world 统计字数代码实现原理_03

(2)均值方式


2. 代码实现

首先,定义GraphAttentionLayer层,实现单个注意力机制层。


class GraphAttentionLayer(nn.Module):

    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, input, adj):
        h = torch.mm(input, self.W) # shape [N, out_features]
        N = h.size()[0]

        a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features) # shape[N, N, 2*out_features]
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))  # [N,N,1] -> [N,N]

        zero_vec = -9e15*torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, h)  # [N,N], [N, out_features] --> [N, out_features]

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime


接下来,定义GAT层,用于实现完整的网络模型。


class GAT(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
        super(GAT, self).__init__()
        self.dropout = dropout

        self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)

        self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)

    def forward(self, x, adj):
        x = F.dropout(x, self.dropout, training=self.training)
        x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.elu(self.out_att(x, adj))
        return F.log_softmax(x, dim=1)


最后,对模型进行训练,优化模型。


model = GAT(nfeat=features.shape[1], nhid=args.hidden, nclass=int(labels.max()) + 1, 
            dropout=args.dropout, nheads=args.nb_heads, alpha=args.alpha)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

features, adj, labels = Variable(features), Variable(adj), Variable(labels)

def train(epoch):
    t = time.time()
    model.train()
    optimizer.zero_grad()
    output = model(features, adj)
    loss_train = F.nll_loss(output[idx_train], labels[idx_train])
    acc_train = accuracy(output[idx_train], labels[idx_train])
    loss_train.backward()
    optimizer.step()

    if not args.fastmode:
        model.eval()
        output = model(features, adj)

    loss_val = F.nll_loss(output[idx_val], labels[idx_val])
    acc_val = accuracy(output[idx_val], labels[idx_val])
    print('Epoch: {:04d}'.format(epoch+1),
          'loss_train: {:.4f}'.format(loss_train.data.item()),
          'acc_train: {:.4f}'.format(acc_train.data.item()),
          'loss_val: {:.4f}'.format(loss_val.data.item()),
          'acc_val: {:.4f}'.format(acc_val.data.item()),
          'time: {:.4f}s'.format(time.time() - t))

    return loss_val.data.item()