题解 P6276 【[USACO20OPEN]Exercise P】

· · 题解

一个置换的步数显然是所有环长的 \mathrm{lcm},又要求乘积,因此我们枚举每个质因子计算贡献. 考虑枚举一个质数 p,我们需要知道它的指数. 一个置换的贡献是所有环长在 p 处的次数的 \max,于是我们枚举 \max=d,计数有多少个置换的 \max\geq d. 令 t=p^d,那么现在需要求出的就是满足存在一个长为 t 的倍数的环的置换的个数,我们可以容斥为不存在的方案数. 那么现在一个环的 EGF 就是

C(z)=\sum_{i\geq 1}\frac{z^i}{i}[i\bmod t\neq 0]=\ln\left(\frac{1}{1-z}\right)-\frac{1}{t}\ln\left(\frac{1}{1-z^t}\right)

那么置换的个数就是

[n!z^n]\exp(C(z))=[n!z^n]\frac{(1-z^t)^{1/t}}{1-z}

注意乘以 \dfrac{1}{1-z} 相当于在做部分和,令 k=\left\lfloor\dfrac{n}{t}\right\rfloor,那么答案就是

[n!z^k](1-z)^{1/t-1}=n!(-1)^k\binom{1/t-1}{k}=\frac{n!(t-1)(2t-1)\cdots(kt-1)}{k!t^k}

现在的问题是因为这个在指数上,所以要对 M-1 取模,而这个相当于任意模数. 于是大概有两个做法,它们都要依赖一个复杂度上的结论:

\sum_{p\ is\ prime,k\geq 1}\left\lfloor\frac{n}{p^k}\right\rfloor=\Theta(n\log\log n)

这个可以用素数倒数和的界来简单证明:

\begin{aligned} &\sum_{p\ is\ prime,k\geq 1}\left\lfloor\frac{n}{p^k}\right\rfloor\\ \leq &\sum_{p\ is\ prime,k\geq 1}\frac{n}{p^k}\\ =&\ n\sum_{p\ is\ prime}\frac{1}{p-1}\\\leq&\ n\left(1+\sum_{p\ is\ prime}\frac{1}{p}\right)\\ =&\ \Theta(n\log\log n) \end{aligned}

那么我们现在考虑计算 \dfrac{n!}{k!t^k}\bmod (M-1), 这相当于是

\prod_{i=1}^ni^{[i\bmod t\neq 0]}

于是其中一个做法就是:上面这个乘积会被分成 \lfloor n/t\rfloor 段连续区间,我们使用数据结构来查询区间乘积. 那么这是经典的静态区间半群信息查询,在这里有 \Theta(n\log\log n) 次查询,于是我们使用 \Theta(n\log\log n)-\Theta(1) 的 Sqrt Tree 即可.

另一个稍微简单一点(?)的做法来自 EI,我们把 M-1 的质因子中 \leq n 的分解出来:

M-1=R\prod_i p_i^{\alpha_i}

\omega\leq 9 个质因子,我们从小到大枚举 t,这样 k 就是递减的,我们维护 \dfrac{n!}{k!} 以及其中和 p_i 互质的部分的乘积以及每个 p_i 的次数. 暴力加入的复杂度就是 \sum_{i}\lfloor n/p_i \rfloor=O(n\log\log n),写得不好(比如我)可能还要加个 n\omega. 对于一个和 p_i 都互质的 t,我们暴力 exgcd 计算 t^{-k}\bmod (M-1) 乘上去,exgcd 的执行次数是 n 以内质数幂的个数 =\Theta(n/\ln n)(只要对 >\sqrt{n}\leq \sqrt{n} 的质数分别考虑即可证明);对于其余的 t (不超过 \sum_i \log_{p_i}n 个)我们用之前记录的指数计算,这部分不是复杂度瓶颈. 最后还要对每个质数做快速幂,这部分复杂度 \Theta(\dfrac{n\log M}{\ln n}). 于是我们可以在 O(\sqrt{M}+n\log\log n+\dfrac{n\log M}{\ln n}) 复杂度内完成整道题.

