NLP实践——基于SIFRank的中文关键短语抽取

  • 0. 本文介绍
  • 1. 运行环境
  • 2. 项目目录
  • 3. 代码实现
  • 3.1 utils
  • 3.2 初始化各类组件
  • 3.2.1 标点和停用词
  • 3.2.2 预训练词汇权重
  • 3.2.3 分词/词性标注模型
  • 3.2.4 候选短语抽取模型
  • 3.2.5 词形还原模型
  • 3.2.6 编码模型
  • 3.3 建立关键短语抽取模型
  • 3.4 抽取应用
  • 4. 改进
  • 4.1 增加候选关键短语
  • 4.2 自监督训练


0. 本文介绍

本文在《SIFRank: A New Baseline for Unsupervised Keyphrase Extraction Based on Pre-Trained Language Model》的基础上,借鉴原作者的思想,重写实现了一个好用的中文关键短语抽取工具。
首先声明一下,这篇论文我并没有看过,所有的理解全都是基于作者开源出来的代码,因而不保证所有的思想都与原作者保持一致。

这篇论文是一个抽取式的关键短语模型,相比近两年备受关注的生成式关键短语模型,其技术理念已经相对落后,但是在实际应用的生产环境中,尤其是对于无监督的垂直领域,我们更关心的是模型的可解释性以及抽取结果的可控性,因而抽取式的模型相比生成式,能够更加让我感到安心,这也是选择这篇论文作为参考的主要原因。在尝试这个思路之前,也对textrank,yake,autophrasex,UCphrase等关键短语抽取工具进行了尝试,但是效果都不太理想。

下面贴出原项目的地址:
https://github.com/yukuotc/SIFRank_zh

原项目的时间比较久,其中所应用到的elmo编码器的预训练模型的下载地址已经失效,并且词性标注模型也比较旧了,所以在此项目的基础上,我从中借鉴了一部分代码,并参考作者的思路,提出并实现了自己的解决方法,主要做出的修改如下:






熟悉我写作风格的同学们应该比较了解,我很少进行理论介绍,我的博客主要从易用的角度,关注一个具体功能的实现,接下来我将从运行环境开始讲起,介绍如何实现这一关键短语抽取模型。

1. 运行环境

首先介绍一下环境配置,我的运行环境如下:

torch 1.8.1
ltp 4.1.4
thulac 0.2.1
nltk 3.5
transformers 4.9.2
sentence-transformers 2.0.0

其中,

  1. thulac是参考原作者的环境,如果完全按照我的方法去做,不考虑原作者的方法,可以不安装;
  2. sentence-transformers是用于自监督训练,如果对领域迁移不感兴趣,可以不安装;
  3. transformers高版本是sentence-transformers的要求,如果不安装后者,估计前者4.0以上即可;
  4. ltp最好采用4.1或以上版本,其新版与旧版在效率和准确度上都有很大的差异;
  5. torch满足相应版本的ltp和transformers即可;
  6. nltk的版本相对随意,一般也不会与其他模块冲突。

2. 项目目录

然后介绍一下项目目录。建立一个项目根目录keyphrase_extractor,在此目录下建立一个jupyter笔记或py文件,建立一个utils.py(其中的内容后边会介绍),以及一个文件夹resources;

resources中,建立一个ner_usr_dict.txt,其中存放分词时的用户自定义实体表,每行写一个实体,例如:

南京市长
江大桥

这个文件的作用是,让分词模型在分词的时候,把“南京市长江大桥”分为[“南京市长”, “江大桥”],而非[“南京市”, “长江大桥”]。

然后去原项目中,下载auxiliary_data下的dict.txt,放在我们的resources下,命名为pretrained_weight_dict.txt。

 

nlp关键词提取java_自然语言处理

nlp关键词提取java_自然语言处理_02

全部准备好之后,整个项目目录应该是这个样子:

keyphrase_extractor
|--keyphrase_extract.ipynb        # 下面所有的代码放进这个笔记
|--utils.py                  # 辅助函数
|--resources
    |--ner_usr_dict.txt         # 自定义实体表
    |--pretrained_weight_dict.txt  # 预训练词汇权重
    |--chinese-electra-180g-small-discriminator   # electra 预训练模型
        |--config.json
        |--tokenizer_config.json
        |--tokenizer.json
        |--added_tokens.json
        |--special_tokens_map.json
        |--vocab.txt
        |--pytorch_model.bin

