多标签文本分类nlp_全连接


多标签文本分类的框架:ALBERT+Denses,即通过 多个二分类解决多标签分类问题。搭建这个框架的目的主要还是为了和其它几个不同的框架做一些对比,以及尝试一种新的方法来做多标签文本分类。


多标签文本分类nlp_全连接_02


目的:主要是兴趣,实现自己的想法,以及与其他框架对比下实验效果。

  • 这篇文章和之前写的一篇文章有一定的相似之处。

HelloNLP:多标签文本分类 [ALBERT](附代码)zhuanlan.zhihu.com


多标签文本分类nlp_文本分类_03


一、介绍

  1. 此项目是在tensorflow版本1.14.0的基础上做的训练和测试。
  2. 任务类型为中文多标签文本分类,一共有K个标签: 。
  3. 模型的输入为一个sentence,输出为一个或者多个label。

由于另外几篇文章有详细的介绍,所以这里就不不多说了。

二、框架及算法

1、Placeholder


多标签文本分类nlp_多标签文本分类nlp_04


首先,我们需要设置一些占位符(Placeholder),占位符的作用是在训练和推理的过程中feed模型需要的数据。我们这里需要4个占位符,分别是input_ids、input_masks、segment_ids和label_ids。前面3个是我们了解的BERT输入特征,最后面一个是标签的id。

2、ALBERT token-vectors


多标签文本分类nlp_文本分类_05


从图中红色的框内可以看出,ALBERT需要传入3个参数(input_ids,input_masks、segment_ids),就可以得到我们所需要的一个2维向量output_layer:(batch_size, hidden_size)。

有人在这里就会好奇,为什么ALBERT输出的是一个2维向量,而不是一个3维向量(batch_size, sequence_length, hidden_size)呢?那我们来看一下源码,弄清楚self.model.get_pooled_output()的来历。


多标签文本分类nlp_占位符_06


其中self.sequence_ouput就是我们所说的那个3维向量(batch_size, sequence_length, hidden_size)。我们对这个3维向量做了一个"pooler"的操作,从而使之变成了一个2维的向量。

蓝色方框内的解释为:”We "pool" the model by simply taking the hidden state corresponding to the first token. We assume that this has been pre-trained“。这句话怎么理解呢?意思是将整个句子的特征信息投射到句子第一个字的隐藏状态向量上面。并且,认为这个它是通过预训练得到的。

3、Full connection


多标签文本分类nlp_文本分类_07


我们可以发现,这里使用了很多个全连接层。其中,每一个全连接层服务于一个label,以及每一个全连接层都是一个2分类。在这里,我们使用了tf.nn.softmax_cross_entropy_with_logits的交叉熵计算方式,而非tf.nn.sigmoid_cross_entropy_with_logits。因为对于一个二分类来说,我们要解决的是独占类标签。

  • 在计算损失值时,我们单独计算了每一个标签的损失值。
  • 另外,self.probabilities在这里的维度是一个3维向量(batch_size,hp.num_labels,2),而非一个二维向量,即每一个标签都是一个二分类。

4、Inference


多标签文本分类nlp_文本分类_08


由于损失函数使用的是tf.nn.softmax_cross_entropy_with_logits,所以这里我们使用了tf.argmax来计算出预测值。另外,我们可以看出,在使用tf.argmax时,我们是对第3个维度使用了这个函数,这一点和我们平常使用的会有所区别。

三、实践及框架图

1、框架图


多标签文本分类nlp_占位符_09


2、模型Loss和Accuracy变化曲线图


多标签文本分类nlp_多标签文本分类nlp_10


  • 这个损失函数的收敛图和基于ALBERT框架的非常的类似,主要是由于他们的框架较为相似。
  • 根据上图,可以发现,当loss在0.0001左右时,模型接近收敛。这和常见的文本分类有较大的区别,一般文本分类在0.02和0.01区间模型已经收敛,这个主要是由于多标签文本分类中,有很多0,即空标签,在计算loss时,为空的0也会参与计算。

疑问:为什么accuracy曲线会这么奇怪呢?
Accuracy曲线出来后,也让我有一点吃惊。一开始以为是模型随机性导致的,所以又从新训练了一遍,发现还是一样的结果。
后来想了一下,因为Accuracy的计算方式导致的。

四、代码链接


以后开放!