P12602 指鹿为马 题解

· · 题解

下文为了区分,用 \text{cnt1}_i,\text{cnt2}_{i,j} 代指题目中的 \text{cnt}_i,\text{cnt}_{i,j}

首先对于打出的东西,我们只关心其某一个后缀是否是 s 的前缀。可以想到的 dp 是,设 f_i 表示对于目前打出的字符所构成的字符串,其所有后缀中,满足是 s 的前缀的最大长度为 i,对应的期望打字数。

\text{nxt}_{i,j} 表示 \overline{s[1:i]j} 的 border,那么当 \text{nxt}_{i,j}\neq0 时,j 会对 f_i 贡献 \dfrac{\text{cnt2}_{i,j}}{\text{cnt1}_i}\cdot(f_{\text{nxt}_{i,j}}+1)。如果 \text{nxt}_{i,j}=0,那么我们就需要重新开始,即要算出从字符 j 开始打字,直到打出 s_1 的期望打字数,记为 g_j。这个可以很容易得到关系式,即枚举下一个打的字符,有关系式:

g_i=1+\sum\limits_{j\in\Sigma}\dfrac{\text{cnt2}_{i,j}}{\text{cnt1}_i}\cdot g_j

并且我们有 g_{s_1}=0。考虑把这个关系式的未知数放到左边,常数项放到右边,就可以做一个高斯消元了。

同时我们也就有 f_i 的转移式:

f_i=\sum\limits_{j\in\Sigma}\dfrac{\text{cnt2}_{i,j}}{\text{cnt1}_i}\cdot\begin{cases}f_1+1+g_j&(\text{nxt}_{i,j}=0)\\f_{\text{nxt}_{i,j}}+1&(\text{nxt}_{i,j}\neq0)\end{cases}

答案即为 f_1+1,边界有 f_{|s|}=0

这个东西虽然也能高斯消元,但是时间爆完了。

考虑一下,我们之所以做高斯消元,是因为不能只从之前的状态转移过来。然而发现 \text{nxt}_{i,j}\le i+1,也就是说只有恰好一个 j 会从一个 >i 的位置转移过来,并且这个位置还恰好是 i+1。于是我们把 f_{i+1} 这一项挪到左边,f_i 挪到右边,就有一个从小到大的转移。

但是我们现在是知道的大的要反推小的。归纳一下的话,可以得到,对于任意的 f_i,都可以表示成一个 f_i=k_i\cdot f_1+b_i。于是去转移这个 k_i,b_i 即可。最终的 f_{|s|}=k_{|s|}\cdot f_1+b_{|s|}=0 可以直接得到 f_1 的值。时间复杂度 O(|\Sigma|^3+|s|\cdot|\Sigma|)

#include<bits/stdc++.h>
#define N 300005
#define M 65
#define V 62
#define ll long long
#define mod 998244353
using namespace std;

ll qpow(ll x,ll y)
{
    ll res=1;
    x%=mod;
    while(y)
    {
        if(y&1) res=res*x%mod;
        x=x*x%mod;
        y>>=1;
    }
    return res;
}

struct gaus{
    ll a[M][M];
    int n,m;
    ll* operator[](int i){return a[i];}
    void init(int x,int y){n=x,m=y;}
    auto gauss()
    {
        for(int j=1;j<m;j++)
        {
            ll inv=qpow(a[j][j],mod-2);
            for(int i=1;i<=n;i++)
            {
                if(i==j) continue;
                ll t=a[i][j]*inv%mod;
                for(int k=1;k<=m;k++) a[i][k]=(a[i][k]-t*a[j][k]%mod+mod)%mod;
            }
        }
        vector<int>vec;
        for(int i=1;i<=n;i++) vec.push_back(a[i][m]*qpow(a[i][i],mod-2)%mod);
        return vec;
    }
}equ;
struct info{
    ll k,b;
}f[N];
int n,a[N],cnt1[M],cnt2[M][M],nxt[N][M],bd[N];
ll invc[M],g[N];

void border()
{
    for(int i=2;i<=n;i++)
    {
        int j=bd[i-1];
        while(j&&a[i]!=a[j+1]) j=bd[j];
        bd[i]=j+(a[i]==a[j+1]);
    }
    for(int j=1;j<=V;j++)
    {
        for(int i=0;i<n;i++)
        {
            nxt[i][j]=(a[i+1]==j?i+1:nxt[bd[i]][j]);
        }
    }
}

int main()
{
    //freopen("typewriter.in","r",stdin);
    //freopen("typewriter.out","w",stdout);
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    string s;
    cin>>s;
    n=s.size();
    for(int i=0;i<n;i++) a[i+1]=s[i]-(isdigit(s[i])?47:isupper(s[i])?54:60);
    for(int i=1;i<=n;i++) cnt1[a[i]]++,cnt2[a[i]][a[i%n+1]]++;
    border();
    equ.init(V,V+1);
    for(int i=1;i<=V;i++)
    {
        equ[i][i]=equ[i][V+1]=mod-1;
        invc[i]=qpow(cnt1[i],mod-2);
        // cerr<<invc[i]<<' ';
        for(int j=1;j<=V;j++) (equ[i][j]+=cnt2[i][j]*invc[i]%mod)%=mod;
    }
    fill(equ[a[1]]+1,equ[a[1]]+V+2,0);
    equ[a[1]][a[1]]=1;
    // for(int i=1;i<=V;i++){for(int j=1;j<=V+1;j++)cerr<<equ[i][j]<<' ';cerr<<'\n';}
    // cerr<<'\n';
    auto vec=equ.gauss();
    for(int i=1;i<=V;i++) g[i]=vec[i-1];//,cerr<<g[i]<<' ';cerr<<'\n';
    f[1]={1,0};
    for(int i=2;i<=n;i++)
    {
        ll inv=invc[a[i-1]];
        f[i].k=f[i-1].k;
        f[i].b=(f[i-1].b-cnt2[a[i-1]][a[i]]*inv%mod+mod)%mod;
        for(int j=1;j<=V;j++)
        {
            ll p=cnt2[a[i-1]][j]*inv%mod;
            if(nxt[i-1][j]==i) continue;
            if(!nxt[i-1][j])
            {
                (f[i].k+=(mod-f[1].k)*p%mod)%=mod;
                (f[i].b+=(mod-f[1].b-1-g[j])*p%mod+mod)%=mod;
                continue;
            }
            (f[i].k+=(mod-f[nxt[i-1][j]].k)*p%mod)%=mod;
            (f[i].b+=(mod-f[nxt[i-1][j]].b-1)*p%mod)%=mod;
        }
        ll p=cnt1[a[i-1]]*qpow(cnt2[a[i-1]][a[i]],mod-2)%mod;
        (f[i].k*=p)%=mod;
        (f[i].b*=p)%=mod;
        // cerr<<f[i].k<<' '<<f[i].b<<'\n';
    }
    cout<<(1-f[n].b*qpow(f[n].k,mod-2)%mod+mod)%mod;
    return 0;
}