题解 CF1411G No Game No Life

· · 题解

首先,我们可以在 O(m) 的时间复杂度内求出每个点的 SG 值,直接按照 SG 函数的定义求解即可。

题目转化为给定一个数组 a,长度为 n,和一个初值为 0 的变量 v,执行如下操作:

  1. [1,n+1] 中等概率随机一个整数 p。若 p\le n,进行操作 2;否则,进行操作 3。
  2. 游戏结束,若 v=0 则输,否则赢。

求赢的概率。

f_i 为一次随机过程使得 v\larr v\oplus i 的概率,设其集合幂级数为 F,设游戏结束时 v=i 的概率为 g_iG 为其集合幂级数,乘法为异或卷积,则有等式

\begin{aligned} G=&\frac{\sum\limits_{i=0}^{\infin}F^i}{n+1}\\ =&\frac{1}{(n+1)(1-F)} \end{aligned}

这里我们要证明 1-F 有逆,有逆的充要条件是 FWT 后每一项都不为零。

证明需要先了解 FWT 的一个性质, FWT 是线性变换且每一项系数都为 \pm1,且常数项对 FWT 后每一项的系数贡献都为 1

因为 F 的和为 \frac n{n+1} 且每一项都为非负数,所以对 F 进行 FWT 后每一项取值应该在 -\frac n{n+1}\frac n{n+1} 之间,-F 进行 FWT 后的结果为 F 的结果取反,所以每一项取值也在 -\frac n{n+1}\frac n{n+1} 之间,给原集合幂级数常数项加 1,就是给 FWT 后每一项都加 1,所以 1-F 在 FWT 后每一项的取值应该在 1-\frac n{n+1}1+\frac n{n+1} 之间,不包括零,所以 1-F 有逆。

知道这点就可以直接做了,求逆的方式就是点值每一项都求逆,时间复杂度就是 FWT 的复杂度。

不过 SG 值的值域是 \sqrt m,所以复杂度是 O(m+\sqrt m \log m)

#include<cstdio>
#include<algorithm>
#include<vector>
int const mod=998244353,inv2=(mod+1)/2; 
std::vector<int> v[100010],p[100010];
int n,m,d[100010],a[100010],f[600],lim;
int pow(int x,int y){
    int res=1;
    while(y){
        if(y&1) res=1ll*res*x%mod;
        x=1ll*x*x%mod,y>>=1; 
    }
    return res;
}
void fwt(int *a){
    for(int mid=1;mid<lim;mid<<=1)
        for(int r=mid<<1,j=0;j<lim;j+=r)
            for(int k=0;k<mid;k++){
                int x=a[k+j],y=a[mid+j+k];
                a[k+j]=(x+y)%mod,a[mid+j+k]=(x-y+mod)%mod;  
            }
}
void ifwt(int *a){
    for(int mid=1;mid<lim;mid<<=1)
        for(int r=mid<<1,j=0;j<lim;j+=r)
            for(int k=0;k<mid;k++){
                int x=a[k+j],y=a[mid+j+k];
                a[k+j]=1ll*(x+y)*inv2%mod,a[mid+j+k]=1ll*inv2*(x-y+mod)%mod;    
            }
}
int main(){
    scanf("%d%d",&n,&m);
    int lm=pow(n+1,mod-2);
    for(int i=1,x,y;i<=m;i++)scanf("%d%d",&x,&y),v[y].push_back(x),++d[x];
    static int que[100010];
    int *hd=que,*tl=que;
    for(int i=1;i<=n;i++) if(!d[i]) *tl++=i;
    while(hd!=tl){
        int x=*hd++;
        std::sort(p[x].begin(),p[x].end());
        p[x].erase(std::unique(p[x].begin(),p[x].end()),p[x].end());
        for(a[x]=0;a[x]<(int)p[x].size()&&p[x][a[x]]==a[x];a[x]++);
        lim=std::max(a[x],lim);
        f[a[x]]=(f[a[x]]-lm+mod)%mod;
        for(auto u:v[x]){
            p[u].push_back(a[x]);
            if(!--d[u]) *tl++=u;
        }
    }
    if(lim==0) lim=2;
    else do lim+=lim&-lim; while((lim&-lim)!=lim);
    f[0]=(f[0]+1)%mod;
    for(int i=0;i<lim;i++) f[i]=1ll*f[i]*(n+1)%mod;
    fwt(f);
    for(int i=0;i<lim;i++) f[i]=pow(f[i],mod-2);
    ifwt(f);
    printf("%d\n",(1-f[0]+mod)%mod);
    return 0;
}