一. 概述

AC 自动机是一种多模式匹配算法。

AC 自动机构建在 Trie 的结构基础上,结合了 Kmp 算法的失配指针思想。

在进行多模式串匹配前,只有两个步骤需要去实现:

\(1.\)

\(2.\)

二.构建 Trie 树

只需要按照 Trie 树的基本构建方法搭建即可。

请注意,Trie 树节点的含义十分重要:

它表示的是某个模式串的前缀,也就是一个状态。

而 Trie 的边就是状态的转移。

对于概念理解不够透彻的同学可以看这里

代码如下:

void insert (char *s) {
	int slen = strlen (s), u = 0, c;
	for (int i = 0; i < slen; i ++) {
		c = s[i] - 'a';
		if (!trie[u][c]) {//无节点就添加节点。
			trie[u][c] = ++ tot;
		}
		u = trie[u][c];
	}
	tag[u] ++;
}

三.Fail 指针

这是 AC 自动机的核心

什么是 Fail 指针呢?

如果一个 Trie 树上的节点 u 的 Fail 指针指向 节点 v,那么这就表示根节点到节点 v 的字符串是 根节点到节点 u 的字符串的一个后缀。

如下图:

ac自动机 java pom ac自动机代码_ac自动机 java pom

\(3\) 号节点的 Fail 指针就指向 \(5\)

因为根节点到 \(3\) 号节点的字符串为 \(ABC\),

根节点到 \(5\) 号节点的字符串为 \(BC\),

由于 \(BC\) 是 \(ABC\) 的一个后缀,所以 \(3\) 号节点的 Fail 指针指向 \(5\)

四.构建 Fail 指针

对于一个 Trie 树上的节点 u,设它的父节点为 v,两个节点通过字符 c 连接,也就是说 \(trie_{v,c} = u\)。

那么求 Fail 指针的有两个,如下:

\(1.\) 如果 \(trie_{fail_p,c}\) 不是空节点,那么就将节点 u 的 Fail 指针指向 \(trie_{fail_p,c}\)。

\(2.\) 如果 \(trie_{fail_p,c}\) 是空节点,那么继续向上寻找 \(trie_{fail_{fail_p}, c}\),继续重复第 \(1\)

注意:如果找寻到了根节点,那么就将节点 u 的 Fail 指针指向根节点。

代码如下:

queue<int> q;

inline void GetFail () {
	for (int i = 0; i < 26; i ++) {
		if (trie[0][i]) {//非空节点入队。
			q.push (trie[0][i]); 
		}	
	}
	
	while (!q.empty()) {
		int u = q.front();
		
		q.pop();
		
		for (int i = 0; i < 26; i ++) {
			if (trie[u][i]) {
				q.push (trie[u][i]);//非空节点入队。
				
				fail[trie[u][i]] = trie[fail[u]][i]; 
			}
			
			else {
				trie[u][i] = trie[fail[u]][i];
			}
		}
	}
}

稍微对于代码做一个解释:

这里的 GetFail 函数将 Trie 树上所有节点按照 BFS 的顺序入队,最后依次求 Fail 指针。

首先我们单独处理根节点,代码中编号为 \(0\),将其非空的子节点入队。

然后每次取出队首处理 Fail 指针,遍历 \(26\)

\(Fail_u\)

五.查询出现个数

问题如下:

关于许多模式串,求有多少个模式串在文本串中出现。

根据 Fail 指针的定义,如果当前字符串匹配成功,那么它的 Fail 指针指向的字符串也可以成功匹配。

因为 Fail 指针指向的字符串与其后缀匹配。

这样就启发我们一直跳 Fail 指针,累计其答案。

代码如下:

int query (char *s) {
	int slen = strlen (s), u = 0, res = 0, c;
	for (int i = 0; i < slen; i ++) {
		c = s[i] - 'a';
		u = trie[u][c];
		for (int j = u; j && ~tag[j]; j = fail[j]) {
			res += tag[j];
			tag[j] = -1;//标记,重复的不累计答案。
		}
	}
	return res;
}

六.查询最大出现次数

P3796 【模板】AC自动机(加强版)

给出若干个模式串和一个文本串,求某个模式串在文本串中出现的最大次数和该模式。

我们考虑如何查询最大出现次数。

由于会出现文本串中可能会出现多次模式串,所以将 \(tag\) 数组转化为存储该字符串的顺序,在统计答案时用一个 \(vis\)

然后遍历 \(vis\) 数组,当 \(vis_i\) 与最大值相同时,就输出第 \(i\)

多测记得清空

代码如下:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>

using namespace std;

const int N = 333333;

int n, vis[N];

struct AC_automaton {
	int trie[N][26], fail[N], tag[N], tot = 0;
	
