题解 P6276 【[USACO20OPEN]Exercise P】
一个置换的步数显然是所有环长的
那么置换的个数就是
注意乘以
现在的问题是因为这个在指数上,所以要对
这个可以用素数倒数和的界来简单证明:
那么我们现在考虑计算
于是其中一个做法就是:上面这个乘积会被分成
另一个稍微简单一点(?)的做法来自 EI,我们把
有
代码是第二种方法的,不过我觉得还能写得短很多(
#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;
const int N=10000;
int pm[10],lim[10],pm_cnt,is_modp[N],mod,n;
struct modprod
{
int nwn,res,all;
int cnt[10],pw[10][N];
void clear()
{
res=all=1,nwn=n;
for(int i=1;i<=pm_cnt;i++)
{
pw[i][0]=1;cnt[i]=0;
int s=0;
for(long long t=pm[i];t<=n;t*=pm[i])
s+=n/t;
for(int j=1;j<=s;j++)pw[i][j]=1ll*pw[i][j-1]*pm[i]%(mod-1);
}
}
int get(){return all;}
int get(int p,int k)
{
int ans=res;
for(int i=1;i<=pm_cnt;i++)
if(pm[i]==p)
ans=1ll*ans*pw[i][cnt[i]-k]%(mod-1);
else ans=1ll*ans*pw[i][cnt[i]]%(mod-1);
return ans;
}
void mul(int x)
{
all=1ll*all*x%(mod-1);
for(int i=1;i<=pm_cnt;i++)
{
while(x%pm[i]==0)++cnt[i],x/=pm[i];
}
res=1ll*res*x%(mod-1);
}
void maintain(int t)
{
while(nwn>t)mul(nwn--);
}
}a;
void factor(int x)
{
for(int i=2;i*i<=x&&i<=n;i++)
{
if(x%i==0)
{
pm[++pm_cnt]=i;
while(x%i==0)x/=i,++lim[pm_cnt];
is_modp[i]=1;
}
}
if(x!=1&&x<=n)
{
pm[++pm_cnt]=x,lim[pm_cnt]=1,is_modp[x]=1;
}
}
int ecnt[N],ispw[N],prime[N],p[N],pwcnt[N],prime_cnt;
int qpower(int a,int b)
{
int ans=1;for(;b;b>>=1,a=1ll*a*a%mod)if(b&1)ans=1ll*ans*a%mod;
return ans;
}
int spower(int a,int b)
{
int ans=1;for(;b;b>>=1,a=1ll*a*a%(mod-1))if(b&1)ans=1ll*ans*a%(mod-1);
return ans;
}
void exgcd(int a,int b,int &x,int &y)
{
if(!b){x=1,y=0;return;}
exgcd(b,a%b,y,x);y-=a/b*x;
}
int Inv(int a)
{
int x,y;exgcd(a,mod-1,x,y);
return (x%(mod-1)+mod-1)%(mod-1);
}
void solve()
{
int fac=1;for(int i=1;i<=n;i++)fac=1ll*fac*i%(mod-1);
for(int i=2;i<=n;i++)
{
if(!p[i]){prime[++prime_cnt]=i;ispw[i]=i;pwcnt[i]=1;}
if(ispw[i])
{
int k=n/i,p=ispw[i];
a.maintain(k);
// cout<<i<<" "<<p<<" "<<pwcnt[i]*k<<endl;
int s=1;for(int j=1;j<=k;j++)s=1ll*s*(j*i-1)%(mod-1);
// cout<<i<<" "<<p<<" "<<s<<endl;
if(is_modp[p])s=1ll*s*a.get(p,pwcnt[i]*k)%(mod-1);
else s=1ll*s*a.get()%(mod-1)*spower(Inv(i),k)%(mod-1);
ecnt[p]=(ecnt[p]+fac-s)%(mod-1);//cout<<i<<" "<<p<<" "<<ecnt[p]<<endl;
}
for(int j=1;j<=prime_cnt&&i*prime[j]<=n;j++)
{
int x=i*prime[j];p[x]=1;
if(i%prime[j])ispw[x]=0;
else{ispw[x]=ispw[i],pwcnt[x]=pwcnt[i]+1;break;}
}
}
int ans=1;
for(int i=2;i<=n;i++)
if(!p[i])
{
// cout<<i<<" "<<ecnt[i]<<endl;
ans=1ll*ans*qpower(i,(ecnt[i]+(mod-1))%(mod-1))%mod;
}
cout<<ans<<endl;
}
int main()
{
scanf("%d%d",&n,&mod);
factor(mod-1);a.clear();
solve();
}