3. 代码实现

终于来到了喜闻乐见的代码环节,在这一环节中的所有代码,除了3.1中,全部依次丢进keyphrase_extract.ipynb中运行即可。

代码的基本逻辑我随手花了一个图,同学们凑合着看。

nlp关键词提取java_nlp关键词提取java_03

3.1 utils

首先完善一下我们的辅助类函数,打开utils.py,加入以下三个函数:

  1. get_word_weight:用于获取词权重
  2. process_long_input:用于将bert支持的长度从512扩展为1024
  3. rematch:用于token-level到char-level的匹配

这三个函数是到处借鉴来的,其中1是本项目中改写的,2是此论文所述项目中搬来的,3是从bert4keras中搬来的。

import numpy as np
import unicodedata, re
import torch
import torch.nn.functional as F


def get_word_weight(weightfile="", weightpara=2.7e-4):
    """
    Get the weight of words by word_fre/sum_fre_words
    :param weightfile
    :param weightpara
    :return: word2weight[word]=weight : a dict of word weight
    """
    if weightpara <= 0:  # when the parameter makes no sense, use unweighted
        weightpara = 1.0
    word2weight = {}
    word2fre = {}
    with open(weightfile, encoding='UTF-8') as f:
        lines = f.readlines()
    # sum_num_words = 0
    sum_fre_words = 0
    for line in lines:
        word_fre = line.split()
        # sum_num_words += 1
        if (len(word_fre) >= 2):
            word2fre[word_fre[0]] = float(word_fre[1])
            sum_fre_words += float(word_fre[1])
        else:
            print(line)
    for key, value in word2fre.items():
        word2weight[key] = weightpara / (weightpara + value / sum_fre_words)
        # word2weight[key] = 1.0 #method of RVA
    return word2weight


def process_long_input(model, input_ids, attention_mask, start_tokens, end_tokens):
    """

    Parameters
    ----------
    model: 编码模型
    input_ids: (b, l)
    attention_mask: (b, l)
    start_tokens: 对bert而言就是[101]
    end_tokens: [102]

    Returns
    -------

    """
    # Split the input to 2 overlapping chunks. Now BERT can encode inputs of which the length are up to 1024.
    n, c = input_ids.size()
    start_tokens = torch.tensor(start_tokens).to(input_ids)   # 转化为tensor放在指定卡上
    end_tokens = torch.tensor(end_tokens).to(input_ids)
    len_start = start_tokens.size(0)   # 1
    len_end = end_tokens.size(0)       # 1 if bert , 2 if roberta
    if c <= 512:
        output = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_attentions=True,
        )
        sequence_output = output[0]
        attention = output[-1][-1]
    else:
        new_input_ids, new_attention_mask, num_seg = [], [], []   # num_seg记录原来的样本被切成多少片,1 or 2
        seq_len = attention_mask.sum(1).cpu().numpy().astype(np.int32).tolist()  # 在len维度上求和,即每个样本的1的个数,即长度
        for i, l_i in enumerate(seq_len):
            # 对batch中的每一个样本循环
            if l_i <= 512:
                # 如果长度小于512,就直接添加
                new_input_ids.append(input_ids[i, :512])
                new_attention_mask.append(attention_mask[i, :512])
                num_seg.append(1)
            else:
                # 超过512的样本
                # 第一段取开始到511,加结束符
                input_ids1 = torch.cat([input_ids[i, :512 - len_end], end_tokens], dim=-1)
                # 第二段取开始符,加剩下的部分
                input_ids2 = torch.cat([start_tokens, input_ids[i, (l_i - 512 + len_start): l_i]], dim=-1)
                # attention_mask同理
                attention_mask1 = attention_mask[i, :512]
                attention_mask2 = attention_mask[i, (l_i - 512): l_i]
                new_input_ids.extend([input_ids1, input_ids2])
                new_attention_mask.extend([attention_mask1, attention_mask2])
                num_seg.append(2)
        # 在batch维度上拼接
        # 原本的input_ids 是(b, l),经过上面的for循环new_input_ids每一项是(l,)
        # 然后在dim=0上stack,变回了(b, l)
        # 但是此时的b可能已经大于原来的batch_size
        input_ids = torch.stack(new_input_ids, dim=0)
        attention_mask = torch.stack(new_attention_mask, dim=0)
        output = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_attentions=True,
        )
        # 把新构建的输入进行建模,然后把建模结果拼回原来的
        sequence_output = output[0]   # (b, l, 768)
        attention = output[-1][-1]    # (b, ?, l, l)
        i = 0   # i是旧的batch号
        new_output, new_attention = [], []
        for (n_s, l_i) in zip(num_seg, seq_len):
            if n_s == 1:
                # 这个pad没看懂。n_s == 1的话,c - 512应该小于0
                output = F.pad(sequence_output[i], (0, 0, 0, c - 512))
                att = F.pad(attention[i], (0, c - 512, 0, c - 512))
                new_output.append(output)
                new_attention.append(att)
            elif n_s == 2:
                # 取第一个片段的建模结果
                output1 = sequence_output[i][:512 - len_end]
                mask1 = attention_mask[i][:512 - len_end]
                att1 = attention[i][:, :512 - len_end, :512 - len_end]  # 构建第一个样本的时候增加了结束符,所以要去掉它
                output1 = F.pad(output1, (0, 0, 0, c - 512 + len_end))
                mask1 = F.pad(mask1, (0, c - 512 + len_end))
                att1 = F.pad(att1, (0, c - 512 + len_end, 0, c - 512 + len_end))

                # 第二个片段的建模结果
                output2 = sequence_output[i + 1][len_start:]
                mask2 = attention_mask[i + 1][len_start:]
                att2 = attention[i + 1][:, len_start:, len_start:]   # 构建第二个样本的时候增加了开始符,所以要从1开始索引,去掉它
                output2 = F.pad(output2, (0, 0, l_i - 512 + len_start, c - l_i))
                mask2 = F.pad(mask2, (l_i - 512 + len_start, c - l_i))
                att2 = F.pad(att2, [l_i - 512 + len_start, c - l_i, l_i - 512 + len_start, c - l_i])

                # 把两个片段合并
                mask = mask1 + mask2 + 1e-10
                output = (output1 + output2) / mask.unsqueeze(-1)
                att = (att1 + att2)
                att = att / (att.sum(-1, keepdim=True) + 1e-10)
                new_output.append(output)
                new_attention.append(att)
            i += n_s
        sequence_output = torch.stack(new_output, dim=0)
        attention = torch.stack(new_attention, dim=0)

    return sequence_output, attention


