文章目录

  • 1 引言
  • 2 本文模型
  • 2.1 Seq2Seq 注意力模型
  • 2.2 指针生成网络
  • 2.3 覆盖机制(Coverage mechanism)


Reference
1. Get To The Point: Summarization with Pointer-Generator Networks


seq2seq模型可用于文本摘要(并非简单地选择、重排原始文本的段落),然后这些模型有两个缺点:不易关注真实细节,以及倾向于生成重复文本

本文提出一种新颖的架构:使用两种方式增强标准的seq2seq注意力模型。第一,使用混合指针生成网络,利用指针从原文精确地复制单词,同时保留生成器产生新单词的能力。第二,使用覆盖机制跟踪哪些单词已经被摘取,避免生成重复文本。

1 引言

文本摘要任务旨在输出仅包含原文主要信息的压缩文本,大致具有两种方法:抽取式摘要式。抽取式方法直接抽取段落原文,而摘要式方法可能生成一些原文中并非出现的单词或短语,类似于人类写的摘要。

抽取式方法相对容易,因为直接抽取原文大段文本,可保证语法和准确性。从另一个角度来说,模型的经验能力对于生成高质量摘要至关重要,如改写、总结,以及结合现实世界知识,仅摘要式框架具备这种可能性。

摘要式总结较为困难,先前模型大多采用抽取式,最近提出的seq2seq模型使得摘要式总结成为可能。尽管seq2seq模型具备很大潜力,但它们也暴露了一些不良行为,如无法准确复制真实细节无法处理OOV问题,以及倾向于自我重复

指针网络代码pytorch 指针生成网络_编码器

本文提出一种网络结构,在多句上下文总结中解决以上三个问题。最近的摘要式模型主要关注与标题生成(将一两句话缩减至单一标题),我们相信长文本摘要挑战与实用性并存,本文使用CNN/Daily Mail数据集,其包含新闻文章(平均39句)和多句摘要,结果显示,本文提出的模型高于SOTA模型2个ROUGE点。

本文的混合指针生成模型通过指针从原文中复制单词,文本生成准确性提高,并解决了OOV问题,同时保留生成原文中未出现的新单词的能力,该网络可视为摘要方法和抽取方法之间的平衡,类似于应用于短文本摘要的 CopyNetForced-Attention Sentence Compression 模型。我们提出一种新型的覆盖向量(源于NMT,可用于跟踪和控制原文的覆盖率),结果表明,覆盖机制对于消除重复性非常有效。

2 本文模型

2.1 Seq2Seq 注意力模型

本文基线模型类似于图2中的模型:

指针网络代码pytorch 指针生成网络_基线_02


文中各token依次输入至单层BiLSTM,网络输出编码器隐状态序列指针网络代码pytorch 指针生成网络_概率分布_03,在时间步指针网络代码pytorch 指针生成网络_基线_04,解码器(单层单向LSTM)接收到先前单词的词向量(训练阶段为参考摘要的前一个单词,测试阶段为解码器上一时刻输出的单词),输出隐状态指针网络代码pytorch 指针生成网络_概率分布_05。基于Bahdanau et al.(2015)注意力机制,计算注意力分布:

指针网络代码pytorch 指针生成网络_指针网络代码pytorch_06


式中,指针网络代码pytorch 指针生成网络_基线_07为可学习的参数。注意力分布可看作为源单词的概率分布,告诉解码器应关注哪些单词生成下一个单词。接着,使用注意力机制加权编码器隐状态,输出上下文向量指针网络代码pytorch 指针生成网络_指针网络代码pytorch_08:

指针网络代码pytorch 指针生成网络_基线_09


上下文向量可看作为固定维度的、当前时间步从源中读取的内容,将其与解码器隐状态指针网络代码pytorch 指针生成网络_概率分布_05拼接,输入至两层线性网络,产生词典概率分布指针网络代码pytorch 指针生成网络_编码器_11

指针网络代码pytorch 指针生成网络_编码器_12

式中,指针网络代码pytorch 指针生成网络_概率分布_13为可学习参数。指针网络代码pytorch 指针生成网络_编码器_11为词典中所有单词的概率分布,告知我们预测单词指针网络代码pytorch 指针生成网络_指针网络代码pytorch_15的最终概率分布:

指针网络代码pytorch 指针生成网络_指针网络代码pytorch_16

