RNG XOR
云浅知处
·
·
题解
$$
f_0=0\\
f_i=1+\sum_{j}p_jf_{i\text{ xor }j}=1+\sum_{x\text{ xor }y=i}p_xf_y\quad(i\neq 0)
$$
所以有 $(f\star p)_i=f_i-1\quad(i\neq 0)$,现在已知 $p$ 要解 $f$;其中 $\star$ 表示异或卷积。
当然可以高消,但复杂度 $O(2^{3n})$。
我们发现这个形式就很像 FWT,但是 $f_0$ 不满足这个条件,不能直接求逆。
考虑 $(f\star p)_0$ 是啥,我们发现由于 $\sum p_i=1$,必然有 $\sum f_i=\sum (f\star p)_i$,于是应当有
$$
(f\star p)_0=f_0+\sum_{i\neq 0}f_i-(f\star p)_i=f_0+2^N-1
$$
于是我们有这个方程:
$$
(f_0,f_1,\cdots,f_{2^N-1})\star(p_0,p_1,\cdots,p_{2^N-1})=(f_0+2^N-1,f_1-1,\cdots,f_{2^N-1}-1)
$$
我们发现方程的两边都有 $f$,没办法直接求逆。但是只要令 $p_0\leftarrow p_0-1$ 就有
$$
(f_0,f_1,\cdots,f_{2^N-1})\star(p_0-1,p_1,\cdots,p_{2^N-1})=(2^N-1,-1,\cdots,-1)
$$
此时看上去可以直接 FWT 求逆,但注意这个式子会得到 $\text{FWT}(f)_0\times 0=0$,无法得到 $\text{FWT}(f)_0$ 的真实值。解决方案是任取一个值作为 $\text{FWT}(f)_0$,我们发现如果他和真实的 $\text{FWT}(f)_0$ 的偏差是 $\Delta$,那么 IFWT 回去之后每个值 $f_i$ 都会变成 $f_i+\frac{\Delta}{2^N}$。再结合 $f_0=0$ 即可得到 $\Delta$ 的值,从而解得所有 $f_i$ 的真实值。
时间复杂度 $O(N2^N)$。
```cpp
#include<bits/stdc++.h>
#define int long long
using namespace std;
inline int read(){
int x=0,f=1;char c=getchar();
for(;(c<'0'||c>'9');c=getchar()){if(c=='-')f=-1;}
for(;(c>='0'&&c<='9');c=getchar())x=x*10+(c&15);
return x*f;
}
const int mod=998244353;
int ksm(int x,int y,int p=mod){
int ans=1;
for(int i=y;i;i>>=1,x=1ll*x*x%p)if(i&1)ans=1ll*ans*x%p;
return ans%p;
}
int inv(int x,int p=mod){return ksm(x,p-2,p)%p;}
int randint(int l,int r){return rand()*rand()%(r-l+1)+l;}
void add(int &x,int v){x+=v;if(x>=mod)x-=mod;}
void Mod(int &x){if(x>=mod)x-=mod;}
const int N=18;
vector<int>p,f;
void FWT(vector<int>&A,int k,int tag){
for(int i=0;i<k;i++){
for(int S=0;S<(1<<k);S++){
if(S&(1<<i))continue;
int x=A[S],y=A[S^(1<<i)];
Mod(A[S]=x+y),Mod(A[S^(1<<i)]=x+mod-y);
}
}
if(tag==-1){
int iv=inv(1<<k);
for(int S=0;S<(1<<k);S++)A[S]=1ll*A[S]*iv%mod;
}
}
signed main(void){
#ifdef YUNQIAN
freopen("in.in","r",stdin);
#endif
int sum=0,n=read();f.resize(1<<n),p.resize(1<<n);
for(int i=0;i<(1<<n);i++)p[i]=read(),add(sum,p[i]);sum=inv(sum);
for(int i=0;i<(1<<n);i++)p[i]=1ll*p[i]*sum%mod;add(p[0],mod-1);
FWT(p,n,1);
f[0]=(1<<n)-1;for(int i=1;i<(1<<n);i++)f[i]=mod-1;
FWT(f,n,1);
for(int i=0;i<(1<<n);i++)f[i]=1ll*f[i]*inv(p[i])%mod;
FWT(f,n,-1);
for(int i=1;i<(1<<n);i++)add(f[i],mod-f[0]);f[0]=0;
for(int i=0;i<(1<<n);i++)cout<<f[i]<<'\n';
return 0;
}
```