def rematch(text, tokens, do_lower_case=True):
    if do_lower_case:
        text = text.lower()
        
    def is_control(ch):
        return unicodedata.category(ch) in ('Cc', 'Cf')
    
    def is_special(ch):
        return bool(ch) and (ch[0] == '[') and (ch[-1] == ']')
    
    def stem(token):
        if token[:2] == '##':
            return token[2:]
        else:
            return token
        
    normalized_text, char_mapping = '', []
    for i, ch in enumerate(text):
        if do_lower_case:
            ch = unicodedata.normalize('NFD', ch)
            ch = ''.join([c for c in ch if unicodedata.category(c) != 'mn'])
        ch = ''.join([c for c in ch if not (ord(c) == 0 or ord(c) == 0xfffd or is_control(c))])
        normalized_text += ch
        char_mapping.extend([i] * len(ch))
    text, token_mapping, offset = normalized_text, [], 0
    for token in tokens:
        if token.startswith('▁'):
            token = token[1:]
        if is_special(token):
            token_mapping.append([])
        else:
            token = stem(token)
            if do_lower_case:
                token = token.lower()
            try:
                start = text[offset:].index(token) + offset
            except Exception as e:
                print(e)
                print(token)
            end = start + len(token)
            token_mapping.append(char_mapping[start: end])
            offset = end
            
    return token_mapping

3.2 初始化各类组件

先import:

import time
import numpy as np
import thulac
import nltk
from nltk.corpus import stopwords
from ltp import LTP

import torch
import torch.nn.functional as F
from transformers import ElectraModel, ElectraTokenizerFast
from sentence_transformers.util import pytorch_cos_sim

from utils import get_word_weight, process_long_input, rematch

3.2.1 标点和停用词

english_punctuations = [',', '.', ':', ';', '?', '(', ')', '[', ']', '&', '!', '*', '@', '#', '$', '%']
chinese_punctuations = '!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏.'
punctuations = ''.join(i for i in english_punctuations) + chinese_punctuations


