题解 P3176 [HAOI2015]数字串拆分

· · 题解

P3176 [HAOI2015]数字串拆分

题意简述:定义 f(s) 为将 s 拆分成若干个不大于 m 个数的方案数。给出数字字符串 t,求 g(t)=\sum f(x),其中 x 为将 t 分割成若干个数后它们的和。例如 t=\texttt{12} 时,答案为 f(1+2)+f(12)

本文节选自 DP 大锅乱炖 IV.

注意到 m 的范围很小,只有 5,所以 f(s) 可以用矩阵快速幂 m^3\log s 计算(基本操作)。记转移矩阵为 G,那么答案即为 (\sum G^x)_{0,0}。我们记 d_ig'(t[1:i])=\sum G^x(可以和上面 g(s) 的柿子对比一下,g 是数值,g' 是矩阵),则答案为 (d_n)_{0,0}

不难求出转移方程 d_i=\sum_{j=0}^{i-1}d_j\times f(t[j+1:i])。因为 t[j+1:i] 会很大(达到 10^n),所以快速幂的时候需要对指数高精(矩阵没有费马小定理!)。因此,这样的时间复杂度为 \mathcal{O}(n^3m^3),无法承受。

注意到 f(t)=G^t=\prod_i G^{10^{|t|-i}t_i},那么我们可以预处理出形如 G^{c\times 10^k}\ (1\leq c<10,1\leq k\leq n) 的矩阵。然后记 D_{l,r}f(t[l:r]),那么它可以由 D_{l+1,r}\times G^{10^{r-l}t_{l}} 转移过来。nm^3|\Sigma|+n^2m^3 预处理后可以 n^2m^3 DP(d_i=\sum_{j=0}^{i-1} d_jD_{j+1,i})。可以通过本题。

/*
    Powered by C++11.
    Author : Alex_Wei.
*/

#include <bits/stdc++.h>
using namespace std;

//#pragma GCC optimize(3)
//#define int long long

using ll = long long;
#define mem(x,v) memset(x,v,sizeof(x))

const int N=500+5;
const int M=5;
const int mod=998244353;

int n,m;
char s[N];
void add(int &x,ll y){
    x+=y%mod;
    if(x>mod)x-=mod;
}

struct mat{
    int a[M][M];
    friend mat operator * (mat x,mat y){
        mat z; mem(z.a,0);
        for(int i=0;i<m;i++)
            for(int j=0;j<m;j++)
                for(int k=0;k<m;k++)
                    add(z.a[i][j],1ll*x.a[i][k]*y.a[k][j]);
        return z;
    }
    friend mat operator + (mat x,mat y){
        mat z; mem(z.a,0);
        for(int i=0;i<m;i++)
            for(int j=0;j<m;j++)
                add(z.a[i][j],x.a[i][j]+y.a[i][j]);
        return z;
    }
}base,f[N],c[N][N],d[N][10];

mat ksm(mat a,int b){
    mat x; mem(x.a,0);
    for(int i=0;i<m;i++)x.a[i][i]=1;
    while(b){
        if(b&1)x=x*a;
        a=a*a,b>>=1;
    } return x;
}

int main(){
    scanf("%s%d",s+1,&m),n=strlen(s+1),f[0].a[0][0]=1;
    for(int i=0;i<m;i++)base.a[i][0]=1;
    for(int i=1;i<m;i++)base.a[i-1][i]=1;
    for(int i=0;i<10;i++){
        d[0][i]=ksm(base,i);
        for(int j=1;j<=n;j++)d[j][i]=ksm(d[j-1][i],10);
    } for(int r=n;r;r--)
        for(int l=r;l;l--)
            if(l==r)c[l][r]=d[0][s[r]-'0'];
            else c[l][r]=c[l+1][r]*d[r-l][s[l]-'0'];
    for(int i=1;i<=n;i++)
        for(int j=0;j<i;j++)
            f[i]=f[i]+f[j]*c[j+1][i];
    cout<<f[n].a[0][0]<<endl;
    return 0;
}