详解BERT阅读理解_javaBERT的简单回顾

Google发布的论文《Pre-training of Deep Bidirectional Transformers for Language Understanding》,提到的BERT模型刷新了自然语言处理的11项记录。算是NLP的里程碑事件,也开始了大公司之间的数据和算力的装备竞赛。放一篇比较好的中文论文翻译。

BERT在阅读理解领域带了很大的进展,在BERT的出现之前最好的模型基本都是使用的Seq2Seq方法,分五步走,相当复杂。BERT出现后,只需要在BERT后面加上简单的网络就可达到特别好的效果。所以理解BERT用于阅读理解是非常重要的。

下图是SQUAD2.0的排名,截止到19年7月1日。

详解BERT阅读理解_java_02

BERT Base的参数

对于英文阅读理解任务来说,可以选择Base版或者Large版,对于中文来说只有Base版。BERT本身用的是Transformer的Encoder部分,只是堆了很多层,换了个训练任务而已。

下面简单看一下BERT的各层的参数量,全连接层占比半数以上,和TokenEmbedding加起来占比70左右。而最重要的Attention只有27.5,这是12层的参数,除以12后相当少了。这里有个小的思路,通过对全连接层进行压缩,以及对Embedding层压缩或许可以达到小而美的结果(PS. 下一篇顶会就是你)。

详解BERT阅读理解_java_03

Masked Language Model

这里是BERT预训练的方法,训练出来的就是BERT本体,也是官网给的下载模型,实际在微调的时候基本不用到, 了解一下就好(PS. 人家都给你做好了,理解原理后直接拿来主义就行了)。

简单来说是Masked Language Model,分为词语级别和句子级别。对于阅读任务来说需要问题和文档之间的关系来得到最后的答案。这两个任务对阅读理解都很有帮助,尤其是第二个任务。下图具体展示了预训练时候这两个任务的具体做法。

详解BERT阅读理解_java_04详解BERT阅读理解_java_05

Task specific-Models

实际在用BERT的时候需要根据下游任务在BERT后面接上不同的网络,然后可以只训练接的网络的参数,也可以解冻BERT最后几层一起训练,这就是迁移学习,跟CV领域的一致。BERT的四种主流应用场景:

详解BERT阅读理解_java_06BERT应用于阅读理解

在SQuAD,给出了问题和包含答案的段落,任务是预测答案文本的起始和结束位置(答案的跨度,span)。
BERT首先在问题前面添加special classification token[CLS]标记,然后问题和段落连在一起,中间使用special tokens[SEP]分开。序列通过token Embedding、segment embedding 和 positional embedding输入到BERT。最后,通过全连接层和softmax函数将BERT的最终隐藏状态转换为答案跨度的概率。

BERT的输入和输出

  • 输入:input_ids, input_mask, segment_ids

  • 输出:start_position, end_position

起始和结束位置的计算

example = SquadExample( qas_id=qas_id,
                        question_text=question_text,
                        doc_tokens=doc_tokens,
                        orig_answer_text=orig_answer_text,
                        start_position=start_position,
                        end_position=end_position)

详解BERT阅读理解_java_07
起始和结束位置的重新标注,这个是因为分词的原因。即便是bert中文版是按照字的,但是某些英文单词、年份还是会被单独分开,如1990等,但是标注的时候为了统一,是按照字符的个数来算的,就是不分词,所以在处理的时候要重新改写起始和结束位置,例如下面这段文本,原本起始位置是41,但是分词之后改为了38,注意这对生成的结果没有影响。
详解BERT阅读理解_java_08
详解BERT阅读理解_java_09

由于BERT的输入是这样拼接的:

[CLS]question[SEP]context[SEP]

所以最后输入模型的起始位置还要加上question的长度和([CLS],[SEP]),还是以上面两个为例。

  • 第一个例子:"question": "范廷颂是什么时候被任为主教的?",长度为15,加上[CLS]和[SEP],就是17,再加上原始的起始位置30,最后得到47.

  • 第二个例子:"question": "1990年,范廷颂担任什么职务?",长度为13,加上[CLS]和[SEP],就是15,再加上原始的起始位置38,最后得到53.

和Debug的结果一致

详解BERT阅读理解_java_10

BERT的Embedding层

理解了BERT的输入,我们看一下BERT的Embedding层,主要包括三个部分,最后相加就OK了。详解BERT阅读理解_java_11BERT的Token Embedding

也就是word Embeddings,BERT的分词用的是sub words的形式,会把词拆分成一些字词来减少OOV。

详解BERT阅读理解_java_12Bert嵌入层

BERT的Segment Embedding

Segment Embeddings用来区分两段话,用于句子级别的Mask任务,直接加0,1区分,也是最简单的实现。

详解BERT阅读理解_java_13

BERT的Position Embedding

对每个输入的位置进行编码,添加位置信息,Transformer是直接表示的,BERT是训练来的。

小结一下

  • Token Embeddings 形状为(1, n, 768),就是word Embedding,对于中文是字。

  • Segment Embeddings 形状为 (1, n, 768) ,用来区分两个句子

  • Position Embeddings 形状为 (1, n, 768),主要是为Transfomer提供位置信息。

  • 最后把三个加起来就是BERT的Embedding层了(PS.直接加起来确实有点简单粗暴哈)。

BERT阅读理解标准模型

从文本Embedding层向量后,输入到BERT预训练模型里面(Transformer的Encoder),得到BERT的深层表示,再接上网络就可以得到输出了。这里先上代码:

from pytorch_pretrained_bert.modeling import BertModel, BertPreTrainedModel

import torch
from torch import nn
from torch.nn import CrossEntropyLoss

class Bert_QA(BertPreTrainedModel):

    def __init__(self, config, num_labels):
        super(BertOrigin, self).__init__(config)
        self.num_labels = num_labels
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, num_labels)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):

        sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
         logits = self.classifier(sequence_output)  # (B, T, 2)
        start_logits, end_logits = logits.split(1, dim=-1)# ((B, T, 1),(B, T, 1))
        start_logits = start_logits.squeeze(-1# (B, T)
        end_logits = end_logits.squeeze(-1# (B, T)
  • pooled_output: [batch_size, hidden_size=768], 取了最后一层Transformer的输出结果的第一个单词[cls]的hidden states

  • sequence_output:[batch_size, sequence_length, hidden_size=768],最后一个encoder层的输出

上面代码用的是BERT的SQUAD里面的做法得到start_logits和end_logits,就是把sequence_output接全连接转换hidden_size维度为2,然后split输出,最后的loss就是两个分别计算loss然后平均。

总结

本文介绍了BERT用于阅读理解的基本架构,有了这些知识你也可以实现一个中文的阅读理解模型。最后我总结一下BERT阅读理解模型在实际使用场景中的问题和可能的改进措施:

  • BERT模型太大了,即便是在预测的时候也是接近400M,回到开头,我还是比较期望有小而美的模型实现的

  • BERT只支持单句子最长512的输入,这是有Transformer本身的Attention决定的,再大Google也训练不动。解决方案是要么截断,要么划分句子后拼接BERT的表示,Transformer XL说可以缓解这个问题,还没细看

  • BERT最多支持两句输入,对于多轮的阅读理解和对话系统不友好,所有很多文章都在尝试解决这个问题。

  • BERT默认的输出的最大答案长度为30,实际用的时候最好设成300,来应对长答案




原文链接:

https://juejin.im/post/5d05cd6c518825509075f7f9