题解 P7511
观察不等式
先考虑单个长度为
发现
#include <bits/stdc++.h>
#pragma GCC optimize(3)
//#define int long long
#define ri register
#define mk make_pair
#define fi first
#define se second
#define pb push_back
#define eb emplace_back
#define is insert
#define es erase
#define vi vector<int>
#define vpi vector<pair<int,int>>
using namespace std; const int N=2010, Mod=998244353;
inline int read()
{
int s=0, w=1; ri char ch=getchar();
while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar(); }
while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch^48), ch=getchar();
return s*w;
}
int n,K,p[N],book[N],T[N][N],fac[N+5],inv[N+5],ans,a[N],m,F[N][N];
inline int ksc(int x,int p) { int res=1; for(;p;p>>=1, x=1ll*x*x%Mod) if(p&1) res=1ll*res*x%Mod; return res; }
signed main()
{
fac[0]=1;
for(ri int i=1;i<=N;i++) fac[i]=1ll*fac[i-1]*i%Mod;
inv[N]=ksc(fac[N],Mod-2);
for(ri int i=N;i;i--) inv[i-1]=1ll*inv[i]*i%Mod;
n=read(), K=read();
for(ri int i=1;i<=n;i++) p[i]=read();
T[0][0]=1;
for(ri int i=1;i<=n;i++)
{
T[i][0]=1;
for(ri int j=1;j<i;j++)
T[i][j]=(1ll*(j+1)*T[i-1][j]%Mod+1ll*(i-j)*T[i-1][j-1]%Mod)%Mod;
}
ans=fac[n];
for(ri int i=1;i<=n;i++)
{
if(book[i]) continue;
int x=i,len=0;
while(!book[x])
{
book[x]=1;
len++;
x=p[x];
}
ans=1ll*ans*inv[len]%Mod;
a[++m]=len;
}
F[0][0]=1;
for(ri int i=1;i<=m;i++)
{
if(a[i]==1)
{
for(ri int j=0;j<=K;j++) F[i][j]=F[i-1][j];
continue;
}
for(ri int j=1;j<a[i];j++)
{
for(ri int k=0;k+j<=K;k++)
{
F[i][j+k]=(F[i][j+k]+1ll*F[i-1][k]*a[i]%Mod*T[a[i]-1][j-1]%Mod)%Mod;
}
}
}
printf("%d\n",1ll*F[m][K]*ans%Mod);
return 0;
}
参考 Karry5307 的日报,有:
设
可以利用卷积做到在
现在记
利用分治
#include <bits/stdc++.h>
#pragma GCC optimize(3)
//#define int long long
#define ri register
#define mk make_pair
#define fi first
#define se second
#define pb push_back
#define eb emplace_back
#define is insert
#define es erase
#define vi vector<int>
#define vpi vector<pair<int,int>>
using namespace std; const int N=550000, Mod=998244353;
inline int read()
{
int s=0, w=1; ri char ch=getchar();
while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar(); }
while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch^48), ch=getchar();
return s*w;
}
int n,K,p[N],a[N],book[N],m,ans;
int fac[N+5],inv[N+5];
inline int ksc(int x,int p) { int res=1; for(;p;p>>=1, x=1ll*x*x%Mod) if(p&1) res=1ll*res*x%Mod; return res; }
inline int C(int x,int y) { if(x<y||x<0||y<0) return 0; return 1ll*fac[x]*inv[x-y]%Mod*inv[y]%Mod; }
vector<int> F;
vector<int> A,B;
int rev[N],r[24][2];
inline void Init()
{
fac[0]=1;
for(ri int i=1;i<=N;i++) fac[i]=1ll*fac[i-1]*i%Mod;
inv[N]=ksc(fac[N],Mod-2);
for(ri int i=N;i;i--) inv[i-1]=1ll*inv[i]*i%Mod;
r[23][1]=ksc(3,119), r[23][0]=ksc(332748118,119);
for(ri int i=22;~i;i--) r[i][0]=1ll*r[i+1][0]*r[i+1][0]%Mod, r[i][1]=1ll*r[i+1][1]*r[i+1][1]%Mod;
}
inline void Get_Rev(int T) { for(ri int i=0;i<T;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(T>>1):0); }
inline void DFT(vector<int> &s,int T,int type)
{
for(ri int i=0;i<T;i++) if(i<rev[i]) swap(s[i],s[rev[i]]);
for(ri int i=2,cnt=1;i<=T;i<<=1,cnt++)
{
int wn=r[cnt][type];
for(ri int j=0,mid=(i>>1);j<T;j+=i)
{
for(ri int k=0,w=1;k<mid;k++,w=1ll*w*wn%Mod)
{
int x=s[j+k], y=1ll*w*s[j+mid+k]%Mod;
s[j+k]=(x+y)%Mod;
s[j+mid+k]=x-y;
if(s[j+mid+k]<0) s[j+mid+k]+=Mod;
}
}
}
if(!type) for(ri int i=0,inv=ksc(T,Mod-2);i<T;i++) s[i]=1ll*s[i]*inv%Mod;
}
inline void NTT(int n,int m,vector<int> &A,vector<int> &B)
{
int len=n+m;
int T=1;
while(T<=len) T<<=1;
Get_Rev(T);
A.resize(T), B.resize(T);
for(ri int i=n+1;i<T;i++) A[i]=0;
for(ri int i=m+1;i<T;i++) B[i]=0;
DFT(A,T,1), DFT(B,T,1);
for(ri int i=0;i<T;i++) A[i]=1ll*A[i]*B[i]%Mod;
DFT(A,T,0);
for(ri int i=n+m+1;i<T;i++) A[i]=0;
A.erase(A.begin()+n+m+1,A.end());
B.erase(B.begin()+m+1,B.end());
}
void Merge(int l,int r,vector<int> &F)
{
if(l==r)
{
F.resize(a[l]+1), B.resize(a[l]+1);
for(ri int i=0;i<=a[l];i++)
{
int tp=(i&1)?(-1):(1);
int w1=(C(a[l],i)*tp+Mod)%Mod, w2=ksc(i,a[l]-1);
F[i]=w1, B[i]=w2;
}
NTT(F.size()-1,B.size()-1,F,B);
for(ri int i=0;i<=a[l];i++) F[i]=1ll*F[i]*a[l]%Mod;
F.erase(F.begin()+a[l]+1,F.end());
B.erase(B.begin(),B.end());
return;
}
int mid=(l+r)/2;
Merge(l,mid,F);
vector<int> G;
Merge(mid+1,r,G);
NTT(F.size()-1,G.size()-1,F,G);
}
signed main()
{
Init();
n=read(), K=read();
for(ri int i=1;i<=n;i++) p[i]=read();
ans=fac[n];
for(ri int i=1;i<=n;i++)
{
if(book[i]) continue;
int x=i,len=0;
while(!book[x])
{
book[x]=1;
x=p[x], len++;
}
ans=1ll*ans*inv[len]%Mod;
a[++m]=len;
}
Merge(1,m,F);
printf("%d\n",1ll*F[K]*ans%Mod);
return 0;
}