目录

  • 一、Graph Attention Network
  • 1.1 GAT的优点
  • 1.2 Graph Attention layer的输入输出
  • 1.3 Graph Attention layer的attention机制
  • 1.4 多头attention机制
  • 二、GAN的python复现
  • 三、GAT代码、论文、数据集下载


一、Graph Attention Network

1.1 GAT的优点

图注意力网络(GAT)是作者对图卷积网络(GCN)的改进。它的主要创新点在于利用了注意力机制(Attention Mechanism)来自动的学习和优化节点间的连接关系,这一作法有以下几个优点:

  1. 克服了GCN只适用于直推式学习的缺陷(在训练期间需要测试时的图数据),可以应用于我们熟悉的归纳式学习任务(在训练期间不需要测试时的图数据)。
  2. 使用注意力权重代替了原先非0即1的节点连接关系,即两个节点间的关系可以被优化为连续数值,从而获得更丰富的表达
  3. 由于attention值的计算是可以在节点间并行进行的,网络的计算相当高效

1.2 Graph Attention layer的输入输出

作为一层网络,图注意力层的输入为


图注意力 pytorch 图注意力网络改进_图注意力 pytorch这里的

图注意力 pytorch 图注意力网络改进_GAT_02是图的节点个数,

图注意力 pytorch 图注意力网络改进_图注意力网络_03表示节点的特征向量,

图注意力 pytorch 图注意力网络改进_图注意力 pytorch_04表示特征维度。


图注意力层的输出为

图注意力 pytorch 图注意力网络改进_GAT_05


同样的,

图注意力 pytorch 图注意力网络改进_GAT_06表示输出的特征维度。


从图注意力层的输入输出可以看出,其本质上也是对特征的一种变换,和其余的网络层功能是类似的。

1.3 Graph Attention layer的attention机制

首先需要定义一个特征变换矩阵图注意力 pytorch 图注意力网络改进_图注意力 pytorch_07

  1. GAT中的attention机制被称为self-attention,记为图注意力 pytorch 图注意力网络改进_图注意力 pytorch_08,其功能如下:图注意力 pytorch 图注意力网络改进_图神经网络_09
    如图所示,该式表示了self-attention利用节点图注意力 pytorch 图注意力网络改进_attention机制_10和节点图注意力 pytorch 图注意力网络改进_图注意力 pytorch_11的特征作为输入计算出了图注意力 pytorch 图注意力网络改进_attention机制_12, 而图注意力 pytorch 图注意力网络改进_attention机制_12则表示了节点图注意力 pytorch 图注意力网络改进_图注意力 pytorch_11对于节点图注意力 pytorch 图注意力网络改进_attention机制_10的重要性。
  2. 需要说明的是,这里的节点图注意力 pytorch 图注意力网络改进_图注意力网络_16是节点图注意力 pytorch 图注意力网络改进_attention机制_17的近邻,而节点图注意力 pytorch 图注意力网络改进_attention机制_17可能是拥有多个近邻的,因此就有了下面的图注意力 pytorch 图注意力网络改进_图注意力网络_19归一化操 图注意力 pytorch 图注意力网络改进_图注意力 pytorch_20图注意力 pytorch 图注意力网络改进_图注意力网络_21是节点 图注意力 pytorch 图注意力网络改进_attention机制_10 的近邻集合。
  3. 那么说了这么久,这个self-attention机制,也就是我们一开始提到的图注意力 pytorch 图注意力网络改进_图注意力 pytorch_23是怎么计算的呢?其实也很简单 图注意力 pytorch 图注意力网络改进_图注意力网络_24
    这里的 图注意力 pytorch 图注意力网络改进_图神经网络_25 表示需要训练的网络参数, 图注意力 pytorch 图注意力网络改进_图注意力网络_26表示的是矩阵拼接操作,图注意力 pytorch 图注意力网络改进_图注意力 pytorch_27则是一种激活函数,是图注意力 pytorch 图注意力网络改进_图神经网络_28的一种改进。
  4. 最后给出图感知层的定义,即 图注意力 pytorch 图注意力网络改进_图注意力 pytorch_29

上面就是GAT的attention计算方法了,其中会有两个知识点会影响理解

  1. self-attention机制为什么可以表示节点间的重要性
  2. 图注意力 pytorch 图注意力网络改进_图注意力 pytorch_27的定义

对于上面这两点,如果知道的话,再结合对GCN的理解,可以很容易的get到GAT的点和含义,如果不清楚的话可能会有点迷糊。

  1. attention机制实际上是在有监督的训练下计算两个向量的匹配程度,从而揭示其重要性和影响,由于本篇博客不是专门介绍attention的,这里不做多余的解释,日后会补上相应的博客。
  2. 图注意力 pytorch 图注意力网络改进_图注意力 pytorch_27的定义如下: 图注意力 pytorch 图注意力网络改进_图注意力网络_32 即引入了一个系数图注意力 pytorch 图注意力网络改进_GAT_33来取消图注意力 pytorch 图注意力网络改进_图注意力 pytorch_34的死区。

1.4 多头attention机制

  1. 为了稳定self−attention的学习过程,GAT还采用了一种多头机制,即独立的计算K个attention,然后将其获得的特征拼接起来,获得一个更全面的表述,表示如下 图注意力 pytorch 图注意力网络改进_attention机制_35这里的 || 表示矩阵拼接的操作,其余的符号和上面描述的一致。
  2. 同时,考虑到在网络的最后一层输出层如果还采用这种拼接的方式扩大特征维度,可能不合理,因此,GAT又为输出层定义了平均的操作 图注意力 pytorch 图注意力网络改进_图注意力 pytorch_36

多头attention机制如图所示

图注意力 pytorch 图注意力网络改进_attention机制_37

二、GAN的python复现

模型的核心代码如下

import numpy as np
import tensorflow as tf

from utils import layers
from models.base_gattn import BaseGAttN

class GAT(BaseGAttN):
    def inference(inputs, nb_classes, nb_nodes, training, attn_drop, ffd_drop,
            bias_mat, hid_units, n_heads, activation=tf.nn.elu, residual=False):
        attns = []
        for _ in range(n_heads[0]):
            attns.append(layers.attn_head(inputs, bias_mat=bias_mat,
                out_sz=hid_units[0], activation=activation,
                in_drop=ffd_drop, coef_drop=attn_drop, residual=False))
        h_1 = tf.concat(attns, axis=-1)
        for i in range(1, len(hid_units)):
            h_old = h_1
            attns = []
            for _ in range(n_heads[i]):
                attns.append(layers.attn_head(h_1, bias_mat=bias_mat,
                    out_sz=hid_units[i], activation=activation,
                    in_drop=ffd_drop, coef_drop=attn_drop, residual=residual))
            h_1 = tf.concat(attns, axis=-1)
        out = []
        for i in range(n_heads[-1]):
            out.append(layers.attn_head(h_1, bias_mat=bias_mat,
                out_sz=nb_classes, activation=lambda x: x,
                in_drop=ffd_drop, coef_drop=attn_drop, residual=False))
        logits = tf.add_n(out) / n_heads[-1]
    
        return logits