CF1805G 题解

· · 题解

CF1805G 题解

题目大意

定义一个 n\times n 矩阵是好的,当且仅当:

  • 矩阵的每一行都是一个 1\sim n 的排列。
  • 每一对竖直相邻的元素不同。

给定一个好的矩阵,求有多少好矩阵字典序比它小。

数据范围:n\le 2000

思路分析

考虑枚举第一个产生不同的位置 (i,j),那么后 n-i 行的答案就是 (P_n)^{n-i},其中 P_nn 阶错排数。

然后考虑这一行的后 n-j 个位置,观察可以发现每个可以被填的值有两种情况:在 P_{i-1,(j,n]} 中出现过或没出现过,出现过就相当于有一个位置不能填,否则相当于任何位置都可以填。

不妨设 f_{i,j} 表示 i 个元素,有 j 个元素有位置限制的方案数,考虑一个无限制元素的填法,容易得到转移:f_{i,j}=j\times f_{i-1,j-1}+(i-j)\times f_{i-1,j},边界条件是 f_{i,i}=P_i​。

注意 i=1 的情况需要特判。

那么接下来的问题是如何动态维护填 (i,j) 的过程以及后面有位置限制的元素个数,先考虑 (i,j) 的总方案数,相当于 P_{i,[j,n]}<P_{i,j} 的元素个数,树状数组即可。

注意到填 P_{i,j} 至多让有限制的元素个数减少 1,因此我们只关心 P_{i,j} 填的是不是有限制的元素。

先维护 P_{i,[j,n]} 中有限制的元素集合 K,那么相当于求 K<P_{i,j} 的元素个数,树状数组维护即可。

时间复杂度 \mathcal O(n^2\log n)

代码呈现

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int MAXN=2005,MOD=998244353;
ll f[MAXN][MAXN],pw[MAXN],ans=0;
int n,a[MAXN][MAXN];
bool lst[MAXN],cur[MAXN];
struct FenwickTree {
    int s,tr[MAXN];
    void init() { for(int x=1;x<=n;++x) tr[x]=x&-x; }
    void del(int x) { for(;x<=n;x+=x&-x) --tr[x]; }
    int qry(int x) { for(s=0;x;x&=x-1) s+=tr[x]; return s; }
}   rem,lim;
signed main() {
    scanf("%d",&n);
    for(int i=1;i<=n;++i) for(int j=1;j<=n;++j) scanf("%d",&a[i][j]);
    f[0][0]=1;
    for(int i=1;i<=n;++i) for(int j=0;j<=i;++j) {
        if(i==j) {
            f[i][i]=(i-1)*(f[i-1][i-1]+(i>1?f[i-2][i-2]:0))%MOD;
        } else {
            f[i][j]=((j?j*f[i-1][j-1]:0)+(i-j)*f[i-1][j])%MOD;
        }
    }
    pw[0]=1;
    for(int i=1;i<=n;++i) pw[i]=pw[i-1]*f[n][n]%MOD;
    rem.init();
    for(int i=1;i<n;++i) {
        ans=(ans+rem.qry(a[1][i]-1)*f[n-i][0]%MOD*pw[n-1])%MOD;
        rem.del(a[1][i]);
    }
    for(int i=2;i<=n;++i) {
        rem.init(),lim.init();
        fill(lst+1,lst+n+1,1);
        fill(cur+1,cur+n+1,1);
        for(int j=1;j<n;++j) {
            if(cur[a[i-1][j]]) lim.del(a[i-1][j]);
            lst[a[i-1][j]]=0;
            if(j>1) {
                if(lst[a[i][j-1]]) lim.del(a[i][j-1]);
                cur[a[i][j-1]]=0;
            }
            int p=lim.qry(n),x=rem.qry(a[i][j]-1),y=lim.qry(a[i][j]-1);
            if(a[i-1][j]<a[i][j]&&cur[a[i-1][j]]) --x;
            ans=(ans+(x-y)*f[n-j][p]%MOD*pw[n-i])%MOD;
            if(p) ans=(ans+y*f[n-j][p-1]%MOD*pw[n-i])%MOD;
            rem.del(a[i][j]);
        }
    }
    printf("%lld\n",ans);
    return 0;
}