	inline void Clear() {
		memset (trie, 0, sizeof (trie));
		memset (tag, 0, sizeof (tag));
		memset (fail, 0, sizeof (fail));
		memset (vis, 0, sizeof (vis));
		
		tot = 0;
	}
	
	inline void Insert (char *s, int v) {
		int slen = strlen (s), u = 0, c;
		
		for (int i = 0; i < slen; i ++) {
			c = s[i] - 'a';
			
			if (!trie[u][c]) {
				trie[u][c] = ++ tot;
			}
			
			u = trie[u][c];
		} 
		
		tag[u] = v;
	}
	
	queue<int> q;

	inline void GetFail () {
		for (int i = 0; i < 26; i ++) {
			if (trie[0][i]) {
				q.push (trie[0][i]); 
			}
		}
		
		while (!q.empty()) {
			int u = q.front();
			
			q.pop();
			
			for (int i = 0; i < 26; i ++) {
				if (trie[u][i]) {
					q.push (trie[u][i]);
					
					fail[trie[u][i]] = trie[fail[u]][i]; 
				}
				
				else {
					trie[u][i] = trie[fail[u]][i];
				}
			}
		}
	}
	
	inline int Query(char *s) {
		int slen = strlen (s), u = 0, ans = 0;
		
		for (int i = 0; i < slen; i ++) {
			int c = s[i] - 'a';
			
			u = trie[u][c];
			
			for (int j = u; j; j = fail[j]) {
				if (!tag[j]) {
					//没有该节点,往下一个 Fail 指针跳。
					continue;
				}
				
				vis[tag[j]] ++;
				//统计出现次数。
			}
		}
		
		for (int i = 1; i <= n; i ++) {
			ans = max (ans, vis[i]);
			//取最大值。
		}
		
		return ans;
	}
}AC;

char c[200][90];

char TXT[1919810];

int main() {
	while (scanf ("%d", &n) && n != 0) {
		AC.Clear();//多测清空!!!!!
		
		for (int i = 1; i <= n; i ++) {
			scanf ("%s", c[i]);
			
			AC.Insert (c[i], i);	
		}
		
		AC.Build ();
		
		scanf ("%s", TXT);
		
		int mx = AC.Query (TXT);
		
		printf ("%d\n", mx);
		
		for (int i = 1; i <= n; i ++) {
			if (vis[i] == mx) {
				printf ("%s\n", c[i]);
			}
		}
	}
	
	return 0;
}

七.基础例题

P3808 【模板】AC自动机(简单版)

Ybtoj A. 【例题1】单词查询

这两道题就是以上模块的基本操作。

给定文本串和若干个模式串,求出有多少个不同的模式串在文本串中出现。

代码如下:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>

using namespace std;

const int N = 5e5 + 7;

char a[N * 20];

int n;

struct AC_automaton {
	int tag[N], trie[N][26], fail[N], tot;
	
	void insert (char *s) {
		int slen = strlen (s), u = 0, c;
		for (int i = 0; i < slen; i ++) {
			c = s[i] - 'a';
			if (!trie[u][c]) {
				trie[u][c] = ++ tot;
			}
			u = trie[u][c];
		}
		tag[u] ++;
	}
	
	queue<int> q;
	
	void build () {
		int u;
		for (int i = 0; i < 26; i ++) {
			if (trie[0][i]) {
				fail[trie[0][i]] = 0;
				q.push (trie[0][i]);
			}
		}
		while (!q.empty()) {
			u = q.front();
			q.pop();
			for (int i = 0; i < 26; i ++) {
				if (trie[u][i]) {
					fail[trie[u][i]] = trie[fail[u]][i];
					q.push (trie[u][i]);
				}
				else {
					trie[u][i] = trie[fail[u]][i];
				}
			}
 		}
	}
	
	int query (char *s) {
		int slen = strlen (s), u = 0, res = 0, c;
		for (int i = 0; i < slen; i ++) {
			c = s[i] - 'a';
			u = trie[u][c];
			for (int j = u; j && ~tag[j]; j = fail[j]) {
				res += tag[j];
				tag[j] = -1;
			}
		}
		return res;
	}
}AC;

int main() {
	scanf ("%d", &n);
	for (int i = 1; i <= n; i ++) {
		scanf ("%s", a);
		AC.insert (a);
	}
	AC.build();
	
	scanf ("%s", a);
	int ans = AC.query(a);
	cout << ans << endl;
	return 0;
}

P3966 [TJOI2013]单词

Ybtoj B. 【例题2】单词频率

首先,定义一个节点的权值为该节点属于的字符串个数。

那么,一个节点表示的字符串,在整个字典树中出现的次数就是子树的权值和。

代码如下:

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <queue>

using namespace std;

const int N = 1222222;

char c[N];

int n, ans[N], tot = 0;