训练阶段,时间步指针网络代码pytorch 指针生成网络_基线_04的损失为目标单词指针网络代码pytorch 指针生成网络_编码器_18的负对数似然:

指针网络代码pytorch 指针生成网络_概率分布_19


整个序列的全部损失为

指针网络代码pytorch 指针生成网络_概率分布_20

2.2 指针生成网络

本文模型为基线模型seq2seq和指针网络的混合,其允许通过指针复制单词,以及从固定大小的词典中生成单词。在图三所示的指针生成网络中,注意力分布指针网络代码pytorch 指针生成网络_基线_21和上下文向量指针网络代码pytorch 指针生成网络_指针网络代码pytorch_08可以利用2.1章节所述公式计算。

此外,时间步利用上下文向量指针网络代码pytorch 指针生成网络_指针网络代码pytorch_08,解码器隐状态指针网络代码pytorch 指针生成网络_概率分布_05,解码器输入指针网络代码pytorch 指针生成网络_基线_25计算生成概率分布:

指针网络代码pytorch 指针生成网络_编码器_26


式中,向量指针网络代码pytorch 指针生成网络_编码器_27和变量指针网络代码pytorch 指针生成网络_指针网络代码pytorch_28为可学习参数,指针网络代码pytorch 指针生成网络_基线_29为sigmoid函数。指针网络代码pytorch 指针生成网络_概率分布_30可看作为软开关,用于选择是利用指针网络代码pytorch 指针生成网络_编码器_11从词表中抽取单词,还是利用注意力分布指针网络代码pytorch 指针生成网络_指针网络代码pytorch_32从输入句抽取单词。对于每一篇文档,将原文中所有出现的单词和词典结合为扩充词典,获得在扩展词典上的概率分布:

指针网络代码pytorch 指针生成网络_指针网络代码pytorch_33


注意到,如果指针网络代码pytorch 指针生成网络_指针网络代码pytorch_15不存在与词典中,则指针网络代码pytorch 指针生成网络_编码器_35;类似地,如果指针网络代码pytorch 指针生成网络_指针网络代码pytorch_15不存在于原文中,则指针网络代码pytorch 指针生成网络_指针网络代码pytorch_37。产生OOV单词的能力是指针网络的主要优势之一,而我们的基线模型产生单词的数量局限于预设置的词典。损失函数如公式(6)和(7)所示,但我们修改为公式(9)所示的概率分布指针网络代码pytorch 指针生成网络_基线_38

2.3 覆盖机制(Coverage mechanism)

重复是seq2seq模型的常见问题,在生成多句时尤其明显(如图1所示),我们采用覆盖机制解决这个问题。覆盖机制模型中,我们维持之前所有解码步的注意力分布之和作为覆盖向量指针网络代码pytorch 指针生成网络_指针网络代码pytorch_39

指针网络代码pytorch 指针生成网络_指针网络代码pytorch_40


直观上,指针网络代码pytorch 指针生成网络_指针网络代码pytorch_39为原文单词上的分布(未归一化),表示这些单词到目前为止从注意力机制中所获得的覆盖度。注意到,指针网络代码pytorch 指针生成网络_基线_42为零向量,因为初始时刻源文中没有任何单词被覆盖。覆盖向量作为注意力机制的额外输入,将公式(1)改为

指针网络代码pytorch 指针生成网络_编码器_43


式中,指针网络代码pytorch 指针生成网络_编码器_44是与指针网络代码pytorch 指针生成网络_基线_45具有相同长度的可学习向量。覆盖机制使得注意力机制的当前决策受其先前决策(指针网络代码pytorch 指针生成网络_编码器_46之和)影响,因此应该更易避免注意力机制关注相同位置,从而避免生成重复文本。我们发现,额外定义覆盖损失惩罚重复关注相同位置是必要的,覆盖损失

指针网络代码pytorch 指针生成网络_概率分布_47


覆盖损失有界:指针网络代码pytorch 指针生成网络_基线_48,公式(12)中的覆盖损失有别于机器翻译中的覆盖损失。MT中,假定翻译率大致为1:1,如果覆盖向量大于或小于1,其将作为惩罚向量。本文损失函数比较灵活,因为摘要不需要一致覆盖率,本文仅惩罚注意力机制与到目前为止的覆盖向量之间的重叠部分,防止重复关注。最终,使用超参数指针网络代码pytorch 指针生成网络_概率分布_49加权覆盖损失至先前损失,产生新的合成损失:

指针网络代码pytorch 指针生成网络_基线_50