题解 CF932G 【Palindrome Partition】

eee_hoho

2021-01-26 14:18:03

Solution

PAM好题QAQ 我们先把题目要求转化成回文串,假设分成了 $k$ 段为 $s_1,s_2,s_3\dots s_k$ ,我们单独看 $s_1$ 和 $s_k$ ,假设他们的长度是 $x$ ,那么对于原串 $a$ 就有: $$a_1=a_{n-x+1},a_2=a_{n-x+2},\dots,a_x=a_{n}$$ 于是我们构造一个串 $b=a_1a_na_2a_{n-1}a_3a_{n-2}\dots$ ,这样子我们求长度为偶数且总和为 $n$ 的回文串方案数就可以了。 然后我们考虑dp,设 $f_i$ 表示以 $i$ 结尾的方案数,那么就有 $f_i=\sum_jf_{j-1}$ ,其中 $[j,i]$ 于是我们一直跳 $fail$ 就可以得到所有的 $j$ 。 但是这样子做过不去的,全是相同字符就会退化到 $O(n^2)$ ,于是我们再进一步考虑。 对于字符串 $border$ 有:一个字符串的所有 $border$ 会形成 $\log$ 个不相交的等差数列。 考虑证明,我们对于比 $\frac{len}{2}$ 大的两个 $border$ $A,B$ ,假设 $B$ 是该串的最长 $border$ ,那么 $len-|B|$ 是其最小周期,那么 $len-|B| \ | \ len-|A|$ ,于是对于所有满足 $s=k(len-|B|),len-s\ge len$ 的都是字符串的 $border$ ,然后递归考虑 $\frac{len}{2}$ 的问题就可以了。 那么对于回文串,我们显然有如果一个串 $S$ 的后缀 $T$ 是回文的,那么 $T$ 是 $S$ 的 $border$ ,于是我们通过这个来优化复杂度。 设 $g_p$ 表示以 $p$ 这个结点结束的等差数列的贡献,具体点就是如果这段等差数列长度分别为 $s_1,s_2,\dots s_k$ ,那么 $g_p=\sum_{i=1}^kf_{x-s_i}$ ,$x$ 是当前位置。 于是我们考虑求这个东西,根据回文串的性质,发现 $g_p=g_{fail[p]}+f_{x-s_k}$ ,因为 $fail_p$ 会把 $p$ 之前的都算到,然后加上自己的贡献就可以了,这样子复杂度就变成了 $O(n\log n)$ **Code** ``` cpp #include <iostream> #include <cstdio> #include <algorithm> #include <cstring> const int N = 1e6; const int mod = 1e9 + 7; using namespace std; int n; char s[N + 5],ns[N + 5]; struct PAM { int ch[N + 5][26],fail[N + 5],len[N + 5],nc,las,f[N + 5],sf[N + 5],d[N + 5],pos[N + 5]; char s[N + 5]; void init(char *ch) { for (int i = 1;i <= n;i++) s[i] = ch[i]; s[0] = '#'; len[1] = -1; pos[0] = 1; fail[0] = 1; f[0] = 1; nc = 1; } void insert(int x) { int p = las; while (s[x - len[p] - 1] != s[x]) p = fail[p]; if (!ch[p][s[x] - 'a']) { len[++nc] = len[p] + 2; int j = fail[p]; while (s[x - len[j] - 1] != s[x]) j = fail[j]; fail[nc] = ch[j][s[x] - 'a']; d[nc] = len[nc] - len[fail[nc]]; if (d[nc] == d[fail[nc]]) pos[nc] = pos[fail[nc]]; else pos[nc] = fail[nc]; ch[p][s[x] - 'a'] = nc; } las = ch[p][s[x] - 'a']; int j = las; while (j > 1) { sf[j] = f[x - len[pos[j]] - d[j]]; if (pos[j] != fail[j]) sf[j] += sf[fail[j]],sf[j] %= mod; if (x % 2 == 0) f[x] += sf[j],f[x] %= mod; j = pos[j]; } } void build() { for (int i = 1;i <= n;i++) insert(i); } }pa; int main() { scanf("%s",s + 1); n = strlen(s + 1); int l = 1,r = n; for (int i = 1;i <= n;i++) if (i % 2 == 1) ns[i] = s[l++]; else ns[i] = s[r--]; pa.init(ns);pa.build(); cout<<pa.f[n]<<endl; return 0; } ```