P3804 【模板】后缀自动机 (SAM)

这玩意真不是人学的东西

目前只看懂了一般,解释待补把

先存个代码

结构体版本

#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e6+10;
vector<int>vec[maxn];
char s[maxn];
struct node
{
	int zi[26],len,fa;
	node() { memset( zi,0,sizeof zi );len = fa = 0; }
}sam[maxn];
int id=1,las=1,siz[maxn];
long long ans;
void insert(int c)
{
	int p = las, np = ++id; las = np; 
	siz[np] = 1;
	sam[np].len = sam[p].len+1;
	for( ;p&&!sam[p].zi[c];p=sam[p].fa )	sam[p].zi[c] = np;//endpos多加了n
	if( !p ) { sam[np].fa = 1; return; }//之前没出现过这个字母
	int q = sam[p].zi[c];
	if( sam[q].len==sam[p].len+1 )	sam[np].fa = q;//加上c,还是新串的后缀
	else
	{
		int nq = ++id;//q存在不是新串后缀的节点,需要分裂出来
		sam[nq] = sam[q]; sam[nq].len = sam[p].len+1;//nq的endpos包含n
		sam[q].fa = sam[np].fa = nq;
		//nq从q分裂出来,q比nq长,理所应当
		//np的endpos包含n,nq也是,所以nq是np的父亲
		for(;p&&sam[p].zi[c]==q;p=sam[p].fa )	sam[p].zi[c] = nq;
		//这里longest(p)+c必定为新串的后缀,endpos包含n
		//但是p节点的endpos不包含n!!所以为了不矛盾需要修改一下 	
	} 
}
void dfs(int u)
{
	for(auto v:vec[u] )
	{
		dfs(v);
		siz[u] += siz[v];
	}
	if( siz[u]>1 )	ans = max( ans,1ll*siz[u]*sam[u].len );
}
signed main()
{
	scanf("%s",s); int n = strlen( s );
	for(int i=0;i<n;i++)	insert( s[i]-'a' );
	for(int i=2;i<=id;i++)	vec[sam[i].fa].push_back( i );
	dfs(1);
	cout << ans;
}

数组的

#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 3e6+10;
char s[maxn];
int n,t[maxn],rk[maxn],ans;
int mp[maxn][26],fa[maxn],len[maxn],siz[maxn],ed=1,id=1;
void insert(int c)
{
	int las = ed, p = ++id; ed = p;
	len[p] = len[las]+1, siz[p] = 1;
	while( las&&!mp[las][c] )	mp[las][c] = p, las = fa[las];
	if( !las ) { fa[p] = 1; return; }
	int q = mp[las][c];
	if( len[las]+1==len[q] )	fa[p]=q;
	else
	{
		int nq = ++id;
		fa[nq] = fa[q]; fa[q] = fa[p] = nq;
		memcpy( mp[nq],mp[q],sizeof mp[nq] ); len[nq] = len[las]+1;
		while( las&&mp[las][c]==q )	mp[las][c]=nq,las = fa[las];
	}
}
signed main()
{
	scanf("%s",s+1); n = strlen( s+1 );
	for(int i=1;i<=n;i++)	insert(s[i]-'a');
	for(int i=1;i<=id;i++)	t[len[i]]++;
	for(int i=1;i<=id;i++)	t[i] += t[i-1];
	for(int i=id;i>=1;i--)	rk[t[len[i]]--] = i;
	for(int i=id;i>=1;i--)
	{
		int now = rk[i];
		siz[fa[now]] += siz[now];
		if( siz[now]>1 )	ans = max( ans,1ll*siz[now]*len[now] );
	}
	printf("%lld",ans);
}