题解:P5577 [CmdOI2019] 算力训练

· · 题解

[CmdOI2019] 算力训练

看到 k 进制不进位加法很自然会想到 \rm FWT 。不会 k 进制 \rm FTW 的可以去 \rm OI-Wiki

不难发现答案为

\prod\limits_{i=1}^n 1 + x^{a_i}

这里下标的加法定义为 k 进制不进位加法。

一个朴素的想法就是,对于每个 i ,对 1 + x^{a_i} 进行 \rm FWT ,并将每一位都乘一起,最后 \rm UFWT 回去。由于不存在 6 次单位根,因此还需要将每一位都设为关于 \omega_k 的多项式。这么做的复杂度是 O( n m k^{m+2} ) ,显然过不去。

考虑优化。不难发现,对 x^{a_i} 进行 \rm FWT 后,每一位都是形如 \omega_k^c 的,而对 1 进行 \rm FWT 后,每一位都是 1 。因此,对 1 + x^{a_i} 进行 \rm FWT 后,每一位都是形如 1 + \omega_k^c 的。

因此,对 \sum\limits_{i=1}^n x^{a_i} 进行 \rm FWT 后,对于某个 i,设结果的第 i 位为 \sum\limits_{j=0}^{k-1} c_j \omega_k^{j} ,则说明对于每个 j ,都恰好存在 c_j a ,使得 [x^i] \operatorname{FWT}( 1 + x^a ) = 1 + \omega_k^j ,所以 [x^i] \prod\limits_{j=1}^n 1 + x^{a_j} = \prod\limits_{j=0}^{k-1} ( 1 + \omega_k^j )^{c_j} 。由于 c_j \le n ,因此对每个 j, c 预处理出 ( 1 + \omega_k^j )^c 即可。复杂度为 O( n k^2 + k^{m+3} ) 。注意预处理 ( 1 + \omega_k^j )^c 时不要把所有结果都存下来,最好对每个 j 单独处理,否则会爆空间。

#include<bits/stdc++.h>
#define ll long long
#define pn putchar('\n')
#define mset(a,x) memset(a,x,sizeof a)
#define mcpy(a,b) memcpy(a,b,sizeof a)
#define all(a) a.begin(),a.end()
#define fls() fflush(stdout)
#define int ll
#define maxn 1000005
#define mod 998244353
using namespace std;
int re()
{
    int x=0;
    bool t=1;
    char ch=getchar();
    while(ch>'9'||ch<'0')
        t=ch=='-'?0:t,ch=getchar();
    while(ch>='0'&&ch<='9')
        x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
    return t?x:-x;
}
int ksm(int x,int y)
{
    int ret=1;
    while(y)
    {
        if(y&1)ret=ret*x%mod;
        x=x*x%mod,y>>=1;
    }
    return ret;
}
void dq(int &x)
{
    if(x>=mod)x-=mod;
    if(x<0)x+=mod;
}
int n,K,m,N;
int toK(int x)
{
    int ret=0;
    for(int i=1;i<N;i*=K)
    {
        ret+=x%10*i;
        x/=10;
    }
    return ret;
}
struct Poly
{
    int a[6];
    Poly()
    {
        mset(a,0);
    }
    void clr()
    {
        mset(a,0);
    }
    void print()
    {
        for(int i=0;i<K;i++)
            printf("%lld ",a[i]);
        pn;
    }
    void operator = (int x)
    {
        mset(a,0);
        a[0]=x;
    }
    int& operator [] (int x)
    {
        return a[x];
    }
    Poly operator << (int t)
    {
        t%=K;
        if(t<0)
            t+=K;
        Poly ret;
        for(int i=0;i<K;i++)
            ret[t+i>=K?t+i-K:t+i]=a[i];
        return ret;
    }
    Poly operator + (Poly t)
    {
        Poly ret;
        for(int i=0;i<K;i++)
            dq(ret[i]=a[i]+t[i]);
        return ret;
    }
    void operator += (Poly t)
    {
        *this=*this+t;
    }
    Poly operator * (Poly t)
    {
        Poly ret;
        for(int i=0;i<K;i++)
        {
            for(int j=0;j<K;j++)
                (ret[i+j>=K?i+j-K:i+j]+=a[i]*t[j])%=mod;
        }
        return ret;
    }
    void operator *= (Poly t)
    {
        *this=*this*t;
    }
    Poly operator * (int t)
    {
        Poly ret;
        for(int i=0;i<K;i++)
            ret[i]=a[i]*t%mod;
        return ret;
    }
    void operator *= (int t)
    {
        *this=*this*t;
    }
}a[100000],b[100000],P,tmp[100000],f[maxn];
void FWT(Poly a[],bool ty=0)
{
    for(int i=1;i<N;i*=K)
    {
        for(int j=0;j<N;j+=i*K)
        {
            for(int k=j;k<j+i;k++)
            {
                for(int l=0;l<K;l++)
                {
                    int t=k+l*i;
                    tmp[t]=a[t];
                    a[t].clr();
                }
                for(int x=0;x<K;x++)
                {
                    for(int y=0;y<K;y++)
                        a[k+x*i]+=tmp[k+y*i]<<(!ty?x*y:-x*y);
                }
            }
        }
    }
    if(ty)
    {
        int iv=ksm(N,mod-2);
        for(int i=0;i<N;i++)
            a[i]*=iv;
    }
}
void Mod(Poly& a)
{
    for(int i=K-1;i>=m;i--)
    {
        int t=a[i];
        for(int j=0;j<=m;j++)
            (a[i-j]-=t*P[m-j])%=mod;
    }
}
signed main()
{
    n=re(),K=re(),m=re();
    N=1;
    while(m--)
        N*=K;
    if(K==5)
    {
        m=4;
        P[0]=P[1]=P[2]=P[3]=P[4]=1;
    }
    else
    {
        m=2;
        P[0]=P[2]=1;
        P[1]=-1;
    }
    for(int i=1;i<=n;i++)
        a[toK(re())][0]++;
    FWT(a);
    for(int i=0;i<N;i++)
        b[i]=1;
    for(int i=0;i<K;i++)
    {
        f[0]=1;
        for(int j=1;j<=n;j++)
        {
            for(int k=0;k<K;k++)
                f[j][k]=f[j-1][k];
            for(int k=0;k<K;k++)
                dq(f[j][k+i>=K?k+i-K:k+i]+=f[j-1][k]);
        }
        for(int j=0;j<N;j++)
            b[j]*=f[a[j][i]];
    }
    FWT(b,1);
    for(int i=0;i<N;i++)
    {
        Mod(b[i]);
        int ans=b[i][0];
        if(ans<0)
            ans+=mod;
        printf("%lld\n",ans);
    }
    return 0;
}