[BZOJ3676][Apio2014]回文串
试题描述
输入
输入只有一行,为一个只包含小写字母(a -z)的非空字符串s。
输出
输出一个整数,为逝查回文子串的最大出现值。
输入示例
abacaba
输出示例
7
数据规模及约定
数据满足1≤字符串长度≤300000。
题解
首先我们用 manacher 求出所有极长的回文子串,然后通过这个我们可以求出以每个位置 i 为结尾,向左延伸到的最长回文子串长度是多少(记为 len[i])。方法就是先对于极长的回文子串打一下标记,然后从后往前扫一遍,每往前移一个位置值减一,然后和当前位置的标记取个 max。
接下来我们构造后缀自动机,对于每个状态的 right 集合,我们只要知道这个集合的大小就知道当前串出现了几次;现在我们还需要保证所有找到的串是回文的,于是就可以用上面求出来的 len[i] 做。不难发现对于一个状态 u 的 right 集合,把集合内所有位置上的 len[i] 取 min(注意判断这个 min 值必须要小于等于 Max[u])就是当前节点的答案了。
#include <iostream> #include <cstdio> #include <cstdlib> #include <cstring> #include <cctype> #include <algorithm> using namespace std; #define maxn 600010 #define maxnode 600010 #define maxa 26 #define oo 2147483647 #define LL long long char S[maxn], Str[maxn]; int n, Len[maxn], len[maxn], alp[maxn]; int rt, last, ToT, to[maxnode][maxa], par[maxnode], Max[maxnode], val[maxnode], mnl[maxnode]; void extend(int pos) { int x = Str[pos] - 'a', p = last, np = ++ToT; Max[np] = Max[p] + 1; val[np] = 1; mnl[np] = len[pos]; last = np; while(p && !to[p][x]) to[p][x] = np, p = par[p]; if(!p){ par[np] = rt; return ; } int q = to[p][x]; if(Max[q] == Max[p] + 1){ par[np] = q; return ; } int nq = ++ToT; Max[nq] = Max[p] + 1; mnl[nq] = oo; memcpy(to[nq], to[q], sizeof(to[q])); par[nq] = par[q]; par[q] = par[np] = nq; while(p && to[p][x] == q) to[p][x] = nq, p = par[p]; return ; } int sa[maxnode], Ws[maxnode]; int main() { scanf("%s", Str + 1); n = strlen(Str + 1); for(int i = 1; i <= n; i++) S[(i<<1)-1] = Str[i], S[i<<1] = 'z' + 1; n <<= 1; for(int i = 1; i <= n; i++) alp[i] = alp[i-1] + (isalpha(S[i]) ? 1 : 0); int mxp = 0; for(int i = 1; i <= n; i++) { int mxr = mxp + Len[mxp] - 1; if(mxr >= i) Len[i] = min(Len[(mxp<<1)-i], mxr - i + 1); else Len[i] = 1; while(1 <= i - Len[i] + 1 && i + Len[i] - 1 <= n && S[i-Len[i]+1] == S[i+Len[i]-1]) Len[i]++; Len[i]--; if(mxr < i + Len[i] - 1) mxp = i; } for(int i = 1; i <= n; i++) len[i+Len[i]-1] = max(len[i+Len[i]-1], Len[i]); // for(int i = 1; i <= n; i++) printf("%d%c", Len[i], i < n ? ' ' : '\n'); // for(int i = 1; i <= n; i++) printf("%d%c", len[i], i < n ? ' ' : '\n'); int mx = 1; for(int i = n; i; i--) { mx--; if(len[i] > mx) mx = len[i]; else len[i] = mx; } // for(int i = 1; i <= n; i++) printf("%d%c", len[i], i < n ? ' ' : '\n'); for(int i = 1; i <= n; i++) { len[i] = (len[i] << 1) - 1; int l = len[i]; len[i] = alp[i] - alp[i-l]; } for(int i = 1; i <= (n >> 1); i++) len[i] = max(len[(i<<1)-1], len[i<<1]); n >>= 1; // puts(S + 1); // for(int i = 1; i <= n; i++) printf("%d%c", len[i], i < n ? ' ' : '\n'); rt = last = ToT = 1; for(int i = 1; i <= n; i++) extend(i); for(int i = 1; i <= ToT; i++) Ws[n-Max[i]]++; for(int i = 1; i <= n; i++) Ws[i] += Ws[i-1]; for(int i = ToT; i; i--) sa[Ws[n-Max[i]]--] = i; LL ans = 0; for(int i = 1; i <= ToT; i++) { int u = sa[i]; mnl[par[u]] = min(mnl[par[u]], mnl[u]); val[par[u]] += val[u]; if(Max[u] >= mnl[u]) ans = max(ans, (LL)mnl[u] * val[u]); } printf("%lld\n", ans); return 0; }