# 注意只对英文去停,中文停用词保留
stop_words = stopwords.words('english')

3.2.2 预训练词汇权重

weightfile_pretrain = './resources/pretrained_weight_dict.txt'
weightpara_pretrain = 2.7e-4
word2weight_pretrain = get_word_weight(weightfile_pretrain, weightpara_pretrain)

3.2.3 分词/词性标注模型

如果采用SIFRank原作者的策略,则实例化一个lac模型

lac_model = thulac.thulac()

我采用的是ltp模型,首先把自定义词表和模型路径加载一下。

ltp_model_path = '/ltp4_data/base/'  # 这个模型需要去ltp的git上下载
ltp_ner_usr_dict_path = './resources/ner_usr_dict.txt'

usr_dict = []
with open(ltp_ner_usr_dict_path) as f:
    for line in f.readlines():
        usr_dict.append(line.split('\n')[0])

然后构建一个类,用于做分词和词性分析。

class LTPForTokenizeAndPostag:
	"""
	用于分词和词性分析
	---------------
	ver: 2021-11-01
	by: changhongyu
	"""
	def __init__(self, ltp_model_path, ners=None, device='cpu'):
		"""
		:param ltp_model_path: str: ltp模型的路径
		:param ners: list: 用户输入的实体列表
		:param device: str: cpu还是cuda
		"""
		print('Initializing LTP model from {}.'.format(ltp_model_path))
		self.ltp_model = LTP(path=ltp_model_path, device=device)
		print('LTP model created.')
		if ners:
			self.ltp_model.add_words(words=ners, max_window=4)
		# 为了保持与thu-lac模型的词性标记形式一致,做了这个映射
		# 当然,也可以不映射,然后对3.2.4的抽取器进行适当修改
		self.ltp_to_lac_pos_map = {
								   'b': 'a',
								   'nd': 'f',
								   'nh': 'np',
								   'nl': 'ns',
								   'nt': 't',
								   'wp': 'w',
								   'ws': 'x',
							  	  }

	def _get_tokens(self, text):
		tokens, hidden = self.ltp_model.seg(self.ltp_model.sent_split([text]))
		self.hidden = hidden
		para_tokens = []
		for t in tokens:
			para_tokens += t
		
		return para_tokens
	def _get_pos(self, text):
		tags = self.ltp_model.pos(self.hiden)
		para_tags = []
		for t in tags:
			para_tags += t
		
		return para_tags
	
	# 因为lac模型的调用方法是cut,所以保持一致用cut命名
	def cut(self, text):
		"""
		:param text: str: 输入文本
		:return token_list: list: tokenized
		:return token_tag_list: list: token对应的词性
		"""
		token_list = self._get_tokens(text)
		token_tag_list = self._get_pos(text)
		assert len(token_list) == len(token_tag_list)
		token_tag_list_lac = []
		for tag in token_tag_list:
			if tag in self.ltp_to_lac_pos_map:
				token_tag_list_lac.append(self.ltp_to_lac_pos_map[tag])
			else:
				token_tag_list_lac.append(tag)
			
		return [[token, tag] for token, tag in zip(token_list, token_tag_list_lac)]

然后实例化这个模型,替换原来的lac模型:

ltp_pos_model = LTPForTokenizeAndPostag(ltp_model_path, ners=usr_dict, device='cuda:0')

3.2.4 候选短语抽取模型

这个模型的作用是以nltk的正则工具抽取候选关键短语。我在原项目的基础上做了一点点修改,原项目每次抽取都重新实例化抽取器,让我觉得很别扭。

class CandidateExtractor:
    """
    参考SIFRank项目的词性正则抽取候选短语
    """
    def __init__(self):
        grammar = """  NP:
                    {<n.*|a|uw|i|j|x>*<n.*|uw|x>|<x|j><-><m|q>} # Adjective(s)(optional) + Noun(s)"""
        self.parser = nltk.RegexpParser(grammar)
    
    def extract_candidates(self, tokens_tagged):
        keyphrase_candidate = []
        np_pos_tag_tokens = self.parser.parse(tokens_tagged)
        count = 0
        for token in np_pos_tag_tokens:
            if (isinstance(token, nltk.tree.Tree) and token._label == "NP"):
                np = ''.join(word for word, tag in token.leaves())
                length = len(token.leaves())
                start_end = (count, count + length)
                count += length
                keyphrase_candidate.append((np, start_end))
            else:
                count += 1
        
        return keyphrase_candidate
        
