题解 P5404 【[CTS2019]重复】

· · 题解

S^{\infty} 为无穷多个 S 顺序连接在一起得到的串。

题意即求存在多少个 T,满足 |T| = mT^{\infty} 存在一个长为 |S| 的子串的字典序 < S

首先将限制转化为 T^{\infty} 的所有长为 |S| 的子串字典序都 \geq S。 考虑什么样的 T 是合法的,将 T^{\infty} 放在 S 的 kmp 自动机上匹配时,走的所有转移边都必须大于当前节点沿后缀链接跳到根的路径上的所有节点的后继转移边。

设串 X 匹配结束之后所在的节点为 P(X)

考虑枚举 P(T^\infty) = u,如果我们再往后接一个 T 继续匹配,一定会遍历一遍一个环并回到 u,因为 T^\infty + T = T^\infty

同时,对于所有 T,如果再往后接一个 T 继续匹配恰好回到节点 u,这样的 T 对应的 T^\infty 匹配结束之后所在的节点一定是 u。形式化地说,对于任意串 T,如果从节点 u 往后匹配一个 T 得到的节点为 u,则 P(T^\infty) = u

我们考虑 P(T^\infty) 的意义,即最长的同时是 T^\infty 的后缀,S 的前缀的串。显然有,对于任意串 N,有 P(N+T^{\infty}) = P(T^\infty)。因此,假设 u 代表的串为 N,因为往后匹配任意个 T 得到的节点都为 u,有 P(T^\infty) = P(N + T^\infty) = u

那么我们知道,对应节点 u 的合法的串 T,和从节点 u 接着走 m 条合法转移边形成的一个环回到 u 的方案一一对应。至此直接 dp 可以得到 O(n^2m) 的做法。

考虑这个自动机的形式,对于任意节点,其不指向根的合法出边至多只有一条,否则对应字符较小的一条是不合法的。因此这个自动机的形式是,节点 u 有一条转移边指向 u+1 或者一个节点 xx<u,有若干条转移边指向根。

同样枚举节点 P(T^\infty),考虑这个环是否经过根节点,如果不经过的话,方案(如果存在)是唯一的,暴力验证即可。否则,枚举走几步时跳到根,然后对根 dp,f_{i,j} 表示从根走 i 步到达节点 j 的方案数,就可以快速计算了,复杂度 O(nm)

#include <bits/stdc++.h>
#define ll long long
#define maxn 2005
const int mod = 998244353;
int n,m,max[maxn],nxt[maxn],ch[maxn][26],f[maxn][maxn],ans,sum=1;
char s[maxn];

int main(){
    scanf("%d%s",&m,s+1);
    n = std::strlen(s+1);
    nxt[0] = -1; max[0] = 0;
    for (int i = 1; i <= m; ++ i) sum = (ll) sum * 26 % mod;
    for (int i = 1; i <= n; ++ i) {
        int j = nxt[i-1];
        while (j != -1 && s[j+1] != s[i]) j = nxt[j];
        nxt[i] = j+1;
    } 
    for (int i = 0; i <= n; ++ i) 
    for (int c = 0; c < 26; ++ c) {
        if (s[i+1] - 'a' == c) ch[i][c] = i+1;
        else ch[i][c] = (i!=0? ch[nxt[i]][c] : 0);
        if (ch[i][c]) max[i] = c;
    }
    f[0][0] = 1;
    for (int i = 0; i <= m; ++ i) 
    for (int j = 0; j <= n; ++j) 
    for (int c = max[j]; c < 26; ++ c)
        f[i+1][ch[j][c]] = (f[i+1][ch[j][c]] + f[i][j]) % mod;
    for (int i = 0; i <= n; ++ i) {
        int u = i;
        for (int j = 0; j < m; ++ j) {
            ans = (ans + (ll) (26 - max[u] - 1) * f[m-j-1][i] % mod) % mod;
            u = ch[u][max[u]];
            if (!u) break;
        } 
        ans = (ans + (u !=0 && u==i)) % mod;
    } printf("%d",(sum - ans + mod) % mod);
    return 0;
}