AGC008E Next or Nextnext 题解

· · 题解

设从 i 指向 p_i 所形成的图为 G_1,从 i 指向 a_i 所形成的图为 G_2,考虑考查 G_1G_2 之间的关系。

显然 G_1 由若干个环组成,于是考虑 G_1 中的每一个环:

接下来考虑 G_2 如何推回 G_1

先考虑 G_2 中的环。设 G_2 中有 c_i 个长度为 i 的环,枚举要合并的环的数量 2j

将每个步骤的方案数相乘即可得到这个部分的结果,将每个 j 的结果相加即可得到每种环长的答案。

再考虑 G_2 中的基环内向树。不妨设基环内向树的环上的点依次为 1,2,\dots,m,其中对于每个不大于 n 的正整数 i 都存在一条从 (i \bmod m)+1 指向 i 的边;设挂着链的点依次为 p_1,p_2,\dots,p_k,挂在 p_i 上的链的长度为 l_i,则我们需要将 l_i 插回 p_ip_{(i \bmod k)+1} 之间。设 p_ip_{(i \bmod m)+1} 之间有 e 条边:

将每条链插回的方案数相乘即可得到这部分的结果,答案即为每个环和每棵基环内向树的方案数的乘积。

朴素实现的时间复杂度为 \mathcal O(n \log n),精细实现可以做到 \mathcal O(n)

const int N=1e5+5,mod=1e9+7;
int n,a[N],c[N],vis[N],ans=1,fac[N<<1],infac[N<<1];
vector <int> ve[N];
void add(int &a,int b){
    a+=b;
    if(a>=mod) a-=mod;
}
int ad(int a,int b){
    a+=b;
    if(a>=mod) a-=mod;
    return a;
}
int ksm(int a,int b){
    int res=1;
    while(b){
        if(b&1) res=1ll*res*a%mod;
        b>>=1,a=1ll*a*a%mod; 
    }
    return res;
}
int C(int n,int m){
    return 1ll*fac[n]*infac[m]%mod*infac[n-m]%mod;
}
void dfs(int u,bool f,int len){
    f|=(ve[u].size()!=1);
    vis[u]=1,len++;
    if(vis[a[u]]){
        if(!f) return c[len]++,void();
        vector <int> s,p,l;
        s.pb(u),vis[u]=2;
        int v=a[u];
        while(v!=u) s.pb(v),vis[v]=2,v=a[v];
        int m=s.size();
        for(int i=0;i<m;i++){
            int u=s[i];
            if(ve[u].size()==2){
                if(vis[ve[u][0]]==2) v=ve[u][1];
                else v=ve[u][0];
                int cnt=0;
                while(1){
                    vis[v]=1,cnt++;
                    if(ve[v].size()==2) cout<<0<<endl,exit(0);
                    if(ve[v].size()==0) break;
                    v=ve[v][0];
                }
                p.pb(i),l.pb(cnt);
            }
        }
        int k=p.size();
        for(int i=0;i<k;i++){
            int e;
            if(i==0) e=p[i]+m-p[k-1];
            else e=p[i]-p[i-1];
            if(e>l[i]) add(ans,ans);
            if(e<l[i]) cout<<0<<endl,exit(0);
        }
        return;
    }
    dfs(a[u],f,len);
}
void solve(){
    cin>>n;
    fac[0]=infac[0]=1;
    for(int i=1;i<=n+n;i++) fac[i]=1ll*i*fac[i-1]%mod;
    infac[n+n]=ksm(fac[n+n],mod-2);
    for(int i=n+n-1;i>0;i--) infac[i]=1ll*(i+1)*infac[i+1]%mod;
    for(int i=1;i<=n;i++) cin>>a[i],ve[a[i]].pb(i);
    for(int i=1;i<=n;i++) if(ve[i].size()>2) cout<<0<<endl,exit(0);
    for(int i=1;i<=n;i++){
        if(vis[i]) continue;
        dfs(i,0,0);
    }
    for(int i=1;i<=n;i++){
        int sum=0;
        for(int j=0;j+j<=c[i];j++){
            int res=1;
            if(j!=0) res=1ll*C(c[i],j+j)*C(j+j,j)%mod*fac[j]%mod*ksm((mod+1)/2,j)%mod*ksm(i,j)%mod;
            if((i&1)&&i!=1) res=1ll*res*ksm(2,c[i]-j-j)%mod;
            add(sum,res);
        }
        ans=1ll*ans*sum%mod;
    }
    cout<<ans<<endl;
}