candidate_extractor = CandidateExtractor()

3.2.5 词形还原模型

这个没什么好说的,就是一个简单的词形还原,对中文来说作用不大。

lemma_model = nltk.WordNetLemmatizer()

3.2.6 编码模型

这里可以采用多种编码模型,可以多实验几个预训练模型测试一下效果。注意,Roberta系列的模型和XMLRoberta系列的模型由于tokenizer比较特殊,我没有做相应的适配。

Electra模型:

electra_path = './resources/chinese-electra-180g-small-discriminator'
electra_tokenizer = ElectraTokenizerFast.from_pretrained(electra_path)
electra_model = ElectraModel.from_pretrained(electra_path)

Bert模型:

from transformers import BertTokenizerFast, BertModel
bert_path = './resources/bert-base-chinese/'
bert_model = BertModel.from_pretrained(bert_path)
bert_tokenizer = BertTokenizerFast.from_pretrained(bert_path)

Sentence-bert提供的一个语义相似度预训练bert:

from transformers import DistilBertTokenizerFast, DistilBertModel
# distil_bert_path = './resources/distiluse-base-multilingual-cased-v2/'  # 这个是原来的
distil_bert_path = './finetune_embedding_model/SimCSE/4500/'  # 这个是我用SimCSE训练之后的
distil_bert_model = DistilBertModel.from_pretrained(distil_bert_path)
distil_bert_tokenizer = DistilBertTokenizerFast.from_pretrained(distil_bert_path)

这些模型都可以在huggingface网站上找到,参考本文第2部分。

3.3 建立关键短语抽取模型

万事俱备,接下来就把这些组件放在一起,构建一个大类,用于抽取关键短语。这个大类包含一下几个方法:

  1. 构造方法:加载3.2中构建的各个组件;
  2. 添加新的停用词和标点词;
  3. 获取每个token的编码特征列表;
  4. 获取每个token的权重列表;
  5. 获取候选短语列表;
  6. 从候选短语抽取关键短语;
  7. 调用方法,给入文本,抽取关键短语;
  8. 静态方法:获取一个候选的加权表征;
  9. 静态方法:输入文本预处理。

以上方法将会依次呈现在下面的类中:

