AC自动机需要自备两个前置技能:KMP和trie树。
不要看代码,先理解思路。都不复杂,不理解的可以看我前面的博客。 

1、问题来源

ac自动机其实就是一种多模匹配算法,那么什么叫做多模匹配算法

单模就是 一个大长字符串里 找 个 单词
多模就是 一个大长字符串里 找 个 单词

单模的问题 用 KMP 算法

多模的问题 用 ac自动机

单模就是给你一个单词,然后给你一个字符串,问你这个单词是否在这个字符串中出现过(匹配),这个问题可以用kmp算法在比较高效的效率上完成这个任务。
那么现在我们换个问题,给你很多个单词,然后给你一段字符串,问你有多少个单词在这个字符串中出现过,当然我们暴力做,用每一个单词对字符串做kmp,这样虽然理论上可行,但是时间复杂度非常之高,当单词的个数比较多并且字符串很长的情况下不能有效的解决这个问题,所以这时候就要用到我们的ac自动机算法了。

KMP是用于一对一的字符串匹配,即在母串中寻求一个模式串的匹配

trie虽然能用于多模式匹配,但是每次匹配失败都需要进行回溯,如果模式串很长的话会很浪费时间

2、核心算法思路

ac自动机,就是在tire树的基础上,增加fail指针,如果当前点匹配失败,则将指针转移到fail指针指向的地方,这样就不用回溯,而可以一路匹配下去了。

3、算法示例

给定n个模式串和1个文本串,求有多少个模式串在文本串里出现过。

注意:是出现过,就是出现多次只算一次。

3.1 建立Trie树

我们将n个模式串建成一颗Trie树,建树的方式和建Trie完全一样。

NLP AC自动机 ac自动机原理_NLP AC自动机

3.2 构造fail指针

如何确定fail指针??
重点:如果一个点 X 的Fail指针指向 Y 。root到Y的字符串root到X的字符串 的一个后缀
重点:如果一个点 X 的Fail指针指向 Y 。root到Y的字符串root到X的字符串 的一个后缀
重点:如果一个点 X 的Fail指针指向 Y 。root到Y的字符串root到X的字符串 的一个后缀

3.2.1构造第一步

把所有第一层的节点的 fail指针指向 root

NLP AC自动机 ac自动机原理_NLP AC自动机_02

3.2.2构造第二步

用BFS(广度优先遍历)方法把其他层的节点,逐层 构造 fail指针 说一下构造的思路 重点:如果一个点 X 的Fail指针指向 Y 。root到Y的字符串 是 root到X的字符串 的一个后缀。 举一个例子: root到5C 的串 是 BC ,root到 3C的 串 是 C root到7C 的后缀 包括:ABC、BC、C 则 7C的 fail指针应该指向 交集 最长 的那个。即,5C。因为他们都有BC的交集。

后面的以此类推:
4B 可以指向2B。root到2B的 串 是 B。root到4B的后缀包括:AB、B。有交集。所以可以指向。
5C 只能指向 3C,他们有交集 是 C
6D 只能指向root. 6D和8D没有交集. root到8D的 串 是 BCD,root到6D的后缀 包括:BD、D。没有交集。
8D 只能指向root。8D和6D没有交集,6D的 串 是 BD,root到8D的后缀包括:BCD、CD、D。没有交集的 都指向root。

开始检索

为了避免重复计算,我们每经过一个点就打个标记为−1,下一次经过就不重复计算了。

同时,如果一个字符串匹配成功,那么他的Fail也肯定可以匹配成功(后缀嘛),于是我们就把Fail再统计答案,同样,Fail的Fail也可以匹配成功,以此类推……经过的点累加flag,标记为−1。

最后主要还是和Trie的查询是一样的。

划重点:

模式串 [‘hert’,‘er’,‘rtv’]

匹配串‘hertvc’

NLP AC自动机 ac自动机原理_字符串_03

如果一次遍历,会出现 检索出 hert ,然后 直接指向rt,然后检索出 rtv。而忽略了er !
因此,每行进一个字符 都必须检索一次trie树!防止漏项!

完整代码

from collections import defaultdict
class Node:
    def __init__(self, state_num, ch=None):
        self.state_num = state_num
        self.ch = ch
        self.children = []


class Trie(Node):
    """
    实现了一个简单的字典树
    """

    def __init__(self):
        Node.__init__(self, 0)

    def init(self):
        self._state_num_max = 0
        self.goto_dic = defaultdict(lambda: -1)
        self.fail_dic = defaultdict(int)
        self.output_dic = defaultdict(list)

    def build(self, patterns):
        """
        参数 patterns 如['he', 'she', 'his', 'hers']
        """
        for pattern in patterns:
            self._build_for_each_pattern(pattern)
        self._build_fail()

    def _build_for_each_pattern(self, pattern):
        """
        将pattern添加到当前的字典树中
        """
        current = self
        for ch in pattern:
            # 判断字符 ch 是否为节点 current 的子节点
            index = self._ch_exist_in_node_children(current, ch)
            # 不存在 添加新节点并转向
            if index == -1:
                current = self._add_child_and_goto(current, ch)
            # 存在 直接 goto
            else:
                current = current.children[index]
        self.output_dic[current.state_num] = [pattern]

    def _ch_exist_in_node_children(self, current, ch):
        """
        判断字符 ch 是否为节点 current 的子节点,如果是则返回位置,否则返回-1
        """
        for index in range(len(current.children)):
            child = current.children[index]
            if child.ch == ch:
                return index
        return -1

    def _add_child_and_goto(self, current, ch):
        """
        在当前的字典树中添加新节点并转向
        新节点的编号为 当前最大状态编号+1
        """
        self._state_num_max += 1
        next_node = Node(self._state_num_max, ch)
        current.children.append(next_node)
        # 修改转向函数
        self.goto_dic[(current.state_num, ch)] = self._state_num_max
        return next_node

    def _build_fail(self):
        node_at_level = self.children
        while node_at_level:
            node_at_next_level = []
            for parent in node_at_level:
                node_at_next_level.extend(parent.children)
                for child in parent.children:
                    v = self.fail_dic[parent.state_num]
                    while self.goto_dic[(v, child.ch)] == -1 and v != 0:
                        v = self.fail_dic[v]
                    fail_value = self.goto_dic[(v, child.ch)]
                    self.fail_dic[child.state_num] = fail_value
                    if self.fail_dic[child.state_num] != 0:
                        self.output_dic[child.state_num].extend(self.output_dic[fail_value])
            node_at_level = node_at_next_level


class AC(Trie):
    def __init__(self):
        Trie.__init__(self)

    def init(self, patterns):
        Trie.init(self)
        self.build(patterns)

    def goto(self, s, ch):
        if s == 0:
            if (s, ch) not in self.goto_dic:
                return 0
        return self.goto_dic[(s, ch)]

    def fail(self, s):
        return self.fail_dic[s]

    def output(self, s):
        return self.output_dic[s]

    def search(self, text):
        current_state = 0
        ch_index = 0
        while ch_index < len(text):
            ch = text[ch_index]

            if self.goto(current_state, ch) == -1:
                current_state = self.fail(current_state)

            current_state = self.goto(current_state, ch)

            patterns = self.output(current_state)
            if patterns:
                print(current_state, *patterns)

            ch_index += 1


if __name__ == "__main__":
    ac = AC()
    ac.init
    ac.init(['hert','er','rtv'])
    ac.search("hertvc")