CF1103E Radix sum

· · 题解

其他题解怎么感觉都是在硬凑,这里介绍一下 k 进制 FWT 的通用做法。

这题求的就是 (\sum x^{a_i})^n

由于每一位独立,所以我们对每一位单独做 DFT 即可,在单位根存在时,这样的复杂度是 O(k^{m+1}m),使用暴力 DFT,m 为位数,k 为进制。

由于 k 次单位根可能不存在,我们考虑扩域,令 \omega^k=1,这样每个数就是关于 \omegak-1 次多项式,两个数的乘法就是循环卷积,即在 \bmod \omega^k-1 意义下进行。

但这样是不对的,由于 \omega^k-1 存在零因子,所以一个数有多种表示方法,我们无法确定一个数的真实值,所以考虑分圆多项式 \Phi_k(\omega),其满足 \omega 的阶恰为 k,且在 \mathbb Q 内不可约,所以我们定义两个数的乘法是在 \bmod \Phi_k(\omega) 意义下进行。

值得注意的是,我们大多数时候并不需要证明其在 F_p 内不可约,这是因为,我们计算过程中并未用到 F_p 的性质,所以我们的行为其实等价于先算出了真实结果,再对 p 取模,所以无需证明其在 F_p 内不可约。

还有一个问题是,\bmod \Phi_k(\omega) 太慢,\Phi_k(\omega)\varphi(k) 次多项式,也就是说,我们每次两个数乘法都要对 O(k) 次的多项式取模,就连 DFT 时的乘上 \omega^i 时都要 \bmod \Phi_k(\omega),让时间复杂度多了个 k,由于 \Phi_k(\omega)\mid \omega^k-1,所以我们可以先在 \bmod \omega^k-1 意义下进行运算,最后求答案时再 \bmod \Phi_k(\omega)

关于分圆多项式怎么算,我们有 \Phi_k(x)=(-1)^{[k=1]}\prod\limits_{d\mid k}(1-x^d)^{\mu(\frac nd)},这是由定义式 \prod\limits_{i=0}^{k-1}(x-\omega_n^k)^{[\gcd(n,k)=1]} 也即 x^k-1=\prod\limits_{d\mid k}\Phi_d(x) 得来的,就不在此赘述了。根据这个式子,我们可以在 O(2^{\omega(k)}\varphi(k)) 的时间复杂度内求出分圆多项式,代码如下。

void init(){
    T[0]=1;
    for(auto i:d[k]){//d 为 k 的所有 square-free 因子
        if(mu[k/i]==1) for(int j=phi[k];j>=i;j--) T[j]-=T[j-i];
        else if(mu[k/i]==-1) for(int j=i;j<=phi[k];j++) T[j]+=T[j-i];
    }
}

时间复杂度 O(k^{m+2}m),算出来只有 5\times 10^7 轻松过题。

#include<iostream>
#include<string>
using std::cin;
using std::cout;
#define int unsigned long long
int const iv5=14757395258967641293ull;
int n,k,m,pw[10],lim,iv;
int pow(int x,int y){
    int res=1;
    while(y){
        if(y&1) res*=x;
        x*=x,y>>=1;
    }
    return res;
}
struct node{
    int a[11];
    node():a(){}
    int operator [](int x)const{return a[x];}
    int& operator [](int x){return a[x];} 
    void operator +=(node const &x){for(int i=0;i<k;i++) a[i]+=x[i];}
    void operator *=(int y){for(int i=0;i<k;i++) a[i]*=y;}
    friend node operator *(node const &x,node const &y){
        static int a[30];
        node ans;
        for(int i=0;i<k+k;i++) a[i]=0;
        for(int i=0;i<k;i++) for(int j=0;j<k;j++) a[i+j]+=x[i]*y[j];
        for(int i=0;i<k;i++) ans[i]=a[i]+a[i+k];
        return ans;
    }
    friend node operator +(node const &x,node const &y){
        node ans;
        for(int i=0;i<k;i++) ans[i]=x[i]+y[i];
        return ans;
    }
    node apply(int x)const{
        x%=k;
        node ans;
        for(int i=0;i<k;i++) ans[(i+x)%k]=a[i];
        return ans;
    }
    node iapply(int x)const{
        x%=k;
        x=(k-x)%k;
        node ans;
        for(int i=0;i<k;i++) ans[(i+x)%k]=a[i];
        return ans;
    }
}T,ct[100010];
node pow(node x,int y){
    node res;
    res[0]=1;
    while(y){
        if(y&1) res=res*x;
        x=x*x,y>>=1;
    }
    return res;
}
void dwt(node *c){
    for(int i=0;i<m;i++)
        for(int a=0;a<lim;a+=pw[i+1])
            for(int b=0;b<pw[i];b++){
                static node x[11],y[11];
                for(int j=0;j<k;j++) x[j]=c[a+b+j*pw[i]],y[j]=node();
                for(int j=0;j<k;j++) for(int u=0;u<k;u++) y[j]+=x[u].apply(j*u);
                for(int j=0;j<k;j++) c[a+b+j*pw[i]]=y[j];
            }
}
void idwt(node *c){
    for(int i=0;i<m;i++)
        for(int a=0;a<lim;a+=pw[i+1])
            for(int b=0;b<pw[i];b++){
                static node x[11],y[11];
                for(int j=0;j<k;j++) x[j]=c[a+b+j*pw[i]],y[j]=node();
                for(int j=0;j<k;j++) for(int u=0;u<k;u++) y[j]+=x[u].iapply(j*u);
                for(int j=0;j<k;j++) c[a+b+j*pw[i]]=y[j];
            }
}
void init(){
    T[0]=1,T[1]=-1ull,T[2]=1,T[3]=-1ull,T[4]=1;
}
int deal(node x){
    int n=4;
    for(int i=k-1;i>=n;i--){
        int u=x[i];
        for(int j=1;j<=n;j++) x[i-j]-=u*T[n-j];
    }
    int u=x[0];
    u*=iv;
    u>>=m;
    return u%(1ull<<58);
}
signed main(){
    std::ios::sync_with_stdio(false),cin.tie(nullptr);
    k=10,m=5;
    pw[0]=1;
    for(int i=1;i<=m;i++) pw[i]=pw[i-1]*k;
    lim=pw[m],iv=pow(iv5,m);
    cin>>n;
    for(int i=1;i<=n;i++){
        int x;
        cin>>x;
        ++ct[x][0];
    }
    dwt(ct);
    for(int i=0;i<lim;i++) ct[i]=pow(ct[i],n);
    idwt(ct);
    init();
    for(int i=0;i<n;i++) cout<<deal(ct[i])<<'\n';
    return 0;
}