目录
- MANN
- METHOD
- memory layer
- GMN
- MEMGNN
- TRAINING
- EXPERIMENTS
- datasets
- RESULTS
- ABLATION STUDY
本文作者来自University of Toronto,Sharif University of Technology。
本文提出了一种memory layer,借助multi-head array of memory keys和卷积算子,学习soft cluster assignment。不像是GCN依赖于local信息,memory layer依赖于全局信息,因此不用担心过平滑。根据memory layer为基础提出了两种不同的模型:memory-based GNN (MemGNN)以及graph memory network (GMN)。其中,Memory augmented neural networks (MANNs)是模型的基础,所以在看模型之前还是学习一下MANN。
MANN
对于一般的神经网络,比如RNN,我们将其视为一个黑盒,一个封装好的函数,只需要在执行的时候调用相关参数。而MANN则具有“互动能力”,或者直白地说,可以与内存进行交互,通过访问内存中的相关数据,使得模型具有更强的思考与记忆的能力。一个记忆网络(memory networks,简称为MemNN),包括了记忆m,还包括以下4个组件I、G、O、R(lstm的三个门,然后m像cell的list):
- Input:将输入向量投影到更高的特征维度。
- Generalization:更新记忆,对于数组来说可能只是简单的插入。
- Output:结合输入,从记忆里抽取出适合的结果,返回一个向量。
- Response:将记忆向量转化为合适的输出格式。相当于一个逆向的Input。
但是个人觉得本文中提及的memory虽然在原理上属于记忆网络,但是好像也更贴近Transformer中提及的KQV注意力。可能知识学着学着就学杂了。
METHOD
memory layer
记忆层被定义为:
这是一个函数,将输入转化为,并且。没错,这部分对应了图粗化。并且,特征维度从变成了,顺便也学习了节点的特征。记忆层如图一所示:
记忆层的本质是多头的注意力数组keys,是head数。这里的多头就是注意力的多头,多次提取特征效果更好。对于每一个输入,首先通过所有输入共享的query将输入特征变换为高阶特征,可以对应记忆网络的Input。然后再经过记忆层,同每一个key进行比较(本质上就是比较query和key的相似程度,这个相似度需要根据实际情况自己定义,欧氏距离余弦相似度都是可以的),得到个注意力矩阵。然后接着使用卷积层将其聚合成一个注意力矩阵。
本文将input query表示为,作为输入的图的特征表示,keys为,作为query的聚类中心。然后,使用一种对集群友好的分布作为键和查询之间的距离度量(Student’s t-distribution):
其中就是normalized score,对于本文的就是节点被分配给集群的概率,表示自由度。那么多头的注意力则表示为:。为了将这些头集合成一个赋值矩阵,我们在标准卷积类比中将这些头和赋值矩阵作为深度、高度和宽度通道,并对它们应用一个卷积运算符(也就是图中的蓝色框框),形式化描述为:
其中表示为[1,1,|h|]的卷积。然后,粗化的节点表示V为:
之后使用单层前馈神经网络得到下一层的Q:
对于图分类任务,可以简单地将内存层堆叠到输入图被粗化为表示全局图表示的单个节点的水平,然后将其提供给完全连接的层来预测图类,如下所示:
其中,是最初始的query,通过对图应用query网络得到。本文介绍了两种不同的基于memory network的架构,分别为GMN与MemGNN。
GMN
GMN是一系列记忆网络的堆叠,并且不使用任何消息传递机制生成query,最上层是一个query network ,将初始节点特性投射到表示初始查询空间的潜在空间中。因此,每个节点的拓扑信息需要以某种方式编码到其初始表示中。本文使用带有重启的随机游走策略(RWR)计算这个网络拓扑的嵌入,然后按行对它们进行排序,以强制嵌入顺序不变。然后使用一个两层前馈神经网络的查询网络将拓扑嵌入与初始节点特性融合到初始查询表示中:
这个其实就是拓扑结构特征经过线性变换再和属性特征进行拼接。
MEMGNN
这个直接使用GCN去聚合拓扑结构的特征:
为了考虑边的特殊性,本文对GAT进行了一些修改,称为e-GAT,也考虑了边的信息:
表示节点,表示边,也就是节点和边的特征拼接起来,作为最终的输入。
TRAINING
训练的时候除了分类的损失函数,还考虑了聚类的效果。聚类的loss被定义为KL散度。KL散度被衡量两个不同的概率分布之间的相似程度。soft assignments 与目标分布 之间的KL散度为:
目标分布被定义为:
如此一来总体的损失函数为:
其中,。
EXPERIMENTS
datasets
使用了7个图分类数据集以及2个图回归任务的数据集。数据集为:
其统计数字如下:
RESULTS
除了一些深度学习的方法,还与图核进行了对比。其中GMN比MemGNN的结果更好,说明用全局拓扑嵌入替换局部邻接信息可以为模型提供更有用的信息。AUC-ROC也可以说明算法的效果(table 2)。
对于图回归任务,使用均方根误差(RMSE):
ABLATION STUDY
e-GAT与GAT,从曲线来看当然是带有边的e-GAT更好。由于节点比边具有更丰富的特征,我们将节点和边的特征维数分别设置为16和4。
此外,为了研究拓扑结构对模型的影响,使用了三种不同的拓扑结构特征获取方式:adjacency matrix, normalized adjacency matrix, and RWR。当然RWR最好,这部分文中没有给出图表。
对于邻居节点较多的图,需要对节点的邻居进行下采样。对比了随机下采样和使用RWR分数排名之后下采样的两种不同方式,交叉验证精度分别为73.9%和73.1%,表明随机抽样略优于基于RWR的抽样。
NUMBER OF KEYS AND HEADS:这个对于不同的数据集有着不同的结论,尽管ESOL数据集有13.3个节点,但是却使用了64个keys,而平均节点数为32.69的Enzymes却只用了10个keys。这说明虽然键代表cluster,但键的数量不一定与输入图中的节点数量成比例。
最后,通过可视化的方式去探究这个key的含义,可以见模型能够有效地对化学集团进行聚类。
并且,添加了聚类损失函数的结果要比不添加的更好。下图中的(b)(d)是没有添加聚类损失的,其聚类效果不佳。