P5109 归程

· · 题解

我们定义两个数列 a,b 的卷积为 c_n=\sum\limits_{i\oplus j=n}a_ib_j,其中 \oplus 表示 k 进制下的不进位加法。

考虑到不进位加法每一位相互独立,我们可以对每一位单独进行变换,考察单独一位的变换应该是什么形式,发现对于某一位,有 $i\oplus j=(i+j)\bmod k$,即循环卷积,使用 $\text{DFT}$ 即可,暴力 $\text{DFT}$ 的时间复杂度为 $O(k^{n+1}n)$,$n$ 为位数。 但是 $k$ 次单位根 $\omega_k$ 可能在模意义下不存在,考虑扩域,即人为定义一个 $x$,满足 $x^k=1$,那么我们所有数都可以表示成 $\sum\limits_{i=0}^{k-1}c_ix^i$ 的形式,两个数的运算实际上就是在 $\bmod x^k-1$ 意义下的运算。但是我们发现这样的数并不构成域,因为存在零因子,即一个数会有多种表示方法,形如(真实值)+(零因子的线性组合),我们要找到一种合适的扩域方法,避免零因子的存在。考虑分圆多项式 $\Phi_k(x)$,其满足在 $\mathbb Q$ 上不可约,且在 $\bmod \Phi_k(x)$ 下 $x$ 的阶为 $k$,此时就没有零因子,所以最后得到的结果一定只有常数项有值。但 $\bmod \Phi_k(x)$ 的常数很大,我们考虑到 $\Phi_k(x)\mid x^k-1$,所以我们可以在计算过程中先对 $x^k-1$ 取模,再把最后的结果对 $\Phi_k(x)$ 取模。 由于 $\text{DFT}$ 的时候,需要进行 $O(k^2)$ 次单项式乘多项式,所以时间复杂度 $O(k^{n+2}n)$。 例题就是 P5109 归程,给定一个长为 $n$ 的数列 $a$,每次询问两个数 $p,x$,问有多少长度最多为 $p$ 的数列 $b$ 满足 $1\le b_i\le n,b_1\oplus b_2\oplus \cdots\oplus b_k=x$,其中 $k$ 为 $b$ 的长度。 我们令 $c_i=\sum\limits_{j=1}^n[a_j=i]$,那么题目求的实际上就是 $(c^0+c^1+\cdots+c^p)_x$,由于我们不能求逆,所以不能用等比数列求和公式,考虑倍增,求出 $c^1,c^2,c^4,c^8,\cdots$ 和 $\sum\limits_{i=0}^0c^i,\sum\limits_{i=0}^1c^i,\sum\limits_{i=0}^3c^i,\sum\limits_{i=0}^7c^i,\cdots$ 即可快速求解。 ```cpp #include<iostream> using std::cin; using std::cout; int const mod=2333; int pow(int x,int y){ int res=1; while(y){ if(y&1) res=1ll*res*x%mod; x=1ll*x*x%mod,y>>=1; } return res; } int n,m,v,k,lim,iv,pw[14]; struct node{ int a[11]; node():a(){} int operator [](int x)const{return a[x];} int& operator [](int x){return a[x];} void operator +=(node const &x){for(int i=0;i<k;i++) a[i]+=x[i];} void fit(){for(int i=0;i<k;i++) a[i]%=mod;} void operator *=(int y){for(int i=0;i<k;i++) a[i]=a[i]*y%mod;} friend node operator *(node const &x,node const &y){ static int a[20]; node ans; for(int i=0;i<k+k;i++) a[i]=0; for(int i=0;i<k;i++) for(int j=0;j<k;j++) a[i+j]+=x[i]*y[j]; for(int i=0;i<k;i++) ans[i]=(a[i]+a[i+k])%mod; return ans; } friend node operator +(node const &x,node const &y){ node ans; for(int i=0;i<k;i++) ans[i]=(x[i]+y[i])%mod; return ans; } node apply(int x)const{ x=(x%k+k)%k; node ans; for(int i=0;i<k;i++) ans[(i+x)%k]=a[i]; return ans; } }T; struct poly{ node c[6600]; poly():c(){} node operator [](int x)const{return c[x];} node& operator [](int x){return c[x];} void dwt(){ for(int i=0;i<v;i++) for(int a=0;a<lim;a+=pw[i+1]) for(int b=0;b<pw[i];b++){ static node x[10],y[10]; for(int j=0;j<k;j++) x[j]=c[a+b+j*pw[i]],y[j]=node(); for(int j=0;j<k;j++) for(int u=0;u<k;u++) y[j]+=x[u].apply(j*u); for(int j=0;j<k;j++) y[j].fit(),c[a+b+j*pw[i]]=y[j]; } } friend poly operator +(poly const &x,poly const &y){ poly a; for(int i=0;i<lim;i++) a[i]=x[i]+y[i]; return a; } friend poly operator *(poly const &x,poly const &y){ poly a; for(int i=0;i<lim;i++) a[i]=x[i]*y[i]; return a; } void idwt(){ for(int i=0;i<v;i++) for(int a=0;a<lim;a+=pw[i+1]) for(int b=0;b<pw[i];b++){ static node x[10],y[10]; for(int j=0;j<k;j++) x[j]=c[a+b+j*pw[i]],y[j]=node(); for(int j=0;j<k;j++) for(int u=0;u<k;u++) y[j]+=x[u].apply(-j*u); for(int j=0;j<k;j++) y[j].fit(),c[a+b+j*pw[i]]=y[j]; } for(int i=0;i<lim;i++) c[i]*=iv; } }q,p[50],t[50]; int read(){ int ans=0; for(int i=0;i<v;i++){ int x; cin>>x; ans+=pw[i]*x; } return ans; } int const phi[]={0,1,1,2,2,4,2,6,4,6,4},mu[]={0,1,-1,-1,0,-1,1,-1,0,0,1}; void init(){ T[0]=1; for(int i=1;i<=k;i++)if(k%i==0){ if(mu[k/i]==1) for(int j=phi[k];j>=i;j--) T[j]-=T[j-i]; else if(mu[k/i]==-1) for(int j=i;j<=phi[k];j++) T[j]+=T[j-i]; } } int deal(node x){ int n=phi[k]; for(int i=k-1;i>=n;i--){ int u=x[i]; for(int j=1;j<=n;j++) x[i-j]=(x[i-j]-u*T[n-j]%mod+mod)%mod; } return x[0]; } int main(){ std::ios::sync_with_stdio(false),cin.tie(nullptr); cin>>n>>m>>v>>k; ++k,iv=pow(pow(k,mod-2),v),init(); pw[0]=1; for(int i=1;i<=v;i++) pw[i]=pw[i-1]*k; lim=pw[v]; for(int x;n--;) x=read(),q[x][0]++; for(int i=0;i<lim;i++) q[i][0]%=mod; q.dwt(); for(int i=0;i<lim;i++) t[0][i][0]=1; p[0]=q; for(int i=1;i<30;i++) p[i]=p[i-1]*p[i-1]; for(int i=1;i<30;i++) t[i]=t[i-1]+t[i-1]*p[i-1]; while(m--){ int c,x; cin>>c,x=read(); poly ans,w=t[0]; ++c; for(int i=29;~i;i--) if(c>=(1<<i)) ans=ans+w*t[i],w=w*p[i],c-=1<<i; ans.idwt(); cout<<deal(ans[x])<<'\n'; } return 0; } ```