题解:P10581 [蓝桥杯 2024 国 A] 重复的串

· · 题解

单串 ACAM(可能是 KMP 自动机)+ 矩阵优化 dp。

我们设 dp_{k,i,j} 为当前字符串长为 k,匹配成功 i 次,字符串结尾为 ACAM 上点 j 的方案数。

朴素转移是容易的。注意正好匹配时的边界问题。

然后发现 n 很大,|S| 很小。所以我们考虑上矩阵乘法来优化。

:::error[错误解法]

B_k= \begin{bmatrix} dp_{0,0} & dp_{0,1} & \dots & dp_{0,|S|} \\ dp_{1,0} & dp_{1,1} & \dots & dp_{1,|S|} \\ dp_{2,0} & dp_{2,1} & \dots & dp_{2,|S|} \\ \end{bmatrix}

发现各行之间相互独立,上一行不能转移到下一行,意味着当匹配成功时无法正确转移,次数无法被更新。卒。

:::

正确解法:

B_k= \begin{bmatrix} dp_{0,0} & \dots & dp_{0,|S|} & dp_{1,0} & \dots & dp_{1,|S|} & dp_{2,0} & \dots & dp_{2,|S|} \end{bmatrix}

这样当匹配成功时,次数可以正确被更新,转移正确。

然后构造转移矩阵,先初始化为 0

base=|S|+1

所以匹配 i 次在矩阵中的下标范围为 [i\times base,i \times (base+1))

然后对于每个 i \in [0,tot),遍历 j \in [0,26)(就是遍历字典树),并令转移矩阵中 A[i][trie_{i,j}],A[i+base][trie_{i,j}+base],A[i+base+base][trie_{i,j}+base+base] 自增 1。意为匹配成功 0/1/2 次时,构造转移矩阵使得 dp_0,dp_1,dp_2 这三维互不干扰单独转移。

对于 i=tot,遍历 j \in [0,26),并令转移矩阵中 A[i][trie_{i,j}+base],A[i+base][trie_{i,j}+base+base] 自增 1。意为由匹配成功 x 次转移到匹配成功 x+1 次。

然后对这个矩阵求它的 n 次幂,设结果为 X

答案矩阵 B_n = B_0\times X。其实不用真正乘,发现 B_0 只有 B_0[0][0]=1,所以最终答案为

ans=\sum_{i=2base-1}^{3base-2}X[0][i]

Code:

#include<bits/stdc++.h>
#define int long long

using namespace std;

const int Size=(1<<20)+1;
char buf[Size],*p1=buf,*p2=buf;
char buffer[Size];
int op1=-1;
const int op2=Size-1;
#define getchar()                                                              \
(tt == ss && (tt=(ss=In)+fread(In, 1, 1 << 20, stdin), ss == tt)     \
    ? EOF                                                                 \
    : *ss++)
char In[1<<20],*ss=In,*tt=In;
inline int read()
{
    int x=0,c=getchar(),f=0;
    for(;c>'9'||c<'0';f=c=='-',c=getchar());
    for(;c>='0'&&c<='9';c=getchar())
        x=(x<<1)+(x<<3)+(c^48);
    return f?-x:x;
}
inline void write(int x)
{
    if(x<0) x=-x,putchar('-');
    if(x>9)  write(x/10);
    putchar(x%10+'0');
}

#ifndef ONLINE_JUDGE
#define ONLINE_JUDGE
#endif

const int N=31*3,mod=998244353;
int t[N][26];
int tot;
int fail[N];
struct Martix{
    int a[N][N]={};
}I;
Martix operator*(const Martix &x,const Martix &y)
{
    Martix ans;
    const int base=(tot+1)*3;
    memset(ans.a,0,sizeof(ans.a));
    for(int i=0;i<base;i++)
    for(int j=0;j<base;j++)
    for(int k=0;k<base;k++)
    (ans.a[i][j]+=x.a[i][k]*y.a[k][j]%mod)%=mod;
    return ans;
}

void insert(const string &s)
{   
    int nw=0;
    for(char i:s)
    {
        if(!t[nw][i-'a']) t[nw][i-'a']=++tot;
        nw=t[nw][i-'a'];
    }
}   

Martix build()
{
    queue<int> q;
    Martix ans;
    memset(ans.a,0,sizeof(ans.a));
    for(int i=0;i<26;i++)
    {
        if(t[0][i]) q.push(t[0][i]);
    }
    while(q.size())
    {
        int nw=q.front();
        q.pop();
        for(int i=0;i<26;i++)
        {
            if(t[nw][i])
            {
                fail[t[nw][i]]=t[fail[nw]][i];
                q.push(t[nw][i]);
            }
            else t[nw][i]=t[fail[nw]][i];
        }
    }
    const int base=tot+1;
    for(int i=0;i<tot;i++)
    for(int j=0;j<26;j++)
    {
        ans.a[i][t[i][j]]+=1;
        ans.a[i+base][t[i][j]+base]+=1;
        ans.a[i+base+base][t[i][j]+base+base]+=1;
    }

    for(int i=tot;i==tot;i++)
    for(int j=0;j<26;j++)
    {
        ans.a[i][t[i][j]+base]+=1;
        ans.a[i+base][t[i][j]+base+base]+=1;
    }
    return ans;
}

Martix ksm(Martix x,int p)
{
    Martix ans=I;
    while(p)
    {
        if(p&1) ans=ans*x;
        x=x*x;
        p>>=1;
    }
    return ans;
}

// void output(Martix x)
// {
//  cout<<"\n------------------------------------\n\n";

//  for(int i=0;i<(tot+1)*3;i++,cout<<"\n") 
//  for(int j=0;j<(tot+1)*3;j++) 
//  cout<<x.a[i][j]<<" ";     
//  // cout<<"\n------------------------------------\n\n";
// }

signed main()
{
    #ifndef ONLINE_JUDGE
    freopen("a.in","r",stdin);
    freopen("a.out","w",stdout);
    #endif
    memset(I.a,0,sizeof(I.a));
    for(int i=0;i<N;i++) I.a[i][i]=1;

    string s;
    int n;
    cin>>s>>n;
    insert(s);
    Martix x=build();

    Martix ans=ksm(x,n);
    int nw=0;
    int base=tot+1;
    for(int i=2*base-1;i<3*base-1;i++) (nw+=ans.a[0][i])%=mod;
    cout<<nw<<"\n";

    return 0;
}