代码地址 Advanced: Making Dynamic Decisions and the Bi-LSTM CRF — PyTorch Tutorials 1.11.0+cu102 documentation

pytorch加载训练好的bert模型预测 pytorch bilstm-crf模型_sed

https://pytorch.org/tutorials/beginner/nlp/advanced_tutorial.html

这仅仅是pytorch 给的一个 BiLSTM CRF简单示例 这里分析下源码方便对crf有个清晰的认识

开始分析代码

def argmax(vec):
    # return the argmax as a python int
    # idx 是 最大值所在的索引
    # 找出向量的最大索引
    _, idx = torch.max(vec, 1)
    return idx.item()


def prepare_sequence(seq, to_ix):
    # 输入的sentence转成 id
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)


# Compute log sum exp in a numerically stable way for the forward algorithm
def log_sum_exp(vec):
    # vec shape (1, 5)
    max_score = vec[0, argmax(vec)]
    # 将取出的最大值扩展到和vec 同shape (1, 5)
    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    # 这里减去最大值 在加上最大值 是为了防止log数据溢出
    # max + log Σexp(xi -  max) = max + log (Σexp(xi) / emax) = max + logΣexp(xi) - max
    return max_score + \
           torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))

接下来看 BiLSTM_CRF类

class BiLSTM_CRF(nn.Module):

    def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim):
        super(BiLSTM_CRF, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.tag_to_ix = tag_to_ix
        self.tagset_size = len(tag_to_ix)
        # 定义词向量
        self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
        # 定义一个双向bi-lstm
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2,
                            num_layers=1, bidirectional=True)

        # Maps the output of the LSTM into tag space.
        self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)

        # Matrix of transition parameters.  Entry i,j is the score of
        # transitioning *to* i *from* j.
        # 定义转移矩阵 transitions[i, j] 表示 从状态j转移到状态i
        # 和正常表示想法
        self.transitions = nn.Parameter(
            torch.randn(self.tagset_size, self.tagset_size))

        # These two statements enforce the constraint that we never transfer
        # to the start tag and we never transfer from the stop tag
        # 加上强制限制 1 没有状态能转移到 start tag  2 end stag 不能转移到任何状态
        self.transitions.data[tag_to_ix[START_TAG], :] = -10000
        self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000
        # 初始化lstm 隐藏状态
        self.hidden = self.init_hidden()

    def init_hidden(self):
        # 初始化lstm 隐藏状态参数
        return (torch.randn(2, 1, self.hidden_dim // 2),
                torch.randn(2, 1, self.hidden_dim // 2))

    def _forward_alg(self, feats):
        # Do the forward algorithm to compute the partition function
        # 表示t0时刻到所有状态的分数score
        init_alphas = torch.full((1, self.tagset_size), -10000.)
        # START_TAG has all of the score.
        # 表示t0时刻只能到start tag 状态
        init_alphas[0][self.tag_to_ix[START_TAG]] = 0.

        # Wrap in a variable so that we will get automatic backprop
        # forward_var 存的是 t时刻 到每个状态的分数总和
        # init_alphas
        forward_var = init_alphas

        # Iterate through the sentence
        # 遍历 feats 也就是遍历 每个词 也就是 每个时刻
        for feat in feats:
            # 定义临时数组 存放当前时刻 到所有状态的分数总和
            # 比如 alphas_t = [-1.3660, 1.6381, 0.2685, -9999.5635, -0.6746]
            # alphas_t[i] 表示当前时刻 到 第i个状态的路分数总和
            alphas_t = []  # The forward tensors at this timestep
            for next_tag in range(self.tagset_size):
                # broadcast the emission score: it is the same regardless of
                # the previous tag
                emit_score = feat[next_tag].view(
                    1, -1).expand(1, self.tagset_size)
                # the ith entry of trans_score is the score of transitioning to
                # next_tag from i
                # 得到转移分数
                trans_score = self.transitions[next_tag].view(1, -1)
                # The ith entry of next_tag_var is the value for the
                # edge (i -> next_tag) before we do log-sum-exp
                # 对 trans_score + emit_score 分数进行累积
                next_tag_var = forward_var + trans_score + emit_score
                # The forward variable for this tag is log-sum-exp of all the
                # scores.
                # 对当前时刻转移到当前状态分数进行求和
                alphas_t.append(log_sum_exp(next_tag_var).view(1))
            forward_var = torch.cat(alphas_t).view(1, -1)
        # 得到最后分数 添加 stop tag 分数
        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        # 对最后时刻所有状态分数求和
        alpha = log_sum_exp(terminal_var)
        return alpha

    def _get_lstm_features(self, sentence):
        # 得到 双向lstm的分数值
        self.hidden = self.init_hidden()
        embeds = self.word_embeds(sentence).view(len(sentence), 1, -1)
        # (11, 1, 4)
        lstm_out, self.hidden = self.lstm(embeds, self.hidden)
        lstm_out = lstm_out.view(len(sentence), self.hidden_dim)
        lstm_feats = self.hidden2tag(lstm_out)
        return lstm_feats

    def _score_sentence(self, feats, tags):
        # Gives the score of a provided tag sequence
        score = torch.zeros(1)
        tags = torch.cat([torch.tensor([self.tag_to_ix[START_TAG]], dtype=torch.long), tags])
        # 这里很好理解 得到real path 分数
        for i, feat in enumerate(feats):
            # 从这里也可以看出 self.transitions[i, j]表示从j 状态转移到 i状态
            score = score + \
                    self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]
        score = score + self.transitions[self.tag_to_ix[STOP_TAG], tags[-1]]
        return score

    def _viterbi_decode(self, feats):
        # 这里是用来进行 crf解码的 decoding
        # backpointers 存储整个句子每个时刻到每个状态的最佳状态路径
        backpointers = []

        # Initialize the viterbi variables in log space
        # 初始化 start tag 时刻到各个状态的最大分数值 
        init_vvars = torch.full((1, self.tagset_size), -10000.)
        init_vvars[0][self.tag_to_ix[START_TAG]] = 0

        # forward_var at step i holds the viterbi variables for step i-1
        # forward_var 存储了 t 时刻 到 各个状态的最大分数
        forward_var = init_vvars
        # 遍历sentence 也就是 所有时刻
        for feat in feats:
            bptrs_t = []  # holds the backpointers for this step
            viterbivars_t = []  # holds the viterbi variables for this step
            # 遍历所有状态
            for next_tag in range(self.tagset_size):
                # next_tag_var[i] holds the viterbi variable for tag i at the
                # previous step, plus the score of transitioning
                # from tag i to next_tag.
                # We don't include the emission scores here because the max
                # does not depend on them (we add them in below)
                # 得到当前时刻 到 该状态的转移分数值
                next_tag_var = forward_var + self.transitions[next_tag]
                # 选出到当前状态的最大的状态值
                best_tag_id = argmax(next_tag_var)
                # 记录下最大状态值
                bptrs_t.append(best_tag_id)
                # 记录最大状态的转移分数值
                viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
            # Now add in the emission scores, and assign forward_var to the set
            # of viterbi variables we just computed
            # 得到 emit score + transition score 并进行更新
            forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
            # 记录当前时刻 到每个状态的最佳状态
            #  比如 bptrs_t = [2, 2, 2, 2, 1] 中 bptrs_t[i] 表示 bptrs_t[i]状态 转移到 i 状态是分数最大的
            # 最终 backpointers 存了整个句子每个时刻到每个状态的最佳状态路径
            backpointers.append(bptrs_t)

        # Transition to STOP_TAG
        # 最后加上 转移到 stop_tag 分数 
        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        # 从 后面到前面逆推 最佳路径
        best_tag_id = argmax(terminal_var)
        path_score = terminal_var[0][best_tag_id]

        # Follow the back pointers to decode the best path.
        best_path = [best_tag_id]
        # 从后到前得到最佳路径
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)
        # Pop off the start tag (we dont want to return that to the caller)
        start = best_path.pop()
        assert start == self.tag_to_ix[START_TAG]  # Sanity check
        best_path.reverse()
        return path_score, best_path

    def neg_log_likelihood(self, sentence, tags):
        # 得到 lstm 分数 也就是 发射分数
        feats = self._get_lstm_features(sentence)
        # 计算所有路径总得分
        forward_score = self._forward_alg(feats)
        # 计算真实路径score
        gold_score = self._score_sentence(feats, tags)
        # 实现损失函数的定义
        # 目标是极大化  p(y|x) = exp score(x, y) / Σexp score(x, yi)
        # 因为我们计算的是对数概率  也就是 最大化 log p(y|x) = score(x, y) - Σexp score(x, yi)
        # score(x, y) == gold_score 就是计算出来的真实路径
        # forward_score 就是 Σexp score(x, yi)
        # 所以最大化 gold_score - forward_score 等于 最小化 forward_score - gold_score
        return forward_score - gold_score

    def forward(self, sentence):  # dont confuse this with _forward_alg above.
        # 这个forward是用来解码的 就是计算出最大得分的路径
        # Get the emission scores from the BiLSTM
        # (11, 5)
        lstm_feats = self._get_lstm_features(sentence)
        #  通过 veterbi方法实现解码 得到分数最大的路径
        # Find the best path, given the features.
        score, tag_seq = self._viterbi_decode(lstm_feats)
        return score, tag_seq