class SIFRank:
    """
    用于抽取关键短语的SIFRank模型
    [步骤]
    1. 对原句进行tokenize和词性标注
    2. 对原句进行编码,并根据1中tokenize的结果获取embedding_list
    3. 根据1中tokenize的结果获取weight_list
    4. 抽取原句中的候选关键短语
    5. 对候选关键短语进行评分,得到关键短语
    ---------------
    ver: 2021-11-02
    by: changhongyu
    """
    def __init__(self, tokenize_and_postag_model, candidate_extractor, lemma_model,
                 encoding_model, encoding_tokenizer, encoding_pooling, encoding_device, 
                 word2weight_pretrain, stop_words, punctuations):
        """
        :param tokenize_and_postag_model: 分词和词性标注模型
        :param candidate_extractor: 用于抽取候选短语的模型
        :param lemma_model: 用于词根还原的模型, 如果None,则忽略
        :param encoding_model: PretrainedModel: 编码预训练模型
        :param encoding_tokenizer: PretrainedTokenizer: 编码时的tokenizer
        :param encoding_pooling: str: 编码时的池化策略, 'mean'或'max'
        :param encoding_device: str: 编码时的设备, 'cpu'或'cuda'
        :param word2weight_pretrain: dict: 词汇对应权重的大list
        :param stop_words: list: 停用词表
        :param punctuations: list: 标点符号表
        """
        assert encoding_pooling in ['mean', 'max'], Exception("Pooling must be either mean or max.")
        assert encoding_device.startswith('cuda') or encoding_device == 'cpu'
        self.tokenize_and_postag_model = tokenize_and_postag_model
        self.extractor = candidate_extractor
        self.lemma_model = lemma_model
        self.encoding_model = encoding_model
        self.encoding_tokenizer = encoding_tokenizer
        self.encoding_pooling = encoding_pooling
        self.encoding_device = torch.device(encoding_device)
        self.word2weight_pretrain = word2weight_pretrain
        self.stop_words = stop_words
        self.punctuations = punctuations
        print(self)
    
    def __repr__(self):
        infos = ['------SIFRank for key-phrase extract------\n',
                 'SETTINGS: \n'
                 'tokenize_and_postag_model:  {}\n'.format(str(type(self.tokenize_and_postag_model)).replace("'>", "").split('.')[-1]),
                 'lemma_model:  {}\n'.format(str(type(self.lemma_model)).replace("'>", "").split('.')[-1]),
                 'encoding_model:  {}\n'.format(str(type(self.encoding_model)).replace("'>", "").split('.')[-1]),
                 'encoding_device:  {}\n'.format(self.encoding_device),
                 'encoding_pooling:  {}\n'.format(self.encoding_pooling),
                ]
        
        return ''.join(info for info in infos)
    
    def add_stopword(self, stop_word):
        """
        添加停用词,注意停用词是指英文停用词
        """
        self.stop_words.append(stop_word)
        
    def add_punctuation(self, punctuation):
        """
        添加标点符
        """
        self.punctuations.append(punctuation)
    
    def _get_embedding_list(self, text, target_tokens):
        """
        获取以token为划分的embedding的list
        TODO: 对原句进行清洗,过滤掉对encoding_tokenizer而言OOV的词(耗时太长)
        :param text: str: 原文
        :param target_tokens: list: tokenize_and_postag_model对当前输入的分词结果
        """
        embedding_list = []
        self.encoding_model.to(self.encoding_device)

        ## <1. 获取编码
        features = self.encoding_tokenizer(text.lower().replace(' ', '-'),
                                           max_length=1024,
                                           truncation=True,
                                           padding='longest',
                                           return_tensors='pt')
        input_ids = features['input_ids'].to(self.encoding_device)
        # token_type_ids = features['token_type_ids'].to(self.encoding_device)
        attention_mask = features['attention_mask'].to(self.encoding_device)

        with torch.no_grad():
            # enconding_out = self.encoding_model(input_ids, token_type_ids, attention_mask)
            # last_hidden_state = enconding_out['last_hidden_state'].squeeze(0).detach().cpu().numpy()
            enconding_out, _ = process_long_input(self.encoding_model, 
                                                  input_ids, 
                                                  attention_mask, 
                                                  [self.encoding_tokenizer.cls_token_id], 
                                                  [self.encoding_tokenizer.sep_token_id])
            # last_hidden_state: (len, hidden)
            last_hidden_state = enconding_out.squeeze(0).detach().cpu().numpy()

        ## 1>

        ## <2. token对齐
        t_mapping = rematch(text, target_tokens, do_lower_case=True)
        s_mapping = rematch(text, self.encoding_tokenizer.tokenize(text), do_lower_case=True)
        
        token_lens = []
        t_pointer = 0
        t = t_mapping[t_pointer]
        cur_len = 0
        cur_in_t = 0
        for s in s_mapping:
            # print(s, t[cur_in_t: cur_in_t + len(s)])
            if s == t[cur_in_t: cur_in_t + len(s)]:
                cur_len += 1
                cur_in_t += len(s)
                if cur_in_t == len(t):
                    # 判断当前target结束
                    token_lens.append(cur_len)
                    cur_len = 0
                    cur_in_t = 0
                    t_pointer += 1
                    if t_pointer >= len(t_mapping):
                        break
                    t = t_mapping[t_pointer]
        ## 2>
        assert len(token_lens) == len(target_tokens), \
                Exception("Token_lens and target_tokens shape unmatch: {} vs {}.".format(len(token_lens), len(target_tokens)))

        ## <3 根据token_len获取对应的embedding池化
        cur_pos = 0
        for token_len in token_lens:
            if token_len == 0:
                # 如果是空字符,则置为全零
                cur_emb = np.zeros(last_hidden_state.shape[1])
                embedding_list.append(cur_emb)
                continue
            if self.encoding_pooling == 'mean':
                cur_emb = np.mean(last_hidden_state[cur_pos: cur_pos + token_len][:], axis=0)
            elif self.encoding_pooling == 'max':
                cur_emb = np.max(last_hidden_state[cur_pos: cur_pos + token_len][:], axis=0)
            else:
                raise ValueError("Pooling Strategy must be either mean or max.")
            cur_pos += token_len
            embedding_list.append(cur_emb)
        ## 3>

        assert len(embedding_list) == len(target_tokens), \
                Exception("Result embedding list must have same length as target.")

        return embedding_list
    
    def _get_weight_list(self, target_tokens):
        """
        获取weight列表
        :param target_tokens: list: tokenize_and_postag_model对当前输入的分词结果
        :return weight_list: list of float: 每个token对应的预训练权重列表
        """
        weight_list = []
        _max = 0.
        for token in target_tokens:
            token = token.lower()
            if token in self.stop_words or token in self.punctuations:
                weight = 0.
            elif token in self.word2weight_pretrain:
                weight = word2weight_pretrain[token]
            else:
                # 如果OOV,返回截至当前句中最大的token
                weight = _max
            _max = max(weight, _max)
            weight_list.append(weight)
        
        return weight_list
    
    def _get_candidate_list(self, target_tokens, target_poses):
        """
        用词性正则抽取候选关键短语列表
        :param target_tokens: list: tokenize_and_postag_model对当前输入的分词结果
        :param target_poses: list: tokenize_and_postag_model对当前输入词性标注结果
        :return candidates: list of tuples like: ('自然语言', (5, 7))
            NOTE: tuple[1]是在target_tokens中的span,对target_tokens索引,得到tuple[0]
        """
        assert len(target_tokens) == len(target_poses)
        tokens_tagged = [(tok, pos) for tok, pos in zip(target_tokens, target_poses)]
        candidates = self.extractor.extract_candidates(tokens_tagged)
        
        return candidates
    
    def _extract_keyphrase(self, candidates, weight_list, embedding_list, max_keyphrase_num):
        """
        对候选的关键短语计算与原文编码的相似度,获取关键短语
        :param candidates: list of tuples: 候选关键短语list
        :param weight_list: list of float: 每个token的预训练权重列表
        :param embedding_list: list of array: 每个token的编码结果
        :param max_keyphrase_num: int: 最多保留的关键词个数
        :return key_phrases: list of tuple: [(k1, 0.9), ...]
        """
        assert len(weight_list) == len(embedding_list)
        # 获取每个候选短语的编码
        candidate_embeddings_list = []
        for cand in candidates:
            cand_emb = self.get_candidate_weight_avg(weight_list, embedding_list, cand[1])
            candidate_embeddings_list.append(cand_emb)
            
        # 计算候选短语与原文的相似度
        sent_embeddings = self.get_candidate_weight_avg(weight_list, embedding_list, (0, len(embedding_list)))
        sim_list = []
        for i, emb in enumerate(candidate_embeddings_list):
            sim = float(pytorch_cos_sim(sent_embeddings, candidate_embeddings_list[i]).squeeze().numpy())
            sim_list.append(sim)
            
        # 对候选短语归并,词根相同的短语放在一起
        dict_all = {}
        for i, cand in enumerate(candidates):
            if self.lemma_model:
                cand_lemma = self.lemma_model.lemmatize(cand[0].lower()).replace('▲', ' ')
            else:
                cand_lemma = cand[0].lower().replace('▲', ' ')
            if cand_lemma in dict_all:
                dict_all[cand_lemma].append(sim_list[i])
            else:
                dict_all[cand_lemma] = [sim_list[i]]
        
        # 对归并结果求平均
        final_dict = {}
        for cand, sim_list in dict_all.items():
            sum_sim = sum(sim_list)
            final_dict[cand] = sum_sim / len(sim_list)
            
        return sorted(final_dict.items(), key=lambda x: x[1], reverse=True)[: max_keyphrase_num]
    
    def __call__(self, text, max_keyphrase_num):
        """
        抽取关键词
        :param text: str: 待抽取原文
        :param max_keyphrase_num: int: 最多保留的关键词个数
        :return key_phrases: list of tuple: [(k1, 0.9), ...]
        """
        text = self.preprocess_input_text(text)
        t0 = time.time()
        
        ## <1. 对原句进行tokenize和词性标注
        token_and_pos = self.tokenize_and_postag_model.cut(text)
        target_tokens = [t_p[0] for t_p in token_and_pos]
        target_poses = [t_p[1] for t_p in token_and_pos]
        
        for i, token in enumerate(target_tokens):
            if token in self.stop_words:
                target_poses[i] = "u"
            if token == '-':
                target_poses[i] = "-"
            if token in ['"', "'"]:
                target_poses[i] = '"'
                
        t1 = time.time()
        print("耗时统计")
        print("<1. 对原句进行tokenize和词性标注: ", round(t1 - t0, 2), 's')
        ## 1>
        
        ## <2. 对原句进行编码,并根据1中tokenize的结果获取embedding_list
        embedding_list = self._get_embedding_list(text, target_tokens)
        t2 = time.time()
        print("<2. 对原句进行编码: ", round(t2 - t1, 2), 's')
        ## 2>
        
        ## <3. 根据1中tokenize的结果获取weight_list
        weight_list = self._get_weight_list(target_tokens)
        t3 = time.time()
        print("<3. 结果获取weight_list: ", round(t3 - t2, 2), 's')
        ## 3>
        
        ## <4. 抽取原句中的候选关键短语
        candidate_list = self._get_candidate_list(target_tokens, target_poses)
        t4 = time.time()
        print("<4. 抽取原句中的候选关键短语: ", round(t4 - t3, 2), 's')
        ## 4>
        
        ## <5. 对候选关键短语进行评分,得到关键短语
        key_phrases = self._extract_keyphrase(candidate_list, weight_list, 
                                              embedding_list, max_keyphrase_num)
        t5 = time.time()
        print("<5. 对候选关键短语进行评分: ", round(t5 - t4, 2), 's')
        ## 5>
        
        return key_phrases
        
    @staticmethod
    def get_candidate_weight_avg(weight_list, embedding_list, candidate_span):
        """
        获取一个候选词的加权表征
        :param weight_list: list of float: 每个token的预训练权重列表
        :param embedding_list: list of array: 每个token的编码结果
        :param candidate_span: tuple: 候选短语的start和end
        """
        assert len(weight_list) == len(embedding_list)
        start, end = candidate_span
        num_words = end - start
        embedding_size = embedding_list[0].shape[0]

        sum_ = np.zeros(embedding_size)
        for i in range(start, end):
            tmp = embedding_list[i] * weight_list[i]
            sum_ += tmp
        
        return sum_
    
    @staticmethod
    def preprocess_input_text(text):
        """
        对输入原文进行预处理,主要防止两个tokenizer对齐时出现问题
        """
        text = text.lower()
        # 全部判断过于耗时
        # text = ''.join(char for char in text if char in self.encoding_tokenizer.vocab)
        text = text.replace('“', '"').replace('”', '"')
        text = text.replace('‘', "'").replace('’', "'")
        text = text.replace('⁃', '-')
        text = text.replace('\u3000', ' ').replace('\n', ' ')
        text = text.replace(' ', '▲')
        # text = text.replace(' ', '¤')
        
        return text[: 1024]

