题解 P8350【[SDOI/SXOI2022] conversion】

· · 题解

考虑一个暴力的 DP:关于三进制 DP,令 f_{i,j} 表示三进制意义下高于 i 的信息已经确定,且已确定的数的二进制表达为 j,所有满足上述条件的数 ix^iz^{b(i)} 之和。y^{a(i)} 的系数可以在 f_{-1,j} 处统计。

这个做法是 O(n) 的,不可承受。

注意到三进制下填入后 i 位时,只会影响二进制下最低的 l_i 位,其中 l_i 是满足 2^{l_i}>3^{i+1} 的最小自然数。对于更高位的影响,其至多进一位。

于是我们可以把状态更新成 f_{i,j,0/1},表示已确定的数的二进制表达的后 l_i 位(也即模 2^{l_i} 的结果)是 j,且向更高位无/有进位时的答案。

但是这个 DP 并没有任何实质性的改变……吗?

考虑 根号分治。高 3^{m/2} 位一共只会访问根号个态;低 3^{m/2} 位的状态总数是根号的;故只对低 3^{m/2} 位记忆化,复杂度就是正确的根号。

代码:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod=998244353;
int ksm(int x,ll y){int z=1;for(;y;y>>=1,x=1ll*x*x%mod)if(y&1)z=1ll*z*x%mod;return z;}
const int LG=28;
const int HF=14;
ll n;int X,Y,Z,py[1010],pz[1010];
ll tri[30];
int px[30][3],l[30],bd[30],sum[30],pc[1<<10];
int f[1<<25][2];
int dfs(int pos,ll num,bool car,bool lim){
    if(pos==-1){return !car?py[pc[num]]:0;}
    if(!lim&&pos<HF&&f[sum[pos]+num][car]!=-1)return f[sum[pos]+num][car];
    int res=0;
    for(int i=0;i<=(lim?bd[pos]:2);i++)for(int CAR=0;CAR<2;CAR++){
        ll NUM=num+tri[pos]*i+((ll)CAR<<l[pos]);
        if((NUM>=(1ll<<l[pos+1]))!=car)continue;
        NUM&=(1ll<<l[pos+1])-1;
        int PX=px[pos][i],PY=py[pc[NUM>>l[pos]]],PZ=pz[i];
        res=(1ll*dfs(pos-1,NUM&((1ll<<l[pos])-1),CAR,lim&&(i==bd[pos]))*PX%mod*PY%mod*PZ+res)%mod;
    }
    // printf("%d %d %d %d:%d\n",pos,num,car,lim,res);
    if(!lim&&pos<HF)f[sum[pos]+num][car]=res;
    return res;
}
int main(){
    freopen("conversion.in","r",stdin);
    freopen("conversion.out","w",stdout);
    scanf("%lld%d%d%d",&n,&X,&Y,&Z);
    for(int i=0;i<(1<<10);i++)pc[i]=pc[i>>1]+(i&1);
    py[0]=pz[0]=1;for(int i=1;i<=1000;i++)py[i]=1ll*py[i-1]*Y%mod,pz[i]=1ll*pz[i-1]*Z%mod;
    tri[0]=1;for(int i=1;i<=LG+1;i++)tri[i]=tri[i-1]*3;
    for(int i=0;i<LG;i++)px[i][0]=1,px[i][1]=ksm(X,tri[i]),px[i][2]=1ll*px[i][1]*px[i][1]%mod;
    for(int i=0;i<=LG;i++)while((1ll<<l[i])<=tri[i])l[i]++;
    // for(int i=0;i<=LG;i++)printf("%d ",l[i]);puts("");
    for(int i=0;i<HF;i++)sum[i+1]=sum[i]+(1<<l[i+1]);
    for(int i=0;i<LG;i++)bd[i]=(n/tri[i])%3;
    memset(f,-1,sizeof(f));
    printf("%d\n",(dfs(LG-1,0,false,true)+mod-1)%mod);
    return 0;
}