题解:P12819 [NERC 2021] Fancy Stack

· · 题解

简单题。

a 排序。记 t_i 为值 i 的出现次数。

直接统计 b 比较困难。 我们统计 a 有多少种排列方式能得到合法的 b。最终答案为 \frac{ans}{\prod t_i}

a 去重得到 a'。记 a' 长度为 cntpos_i 表示 a'_ia 中第一次出现的位置。

发现每层的限制与奇偶性有关,于是我们两层两层 dp。

f_{i,j} 表示考虑了前 2i 层,b_{2i}=a'_j 的方案数。考虑怎么向后转移。

由乘法原理得到转移:

f_{i+1,k} \larr (pos_j-2i)\cdot t_{a'_k} \cdot f_{i,j}

前缀和优化。复杂度 O(n^2)

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define eb emplace_back
#define debug(...) fprintf(stderr,__VA_ARGS__)
#define rep(i,x,y) for(int i=(x);i<=(y);i++)
#define per(i,y,x) for(int i=(y);i>=(x);i--)
bool Memst;
namespace cyzz
{
    #define N 5005
    #define mod 998244353
    inline void Add(int &x,int y){x+=y;(x>=mod)&&(x-=mod);}
    int n,a[N],inv[N];
    int t[N],nxt[N];
    int cnt,pos[N];
    int f[N][N],pre[N][N];
    void init()
    {
        inv[0]=inv[1]=1;
        rep(i,2,n) inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
        rep(i,2,n) inv[i]=1ll*inv[i-1]*inv[i]%mod;
    }
    void clr()
    {
        memset(t,0,n+1<<2);
        rep(i,0,n/2+1) rep(j,0,cnt+1) f[i][j]=pre[i][j]=0;
        cnt=0;
    }
    void solve()
    {
        scanf("%d",&n);init();
        rep(i,1,n) scanf("%d",&a[i]),t[a[i]]++;
        sort(a+1,a+n+1);
        int u=1;
        while(u<=n)
        {
            pos[++cnt]=u;
            while(a[u]==a[pos[cnt]]) u++;
        }
        rep(i,1,cnt)
            f[1][i]=1ll*t[a[pos[i]]]*(pos[i]-1)%mod;
        rep(i,1,n/2)
        {
            if(i>1)
            {
                rep(j,1,cnt)
                {
                    Add(pre[i][j],pre[i][j-1]);
                    f[i][j]=1ll*t[a[pos[j]]]*pre[i][j]%mod;
                }
            }
            if(i<n/2)
            {
                rep(j,1,cnt)
                    Add(pre[i+1][j+1],1ll*(pos[j]-2*i)%mod*f[i][j]%mod);
            }
        }
        int ans=f[n/2][cnt];
        rep(i,1,cnt) ans=1ll*ans*inv[t[a[pos[i]]]]%mod;
        printf("%d\n",ans);
        clr();
    }
    void MAIN()
    {
        int T;scanf("%d",&T);
        while(T--)  solve();
    }
}bool Memed;
int main()
{
    // freopen("in.in","r",stdin);
    // freopen("out.out","w",stdout);
    cyzz::MAIN();
    debug("%.2lfms %.2lfMB",1.0*clock()/CLOCKS_PER_SEC*1000,
        1.0*abs(&Memed-&Memst)/1024/1024);
}

数组定义的有些混乱,将就着看吧。