题解 P7511

· · 题解

$\text{Solution}:

观察不等式 π_{i}<π_{π_{i}'}<π_{π_{π_{i}'}'}<...,由于 π' 是一个排列,所以最后不等关系一定构成了若干个有向环。可以发现,钦定环上一条边不满足条件,并断掉这条边使破环成链,就构成了一个长度为 m 的序列。问题转化为:有若干个环,每个环对应一个序列,在这个序列中有 b_{i} 个位置满足 π_{i}<π_{i+1},且 \sum b_{i}=k。求满足条件的总排列个数。

先考虑单个长度为 m 的环对答案贡献了 jπ_{i}<π_{π_{i}'} 的方案数:钦定一个位置 i,满足 π_{i}=1,令 i 这个位置为断点,就是在长度为 m-1 的序列中有 j-1 个位置满足 π_{i}<π_{i+1}。而 π_{i}=1 的位置有 m 种不同的选择。把所有环的贡献相乘就是答案,得到:

\binom{n}{a_{1},a_{2}...a_{m}}\sum\limits_{\sum b_{i}=k}\prod\limits_{i} a_{i}\left<a_{i}-1\atop b_{i}-1\right>,\sum\limits_{i=1}^{m}a_{i}=n

发现 \sum\limits_{b_{i}=k} 是一个背包形式,dp 求欧拉数然后暴力背包就可以做到 O(n^2)。注意我们给出的定义,需要对 a_{i}=1 的情况特判。下面给出一个 O(n^2) 的代码,可以得到 60 分:

#include <bits/stdc++.h>
#pragma GCC optimize(3)
//#define int long long
#define ri register
#define mk make_pair
#define fi first
#define se second
#define pb push_back
#define eb emplace_back
#define is insert
#define es erase
#define vi vector<int>
#define vpi vector<pair<int,int>>
using namespace std; const int N=2010, Mod=998244353;
inline int read()
{
    int s=0, w=1; ri char ch=getchar();
    while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar(); }
    while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch^48), ch=getchar();
    return s*w;
}
int n,K,p[N],book[N],T[N][N],fac[N+5],inv[N+5],ans,a[N],m,F[N][N];
inline int ksc(int x,int p) { int res=1; for(;p;p>>=1, x=1ll*x*x%Mod) if(p&1) res=1ll*res*x%Mod; return res; }
signed main()
{
    fac[0]=1;
    for(ri int i=1;i<=N;i++) fac[i]=1ll*fac[i-1]*i%Mod;
    inv[N]=ksc(fac[N],Mod-2);
    for(ri int i=N;i;i--) inv[i-1]=1ll*inv[i]*i%Mod;
    n=read(), K=read();
    for(ri int i=1;i<=n;i++) p[i]=read();
    T[0][0]=1;
    for(ri int i=1;i<=n;i++)
    {
        T[i][0]=1;
        for(ri int j=1;j<i;j++)
            T[i][j]=(1ll*(j+1)*T[i-1][j]%Mod+1ll*(i-j)*T[i-1][j-1]%Mod)%Mod;
    }
    ans=fac[n];
    for(ri int i=1;i<=n;i++)
    {
        if(book[i]) continue;
        int x=i,len=0;
        while(!book[x])
        {
            book[x]=1;
            len++;
            x=p[x];
        }
        ans=1ll*ans*inv[len]%Mod;
        a[++m]=len;
    }
    F[0][0]=1;
    for(ri int i=1;i<=m;i++)
    {
        if(a[i]==1)
        {
            for(ri int j=0;j<=K;j++) F[i][j]=F[i-1][j];
            continue;
        }
        for(ri int j=1;j<a[i];j++)
        {
            for(ri int k=0;k+j<=K;k++)
            {
                F[i][j+k]=(F[i][j+k]+1ll*F[i-1][k]*a[i]%Mod*T[a[i]-1][j-1]%Mod)%Mod;
            }
        }
    }
    printf("%d\n",1ll*F[m][K]*ans%Mod);
    return 0;
}

参考 Karry5307 的日报,有:

\left<n\atop m\right>=\sum\limits_{k=0}^{m}\binom{n+1}{k}(m+1-k)^{n}(-1)^{k}

g_{i}=\binom{n+1}{i}(-1)^{i}h_{i}=(i+1)^{n},有:

\left<n\atop i\right>=f_{i}=\sum\limits_{j=0}^{i}g_{j}h_{i-j}

可以利用卷积做到在 O(n\log n) 的时间内求出一行的欧拉数。

现在记 F_{i} 表示已经选了 i 个小于号的方案数,枚举到第 i 个环时 G_{j}=a_{i}\times \left<a_{i}-1\atop j-1\right>,那么有:

F_{i}=\sum\limits_{j=0}^{i}F_{j}\times G_{i-j}

利用分治 \text{NTT} 即可解决。

