题解:P12602 指鹿为马

· · 题解

m\le 62 为字符集大小。

如 B 站视频建出 DFA,注意手模样例可以发现本题没有“空”这个节点,s_n 的后面应当紧接着 s_1

注意到当 $i\ge 1$ 时,只有 $c=s_i$ 的位置有意义,$\mathcal{O}((n+m)^3)$,具体地: --- $f_c=g_{c,s_1}f_1+\sum_{c'\neq s_1} g_{c,c'}f_{c'}+1(I)

对于 1\le i<nf_i=g_{s_i,s_{i+1}}f_{i+1}+\sum_{c\ne s_{i+1}\land kmp_{i,c}\neq 0}g_{s_i,c}f_{kmp_{i,c}}+\sum_{c\neq s_{i+1}\land kmp_{i,c}=0}g_{s_i,c}f_c+1(II)

其中 kmp_{i,c} 表示当前匹配了 i 位,又匹配了一个字符 c,当前最多匹配多少。可以 \mathcal{O}(nm) 递推预处理。

f_n=0

以下简记记 f_{0,c}f_cf_{i,s_i}f_i

观察 I 类方程,发现 f_c 只可能依赖 f_{c'},f_1 的值,这意味着如果我们知道所有 f_c,则我们可以直接推出 f_1 的值。

观察 II 类方程,发现 f_i(i\ge 1) 只可能依赖 f_{i+1},f_{<i},f_{c} 的值,这意味着如果我们知道所有 f_c,f_{\le i} 的值,则我们可以直接推出 f_{i+1} 的值。

这意味着,对于每个字符集里的元素 c,我们可以先用所有 f_{c'} 表示出 f_1,接着表示出 f_2,f_3,\dots,f_n,而 f_n=0,这意味着每个 c 都可以造出一个关于所有 f_{c'} 的方程,且这样的方程可以造出 m 个,解方程部分复杂度降到 \mathcal{O}(m^3)

然而对一个 c,造方程需要 \mathcal{O}(nm^2),也即我们造方程组需要 \mathcal{O}(nm^3),可以获得 98pts

#include<cstdio>
#include<string>
#include<vector>
#include<cassert>
#include<cstring>
#include<iostream>
#include<algorithm>
#define fi first
#define se second
using namespace std;
typedef long long LL;
typedef pair<int,int> PII;
const int N=3e5+9,P=998244353;
int qmi(int a,int b){
    int res=1;
    while(b){
        if(b&1) res=(LL)res*a%P;
        a=(LL)a*a%P;
        b>>=1;
    }
    return res;
}
string str;
char s[N]; int n,m,id[128],g[128][128],deg[128],a[128][128];
int kmp[N][128],nxt[N],tmp[N][128],f[N];
void add(int a,int b){
    g[a][b]++;
    deg[a]++;
}
void gauss(){
    for(int i=1;i<=m;i++){
        int id=i;
        for(int j=i;j<=m;j++) if(a[j][i]) id=j;
        if(id^i) swap(a[id],a[i]);
        if(!a[i][i]) assert(0);
        int inv=qmi(a[i][i],P-2);
        for(int j=i;j<=m+1;j++) a[i][j]=(LL)a[i][j]*inv%P;
        for(int k=1;k<=m;k++) if(k!=i&&a[k][i])
            for(int j=m+1;j>=i;j--) a[k][j]=(a[k][j]-(LL)a[k][i]*a[i][j]%P+P)%P;
    }
}
int main(){
    getline(cin,str);
    for(int i=0,len=str.size();i<len;i++) s[++n]=str[i];
    memset(id,-1,sizeof(id));
    for(int i=1;i<=n;i++) if(!~id[s[i]]) id[s[i]]=m++; m--;
    if(!m) return printf("%d\n",n),0;
    for(int i=1;i<n;i++) add(id[s[i]],id[s[i+1]]); add(id[s[n]],id[s[1]]);
    for(int i=0;i<=m;i++){
        int inv=qmi(deg[i],P-2);
        for(int j=0;j<=m;j++) g[i][j]=(LL)inv*g[i][j]%P;
    }
    for(int i=2,j=0;i<=n;i++){
        while(j&&s[j+1]!=s[i]) j=nxt[j];
        if(s[j+1]==s[i]) j++;
        nxt[i]=j;
    }
    for(int i=0;i<n;i++) for(int c=0;c<=m;c++){
        if(id[s[i+1]]==c) kmp[i][c]=i+1;
        else kmp[i][c]=kmp[nxt[i]][c];
//      printf("kmp %d %d %d\n",i,c,kmp[i][c]);
    }
    for(int i=1;i<=m;i++){
        if(!g[i][0]){
            for(int j=1;j<=m;j++) if(i^j) a[i][j]=(P-g[i][j])%P;
            a[i][i]=(P-g[i][i]+1)%P;
            a[i][m+1]=1;
        }
        else{
            int inv=qmi(g[i][0],P-2);
            for(int j=1;j<=m;j++) if(i^j) tmp[1][j]=(LL)(P-g[i][j])*inv%P;
            tmp[1][i]=(LL)(P-g[i][i]+1)*inv%P;
            tmp[1][m+1]=(LL)(P-1)*inv%P;
//          for(int p=1;p<=m+1;p++) printf("%d ",tmp[1][p]); puts("");
            for(int j=2;j<=n;j++){
                int inv=qmi(g[id[s[j-1]]][id[s[j]]],P-2);
                for(int c=1;c<=m+1;c++) tmp[j][c]=(LL)tmp[j-1][c]*inv%P;
//              for(int p=1;p<=m+1;p++) printf("%d ",tmp[j][p]); puts("");
                for(int c=0;c<=m;c++) if(c!=id[s[j]]){
                    int coef=(LL)(P-g[id[s[j-1]]][c])*inv%P;
//                  printf("kmp %d %d %d %d\n",j-1,c,kmp[j-1][c],coef);
                    if(!kmp[j-1][c]) tmp[j][c]=(tmp[j][c]+coef)%P;
                    else{
                        int k=kmp[j-1][c];
                        for(int c=1;c<=m+1;c++) tmp[j][c]=(tmp[j][c]+(LL)coef*tmp[k][c])%P;
                    }
                }
                tmp[j][m+1]=(tmp[j][m+1]+(LL)(P-1)*inv)%P;
//              for(int p=1;p<=m+1;p++) printf("%d ",tmp[j][p]); puts("");
            }
            for(int j=1;j<=m;j++) a[i][j]=tmp[n][j];
            a[i][m+1]=(P-tmp[n][m+1])%P;
        }
    }
//  for(int i=1;i<=m;i++,puts("")) for(int j=1;j<=m+1;j++) printf("%d ",a[i][j]);
    gauss();
    int c=1;
    while(!g[c][0]) c++;
    int f1=(a[c][m+1]-1+P)%P;
    for(int j=1;j<=m;j++) f1=(f1+(LL)(P-g[c][j])*a[j][m+1])%P;
    f1=(LL)f1*qmi(g[c][0],P-2)%P;
    printf("%d\n",(f1+1)%P);
    return 0;
}

进一步考察上述过程,我们发现我们是将每个 f_i 用一个向量 tmp_{i,1},tmp_{i,2},\dots,tmp_{i,m},tmp_{i,m+1} 表示,表示 f_i=\sum_{c'\in[1,m]}tmp_{i,c'}f_{c'}+tmp_{i,m+1}。求 f_i 的过程是 \mathcal{O}(m)f_{j}(j<i) 的线性组合再改 O(m) 项,而无论是线性组合的系数,还是改的常数,都是与 c 无关的。

这意味着无论 c 取何值,tmp_{i,c'} 必然可以表示为 k_{1,c'}tmp_{1,c'}+b_{1,c'} 其中 k,b 为常数组,与 c 无关。

所以我们其实只要 \mathcal{O}(nm^2) 递推一次求出 k,b,对于每个 c 就可以单次 \mathcal{O}(m) 的造出方程。

总时间复杂度 \mathcal{O}(nm^2),空间复杂度 \mathcal{O}(nm),可以通过:

#include<cstdio>
#include<string>
#include<vector>
#include<cassert>
#include<cstring>
#include<iostream>
#include<algorithm>
#define fi first
#define se second
using namespace std;
typedef long long LL;
typedef pair<int,int> PII;
const int N=3e5+9,P=998244353;
int qmi(int a,int b){
    int res=1;
    while(b){
        if(b&1) res=(LL)res*a%P;
        a=(LL)a*a%P;
        b>>=1;
    }
    return res;
}
string str;
char s[N]; int n,m,id[128],g[128][128],deg[128],a[128][128];
int kmp[N][128],nxt[N],f[N];
PII tmp[N][128];
void add(int a,int b){
    g[a][b]++;
    deg[a]++;
}
void gauss(){
    for(int i=1;i<=m;i++){
        int id=i;
        for(int j=i;j<=m;j++) if(a[j][i]) id=j;
        if(id^i) swap(a[id],a[i]);
        if(!a[i][i]) assert(0);
        int inv=qmi(a[i][i],P-2);
        for(int j=i;j<=m+1;j++) a[i][j]=(LL)a[i][j]*inv%P;
        for(int k=1;k<=m;k++) if(k!=i&&a[k][i])
            for(int j=m+1;j>=i;j--) a[k][j]=(a[k][j]-(LL)a[k][i]*a[i][j]%P+P)%P;
    }
}
int main(){
    getline(cin,str);
    for(int i=0,len=str.size();i<len;i++) s[++n]=str[i];
    memset(id,-1,sizeof(id));
    for(int i=1;i<=n;i++) if(!~id[s[i]]) id[s[i]]=m++; m--;
    if(!m) return printf("%d\n",n),0;
    for(int i=1;i<n;i++) add(id[s[i]],id[s[i+1]]); add(id[s[n]],id[s[1]]);
    for(int i=0;i<=m;i++){
        int inv=qmi(deg[i],P-2);
        for(int j=0;j<=m;j++) g[i][j]=(LL)inv*g[i][j]%P;
    }
    for(int i=2,j=0;i<=n;i++){
        while(j&&s[j+1]!=s[i]) j=nxt[j];
        if(s[j+1]==s[i]) j++;
        nxt[i]=j;
    }
    for(int i=0;i<n;i++) for(int c=0;c<=m;c++){
        if(id[s[i+1]]==c) kmp[i][c]=i+1;
        else kmp[i][c]=kmp[nxt[i]][c];
//      printf("kmp %d %d %d\n",i,c,kmp[i][c]);
    }
    for(int i=1;i<=m+1;i++) tmp[1][i]={1,0};
    for(int j=2;j<=n;j++){
        int inv=qmi(g[id[s[j-1]]][id[s[j]]],P-2);
        for(int c=1;c<=m+1;c++) tmp[j][c].fi=(LL)tmp[j-1][c].fi*inv%P,tmp[j][c].se=(LL)tmp[j-1][c].se*inv%P;
        for(int c=0;c<=m;c++) if(c!=id[s[j]]){
            int coef=(LL)(P-g[id[s[j-1]]][c])*inv%P;
            if(!kmp[j-1][c]) tmp[j][c].se=(tmp[j][c].se+coef)%P;
            else{
                int k=kmp[j-1][c];
                for(int c=1;c<=m+1;c++) tmp[j][c].fi=(tmp[j][c].fi+(LL)tmp[k][c].fi*coef)%P,tmp[j][c].se=(tmp[j][c].se+(LL)tmp[k][c].se*coef)%P;
            }
        }
        tmp[j][m+1].se=(tmp[j][m+1].se+(LL)(P-1)*inv)%P;
    }
    for(int i=1;i<=m;i++){
        if(!g[i][0]){
            for(int j=1;j<=m;j++) if(i^j) a[i][j]=(P-g[i][j])%P;
            a[i][i]=(P-g[i][i]+1)%P;
            a[i][m+1]=1;
        }
        else{
            int inv=qmi(g[i][0],P-2);
            for(int j=1;j<=m;j++) if(i^j) a[i][j]=((LL)(P-g[i][j])*inv%P*tmp[n][j].fi+tmp[n][j].se)%P;
            a[i][i]=((LL)(P-g[i][i]+1)*inv%P*tmp[n][i].fi+tmp[n][i].se)%P;
            a[i][m+1]=(P-((LL)(P-1)*inv%P*tmp[n][m+1].fi+tmp[n][m+1].se)%P)%P;
        }
    }
//  for(int i=1;i<=m;i++,puts("")) for(int j=1;j<=m+1;j++) printf("%d ",a[i][j]);
    gauss();
    int c=1;
    while(!g[c][0]) c++;
    int f1=(a[c][m+1]-1+P)%P;
    for(int j=1;j<=m;j++) f1=(f1+(LL)(P-g[c][j])*a[j][m+1])%P;
    f1=(LL)f1*qmi(g[c][0],P-2)%P;
    printf("%d\n",(f1+1)%P);
    return 0;
}