注意,在上面的类中调用了sentence-transformer中的pytorch_cos_sim方法计算两个张量之间的余弦相似度,如果没有安装这个包,可以自己写个方法实现余弦相似度的计算,这个不难,可以直接百度到。

3.4 抽取应用

将上面的大类实例化:

keyphrase_extractor = SIFRank(tokenize_and_postag_model=ltp_pos_model,
                              candidate_extractor=candidate_extractor,
                              lemma_model=lemma_model,
                              encoding_model=electra_model,
                              encoding_tokenizer=electra_tokenizer,
                              encoding_pooling='mean',
                              encoding_device='cuda:1',
                              word2weight_pretrain=word2weight_pretrain,
                              stop_words=stop_words,
                              punctuations=punctuations)

然后对输入的text,调用:

keyphrase_extractor(text, max_keyphrase_num=10)

即可返回关键短语的降序排列,以及每个关键短语对应的得分。

4. 改进

4.1 增加候选关键短语

候选关键短语是通过正则的方式对词性进行匹配得到的,其关键代码在这一句:

grammar = """  NP:
                    {<n.*|a|uw|i|j|x>*<n.*|uw|x>|<x|j><-><m|q>} # Adjective(s)(optional) + Noun(s)"""

通过修改正则语句,我们可以获得自己想要的候选短语。例如,我希望拿到*'"花岗岩"超声速反舰导弹*这样的短语作为关键短语,通过观察词性发现,这类短语的词性构成是:引号+名词+引号+若干名词,翻译成正则语句就是:

<"><n.*><"><n.*>*<n.*>

把它拼接到原来的语句上:

grammar = """  NP:
                    {<n.*|a|uw|i|j|x>*<n.*|uw|x>|<x|j><-><m|q>|<"><n.*><"><n.*>*<n.*>}"""