Google发布的论文《Pre-training of Deep Bidirectional Transformers for Language Understanding》,提到的BERT模型刷新了自然语言处理的11项记录。算是NLP的里程碑事件,也开始了大公司之间的数据和算力的装备竞赛。放一篇比较好的中文论文翻译。
BERT在阅读理解领域带了很大的进展,在BERT的出现之前最好的模型基本都是使用的Seq2Seq方法,分五步走,相当复杂。BERT出现后,只需要在BERT后面加上简单的网络就可达到特别好的效果。所以理解BERT用于阅读理解是非常重要的。
下图是SQUAD2.0的排名,截止到19年7月1日。
BERT Base的参数
对于英文阅读理解任务来说,可以选择Base版或者Large版,对于中文来说只有Base版。BERT本身用的是Transformer的Encoder部分,只是堆了很多层,换了个训练任务而已。
下面简单看一下BERT的各层的参数量,全连接层占比半数以上,和TokenEmbedding加起来占比70左右。而最重要的Attention只有27.5,这是12层的参数,除以12后相当少了。这里有个小的思路,通过对全连接层进行压缩,以及对Embedding层压缩或许可以达到小而美的结果(PS. 下一篇顶会就是你)。
Masked Language Model
这里是BERT预训练的方法,训练出来的就是BERT本体,也是官网给的下载模型,实际在微调的时候基本不用到, 了解一下就好(PS. 人家都给你做好了,理解原理后直接拿来主义就行了)。
简单来说是Masked Language Model,分为词语级别和句子级别。对于阅读任务来说需要问题和文档之间的关系来得到最后的答案。这两个任务对阅读理解都很有帮助,尤其是第二个任务。下图具体展示了预训练时候这两个任务的具体做法。
Task specific-Models
实际在用BERT的时候需要根据下游任务在BERT后面接上不同的网络,然后可以只训练接的网络的参数,也可以解冻BERT最后几层一起训练,这就是迁移学习,跟CV领域的一致。BERT的四种主流应用场景:
在SQuAD,给出了问题和包含答案的段落,任务是预测答案文本的起始和结束位置(答案的跨度,span)。
BERT首先在问题前面添加special classification token[CLS]标记,然后问题和段落连在一起,中间使用special tokens[SEP]分开。序列通过token Embedding、segment embedding 和 positional embedding输入到BERT。最后,通过全连接层和softmax函数将BERT的最终隐藏状态转换为答案跨度的概率。
BERT的输入和输出
输入:
input_ids, input_mask, segment_ids
输出:
start_position, end_position
起始和结束位置的计算
example = SquadExample( qas_id=qas_id,
question_text=question_text,
doc_tokens=doc_tokens,
orig_answer_text=orig_answer_text,
start_position=start_position,
end_position=end_position)
起始和结束位置的重新标注,这个是因为分词的原因。即便是bert中文版是按照字的,但是某些英文单词、年份还是会被单独分开,如1990等,但是标注的时候为了统一,是按照字符的个数来算的,就是不分词,所以在处理的时候要重新改写起始和结束位置,例如下面这段文本,原本起始位置是41,但是分词之后改为了38,注意这对生成的结果没有影响。
由于BERT的输入是这样拼接的:
[CLS]question[SEP]context[SEP]
所以最后输入模型的起始位置还要加上question的长度和([CLS],[SEP]),还是以上面两个为例。
第一个例子:
"question": "范廷颂是什么时候被任为主教的?"
,长度为15,加上[CLS]和[SEP],就是17,再加上原始的起始位置30,最后得到47.第二个例子:
"question": "1990年,范廷颂担任什么职务?"
,长度为13,加上[CLS]和[SEP],就是15,再加上原始的起始位置38,最后得到53.
和Debug的结果一致
BERT的Embedding层
理解了BERT的输入,我们看一下BERT的Embedding层,主要包括三个部分,最后相加就OK了。BERT的Token Embedding
也就是word Embeddings,BERT的分词用的是sub words的形式,会把词拆分成一些字词来减少OOV。
BERT的Segment Embedding
Segment Embeddings用来区分两段话,用于句子级别的Mask任务,直接加0,1区分,也是最简单的实现。
BERT的Position Embedding
对每个输入的位置进行编码,添加位置信息,Transformer是直接表示的,BERT是训练来的。
小结一下
Token Embeddings 形状为(1, n, 768),就是word Embedding,对于中文是字。
Segment Embeddings 形状为 (1, n, 768) ,用来区分两个句子
Position Embeddings 形状为 (1, n, 768),主要是为Transfomer提供位置信息。
最后把三个加起来就是BERT的Embedding层了(PS.直接加起来确实有点简单粗暴哈)。
BERT阅读理解标准模型
从文本Embedding层向量后,输入到BERT预训练模型里面(Transformer的Encoder),得到BERT的深层表示,再接上网络就可以得到输出了。这里先上代码:
from pytorch_pretrained_bert.modeling import BertModel, BertPreTrainedModel
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
class Bert_QA(BertPreTrainedModel):
def __init__(self, config, num_labels):
super(BertOrigin, self).__init__(config)
self.num_labels = num_labels
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
logits = self.classifier(sequence_output) # (B, T, 2)
start_logits, end_logits = logits.split(1, dim=-1)# ((B, T, 1),(B, T, 1))
start_logits = start_logits.squeeze(-1) # (B, T)
end_logits = end_logits.squeeze(-1) # (B, T)
pooled_output: [batch_size, hidden_size=768], 取了最后一层Transformer的输出结果的第一个单词[cls]的hidden states
sequence_output:[batch_size, sequence_length, hidden_size=768],最后一个encoder层的输出
上面代码用的是BERT的SQUAD里面的做法得到start_logits和end_logits,就是把sequence_output接全连接转换hidden_size维度为2,然后split输出,最后的loss就是两个分别计算loss然后平均。
总结本文介绍了BERT用于阅读理解的基本架构,有了这些知识你也可以实现一个中文的阅读理解模型。最后我总结一下BERT阅读理解模型在实际使用场景中的问题和可能的改进措施:
BERT模型太大了,即便是在预测的时候也是接近400M,回到开头,我还是比较期望有小而美的模型实现的
BERT只支持单句子最长512的输入,这是有Transformer本身的Attention决定的,再大Google也训练不动。解决方案是要么截断,要么划分句子后拼接BERT的表示,Transformer XL说可以缓解这个问题,还没细看
BERT最多支持两句输入,对于多轮的阅读理解和对话系统不友好,所有很多文章都在尝试解决这个问题。
BERT默认的输出的最大答案长度为30,实际用的时候最好设成300,来应对长答案
原文链接:
https://juejin.im/post/5d05cd6c518825509075f7f9