\text{Code}:
#include <bits/stdc++.h>
#pragma GCC optimize(3)
//#define int long long
#define ri register
#define mk make_pair
#define fi first
#define se second
#define pb push_back
#define eb emplace_back
#define is insert
#define es erase
#define vi vector<int>
#define vpi vector<pair<int,int>>
using namespace std; const int N=550000, Mod=998244353;
inline int read()
{
    int s=0, w=1; ri char ch=getchar();
    while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar(); }
    while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch^48), ch=getchar();
    return s*w;
}
int n,K,p[N],a[N],book[N],m,ans;
int fac[N+5],inv[N+5];
inline int ksc(int x,int p) { int res=1; for(;p;p>>=1, x=1ll*x*x%Mod) if(p&1) res=1ll*res*x%Mod; return res; }
inline int C(int x,int y) { if(x<y||x<0||y<0) return 0; return 1ll*fac[x]*inv[x-y]%Mod*inv[y]%Mod; }
vector<int> F;
vector<int> A,B;
int rev[N],r[24][2];
inline void Init()
{
    fac[0]=1;
    for(ri int i=1;i<=N;i++) fac[i]=1ll*fac[i-1]*i%Mod;
    inv[N]=ksc(fac[N],Mod-2);
    for(ri int i=N;i;i--) inv[i-1]=1ll*inv[i]*i%Mod;
    r[23][1]=ksc(3,119), r[23][0]=ksc(332748118,119);
    for(ri int i=22;~i;i--) r[i][0]=1ll*r[i+1][0]*r[i+1][0]%Mod, r[i][1]=1ll*r[i+1][1]*r[i+1][1]%Mod;
}
inline void Get_Rev(int T) { for(ri int i=0;i<T;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(T>>1):0); }
inline void DFT(vector<int> &s,int T,int type)
{
    for(ri int i=0;i<T;i++) if(i<rev[i]) swap(s[i],s[rev[i]]);
    for(ri int i=2,cnt=1;i<=T;i<<=1,cnt++)
    {
        int wn=r[cnt][type];
        for(ri int j=0,mid=(i>>1);j<T;j+=i)
        {
            for(ri int k=0,w=1;k<mid;k++,w=1ll*w*wn%Mod)
            {
                int x=s[j+k], y=1ll*w*s[j+mid+k]%Mod;
                s[j+k]=(x+y)%Mod;
                s[j+mid+k]=x-y;
                if(s[j+mid+k]<0) s[j+mid+k]+=Mod;
            }
        }
    }
    if(!type) for(ri int i=0,inv=ksc(T,Mod-2);i<T;i++) s[i]=1ll*s[i]*inv%Mod;
}
inline void NTT(int n,int m,vector<int> &A,vector<int> &B)
{
    int len=n+m;
    int T=1;
    while(T<=len) T<<=1;
    Get_Rev(T);
    A.resize(T), B.resize(T);
    for(ri int i=n+1;i<T;i++) A[i]=0;
    for(ri int i=m+1;i<T;i++) B[i]=0;
    DFT(A,T,1), DFT(B,T,1);
    for(ri int i=0;i<T;i++) A[i]=1ll*A[i]*B[i]%Mod;
    DFT(A,T,0);
    for(ri int i=n+m+1;i<T;i++) A[i]=0;
    A.erase(A.begin()+n+m+1,A.end());
    B.erase(B.begin()+m+1,B.end());
}
void Merge(int l,int r,vector<int> &F)
{
    if(l==r)
    {
        F.resize(a[l]+1), B.resize(a[l]+1);
        for(ri int i=0;i<=a[l];i++)
        {
            int tp=(i&1)?(-1):(1);
            int w1=(C(a[l],i)*tp+Mod)%Mod, w2=ksc(i,a[l]-1);
            F[i]=w1, B[i]=w2;
        }
        NTT(F.size()-1,B.size()-1,F,B);
        for(ri int i=0;i<=a[l];i++) F[i]=1ll*F[i]*a[l]%Mod;
        F.erase(F.begin()+a[l]+1,F.end());
        B.erase(B.begin(),B.end());
        return;
    }
    int mid=(l+r)/2;
    Merge(l,mid,F);
    vector<int> G;
    Merge(mid+1,r,G);
    NTT(F.size()-1,G.size()-1,F,G);
}
signed main()
{
    Init();
    n=read(), K=read();
    for(ri int i=1;i<=n;i++) p[i]=read();
    ans=fac[n];
    for(ri int i=1;i<=n;i++)
    {
        if(book[i]) continue;
        int x=i,len=0;
        while(!book[x])
        {
            book[x]=1;
            x=p[x], len++;
        }
        ans=1ll*ans*inv[len]%Mod;
        a[++m]=len;
    }
    Merge(1,m,F);
    printf("%d\n",1ll*F[K]*ans%Mod);
    return 0;
}