struct AC_automaton {
	int tag[N], trie[N][26], fail[N], q[N], siz[N];
	//手写队列方便。 
	
	inline void Insert (char *s, int k) {
		int slen = strlen (s), u = 0, c;
		
		for (int i = 0; i < slen; i ++) {
			c = s[i] - 'a';
			
			if (!trie[u][c]) {
				trie[u][c] = ++ tot;
			}
			
			u = trie[u][c];
			siz[u] ++;
		}
		
		tag[k] = u;//记录第k个字符串的最后状态。 
	}
	
	inline void GetFail () {
		int head = 0, tail = 0, u = 0, c;
		
		for (int i = 0; i < 26; i ++) {
			if (trie[0][i]) {
				q[++ tail] = trie[0][i];
			}
		}
		
		while (head < tail) {
			u = q[++ head];
			
			for (int i = 0; i < 26; i ++) {
				if (trie[u][i]) {
					q[++ tail] = trie[u][i];
					fail[trie[u][i]] = trie[fail[u]][i];
				}
				
				else {
					trie[u][i] = trie[fail[u]][i];
				}
			}
		} 
	}
	
	inline void Query () {
		for (int i = tot; i >= 0; i --) {
			siz[fail[q[i]]] += siz[q[i]];//倒推计算子树和。 
		} 
		
		for (int i = 1; i <= n; i ++) {
			printf ("%d\n", siz[tag[i]]);
		}
	}
}AC;

int main() {
	scanf ("%d", &n);
	
	for (int i = 1; i <= n; i ++) {
		scanf ("%s", c);
		
		AC.Insert (c, i); 
	}
	
	AC.GetFail ();
	AC.Query ();
	
	return 0; 
}

P5231 [JSOI2012]玄武密码

Ybtoj C. 【例题3】前缀匹配

要求对于每一个模式串,求出其最长的前缀 \(p\),满足 \(p\)

题目稍有变化,思维难度还是比较低的。

我们可以设 \(tag_i\) 表示 Trie 树上的 \(i\)

那么我们就可以匹配出 \(tag\)

代码如下:

#include <iostream>
#include <algorithm>
#include <cstring>
#include <queue>
#include <cstdio>

using namespace std;

const int N = 10000007;
const int M = 100007;
const int T = 107;

int n, m;

char TXT[N];

char p[M][T];

struct AC_automaton {
	int trie[N][4], tag[N], fail[N], tot = 0;
	
	inline int Change (char c) {
		if (c == 'E') {
			return 0;
		}
		
		else if (c == 'S') {
			return 1;
		}
		
		else if (c == 'W') {
			return 2;
		}
		
		else if (c == 'N') {
			return 3;
		}
	}
	
	inline void Insert (char *s) {
		int slen = strlen (s), u = 0, c;
		
		for (int i = 0; i < slen; i ++) {
			c = Change (s[i]);
			
			if (!trie[u][c]) {
				trie[u][c] = ++ tot;
			}
			
			u = trie[u][c];
		}
	}
	
	queue<int> q;
	
	inline void GetFail () {
		for (int i = 0; i < 4; i ++) {
			if (trie[0][i]) {
				q.push (trie[0][i]); 
			}
		}
		
		while (!q.empty()) {
			int u = q.front();
			
			q.pop();
			
			for (int i = 0; i < 4; i ++) {
				if (trie[u][i]) {
					q.push (trie[u][i]);
					
					fail[trie[u][i]] = trie[fail[u]][i];
				}
				
				else {
					trie[u][i] = trie[fail[u]][i];
				}
			}
		}
	}
	
	inline void Find (char *T) {
		int Tlen = strlen (T), u = 0, ans = 0;
		
		for (int i = 0; i < Tlen; i ++) {
			int c = Change (T[i]);
			u = trie[u][c];
			
			for (int j = u; j && !tag[j]; j = fail[j]) {
				tag[j] = 1;
				//求tag。 
			}
		}
	}
		
	inline int Query (char *T) {
		int u = 0, ans = 0, Tlen = strlen (T);
	 
		for (int i = 0; i < Tlen; i ++) {
			int cc = Change (T[i]);
			u = trie[u][cc];
			
			if (tag[u]) {
				ans = i + 1;//下标从0开始,要+1。 
			}
			
			else {
				break;
			}
		}
		
		return ans;
	}
	
}AC;

int main() {
	scanf ("%d%d", &n, &m);
	
	scanf ("%s", TXT);
	
	for (int i = 1; i <= m; i ++) {
		scanf ("%s", p[i]);
		
		AC.Insert (p[i]);
	}
	
	AC.GetFail ();
	AC.Find (TXT); 
	
	for (int i = 1; i <= m; i ++) { 
		printf ("%d\n", AC.Query (p[i]));
	} 
	return 0; 
}