代码是第二种方法的,不过我觉得还能写得短很多(

#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;
const int N=10000;
int pm[10],lim[10],pm_cnt,is_modp[N],mod,n;
struct modprod
{
    int nwn,res,all;
    int cnt[10],pw[10][N];
    void clear()
    {
        res=all=1,nwn=n;
        for(int i=1;i<=pm_cnt;i++)
        {
            pw[i][0]=1;cnt[i]=0;
            int s=0;
            for(long long t=pm[i];t<=n;t*=pm[i])
                s+=n/t;
            for(int j=1;j<=s;j++)pw[i][j]=1ll*pw[i][j-1]*pm[i]%(mod-1);
        }
    }
    int get(){return all;}
    int get(int p,int k)
    {
        int ans=res;
        for(int i=1;i<=pm_cnt;i++)
            if(pm[i]==p)
                ans=1ll*ans*pw[i][cnt[i]-k]%(mod-1);
            else ans=1ll*ans*pw[i][cnt[i]]%(mod-1);
        return ans;
    }
    void mul(int x)
    {
        all=1ll*all*x%(mod-1);
        for(int i=1;i<=pm_cnt;i++)
        {
            while(x%pm[i]==0)++cnt[i],x/=pm[i];
        }
        res=1ll*res*x%(mod-1);
    }
    void maintain(int t)
    {
        while(nwn>t)mul(nwn--);
    }
}a;
void factor(int x)
{
    for(int i=2;i*i<=x&&i<=n;i++)
    {
        if(x%i==0)
        {
            pm[++pm_cnt]=i;
            while(x%i==0)x/=i,++lim[pm_cnt];
            is_modp[i]=1;
        }
    }
    if(x!=1&&x<=n)
    {
        pm[++pm_cnt]=x,lim[pm_cnt]=1,is_modp[x]=1;
    }
}
int ecnt[N],ispw[N],prime[N],p[N],pwcnt[N],prime_cnt;
int qpower(int a,int b)
{
    int ans=1;for(;b;b>>=1,a=1ll*a*a%mod)if(b&1)ans=1ll*ans*a%mod;
    return ans;
}
int spower(int a,int b)
{
    int ans=1;for(;b;b>>=1,a=1ll*a*a%(mod-1))if(b&1)ans=1ll*ans*a%(mod-1);
    return ans;
}
void exgcd(int a,int b,int &x,int &y)
{
    if(!b){x=1,y=0;return;}
    exgcd(b,a%b,y,x);y-=a/b*x;
}
int Inv(int a)
{
    int x,y;exgcd(a,mod-1,x,y);
    return (x%(mod-1)+mod-1)%(mod-1);
}
void solve()
{
    int fac=1;for(int i=1;i<=n;i++)fac=1ll*fac*i%(mod-1);
    for(int i=2;i<=n;i++)
    {
        if(!p[i]){prime[++prime_cnt]=i;ispw[i]=i;pwcnt[i]=1;}
        if(ispw[i])
        {
            int k=n/i,p=ispw[i];
            a.maintain(k);
//            cout<<i<<" "<<p<<" "<<pwcnt[i]*k<<endl;
            int s=1;for(int j=1;j<=k;j++)s=1ll*s*(j*i-1)%(mod-1);
//            cout<<i<<" "<<p<<" "<<s<<endl;
            if(is_modp[p])s=1ll*s*a.get(p,pwcnt[i]*k)%(mod-1);
            else s=1ll*s*a.get()%(mod-1)*spower(Inv(i),k)%(mod-1);            
            ecnt[p]=(ecnt[p]+fac-s)%(mod-1);//cout<<i<<" "<<p<<" "<<ecnt[p]<<endl;
        }
        for(int j=1;j<=prime_cnt&&i*prime[j]<=n;j++)
        {
            int x=i*prime[j];p[x]=1;
            if(i%prime[j])ispw[x]=0;
            else{ispw[x]=ispw[i],pwcnt[x]=pwcnt[i]+1;break;}
        }
    }
    int ans=1;
    for(int i=2;i<=n;i++)
    if(!p[i])
    {
//        cout<<i<<" "<<ecnt[i]<<endl;
        ans=1ll*ans*qpower(i,(ecnt[i]+(mod-1))%(mod-1))%mod;
    }
    cout<<ans<<endl;
}
int main()
{
    scanf("%d%d",&n,&mod);
    factor(mod-1);a.clear();
    solve();
}