CF1085G Beautiful Matrix 题解

· · 题解

好题。

首先用上“一般求小于某个东西的方案数”的策略:枚举前面那几行相等,然后枚举这一行不等。

我们运用它。我们可以钦定前 i 行相同,注意 i 的范围是 0n-1。如果我们能求出 f_i 表示使得第 i+1 行合法的方案数,那么答案就是:\sum_{i=0}^{n-1}f_iD_n^{n-i-1}。其中 D_n 表示长度为 n 的错排方案数。

接下来的问题是:如何求出 f_i

首先我们发现,第 i+1 行有如下限制:字典序小于原来的第 i+1 行,与原来的第 i 行没有重复的地方。

我们发现第 1 行只有第一个限制。所以我们可以通过计算排列的排名,得到 f_0。运用树状数组可以做到 O(n\log n),详情请见 P5367。

然后我们再次运用一开始提到的策略。我们钦定前 j-1 个数相同(接下来的讲解均在“前 i 行相同“的条件下进行)。那么,我们枚举第 j 个数是多少,然后看它是否合法。

你或许注意到了它是 O(n^3) 的。先别急!我们慢慢细说。

然后我们会剩下来后面的数。注意到后面的数,有些是相互重复的,有些是不重复的。

假设我们前面选择了 x 个数(也就是 j),有 y 组重复的数。请注意,这里的重复都指:存在 k,l 满足 a_{i,k}=a_{i+1,l}

那么,后面的方案数为 sz_{n-x,n-2x+y},表示后面有 n-x 的空位,其中有 n-2x+y 组重复数。

问题来了,怎么求 sz_{n,m} 呢?

我们发现,这可以容斥,那么 sz_{n,m}=\sum_{i=0}^{m}(-1)^i\binom{m}{i}(n-i)!,相信大家都能理解,不理解也没关系。

现在,我们成功做到了 O(n^3)

我们需要优化。

先给出我们的暴力代码,以便后面叙述:

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=2010;
const int val=1000000;
const int mod=998244353;
int inv[N];
int jc[N],ijc[N];
int D[N];
void init(int n){
    inv[1]=1;
    jc[0]=ijc[0]=1;
    for(int i=2;i<=n;i++)inv[i]=(mod-mod/i)*inv[mod%i]%mod;
    for(int i=1;i<=n;i++){
        jc[i]=jc[i-1]*i%mod;
        ijc[i]=ijc[i-1]*inv[i]%mod;
    }
    D[1]=0,D[2]=1;
    for(int i=3;i<=n;i++)D[i]=(i-1)*(D[i-1]+D[i-2])%mod;
    return;
}
int C(int n,int m){
    if(n<m)return 0;
    int fz=jc[n];
    int fm=ijc[m]*ijc[n-m]%mod;
    return fz*fm%mod;
}
int n;
struct BIT{
    int t[N];
    int lowbit(int x){return x&-x;}
    void add(int x,int y){
        for(int i=x;i<=n;i+=lowbit(i))t[i]+=y;
        return;
    }
    int query(int x){
        int res=0;
        for(int i=x;i>=1;i-=lowbit(i))res+=t[i];
        return res;
    }
}T;
int ans;
int a[N][N];
int sz[N][N];
int f[N];
int rk(){
    int ans=0;
    for(int i=n;i>=1;i--){
        ans=(ans+jc[n-i]*T.query(a[1][i]-1)%mod)%mod;
        T.add(a[1][i],1);
    }
    return ans;
}
int qpow(int x){return x%2==1?-1:1;}
int ksm(int a,int b){
    int z=1;
    while(b){
        if(b&1)z=z*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return z;
}
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    // freopen("B.in","r",stdin);
    // freopen("B.out","w",stdout);
    init(N-10);
    cin>>n;
    for(int i=1;i<=n;i++){
        for(int j=1;j<=n;j++){
            cin>>a[i][j];
        }
    }
//这里的 sz 与前文的不同。这里的意义就是,当我们前面枚举的状态是(x,y)时的方案数。后面会修改的,这里请不要重视!!!
    for(int x=1;x<=n;x++){
        for(int y=0;y<=x;y++){
            if(n-2*x+y<0)continue;
            for(int i=0;i<=n-2*x+y;i++){
                sz[x][y]+=qpow(i)*C(n-2*x+y,i)*jc[n-x-i]%mod;
                sz[x][y]=(sz[x][y]%mod+mod)%mod;
            }
        }
    }
    f[0]=rk();
    for(int i=1;i<n;i++){
        int x=0,y=0;
        unordered_map<int,int>M1,M2;
        for(int j=1;j<=n;j++){
            x++;
            for(int k=1;k<a[i+1][j];k++){
                if(M2[k])continue;
                if(a[i][j]==k)continue;
                int xx=x,yy=y;
                if(M1[k])yy++;
                if(M2[a[i][j]])yy++;
                f[i]=(f[i]+sz[xx][yy])%mod;
            }
            M1[a[i][j]]++;
            M2[a[i+1][j]]++;
            y+=M2[a[i][j]];
            y+=M1[a[i+1][j]];
        }
    }
    for(int i=0;i<n;i++)ans=(ans+f[i]*ksm(D[n],n-i-1)%mod)%mod;
    cout<<ans<<"\n";
    return 0;
}
/*

*/

我们先修改后面对 f_i 的处理。

我们发现,第 97 行(可以粘贴方便查看)可以单独提出来,提到枚举的外面。

那么,我们发现,yy 无非就两种状态:yy+1

我们可以统计几种情况需要加一,然后用总的个数减掉,就得出了不需要加一的方案。

我的实现有点丑陋,大家真的仅供参考,仅供参考……

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=2010;
const int val=1000000;
const int mod=998244353;
int inv[N];
int jc[N],ijc[N];
int D[N];
void init(int n){
    inv[1]=1;
    jc[0]=ijc[0]=1;
    for(int i=2;i<=n;i++)inv[i]=(mod-mod/i)*inv[mod%i]%mod;
    for(int i=1;i<=n;i++){
        jc[i]=jc[i-1]*i%mod;
        ijc[i]=ijc[i-1]*inv[i]%mod;
    }
    D[1]=0,D[2]=1;
    for(int i=3;i<=n;i++)D[i]=(i-1)*(D[i-1]+D[i-2])%mod;
    return;
}
int C(int n,int m){
    if(n<m)return 0;
    int fz=jc[n];
    int fm=ijc[m]*ijc[n-m]%mod;
    return fz*fm%mod;
}
int n;
struct BIT{
    int t[N];
    void clear(){
        for(int i=1;i<=n;i++)t[i]=0;
        return;
    }
    int lowbit(int x){return x&-x;}
    void add(int x,int y){
        for(int i=x;i<=n;i+=lowbit(i))t[i]+=y;
        return;
    }
    int query(int x){
        int res=0;
        for(int i=x;i>=1;i-=lowbit(i))res+=t[i];
        return res;
    }
}T,T1,T2,T3;//T1是第i行的出现个数,T2是第i+1行的,T3就是两行都出现的
int ans;
int a[N][N];
int sz[N][N];
int f[N];
int rk(){
    int ans=0;
    for(int i=n;i>=1;i--){
        ans=(ans+jc[n-i]*T.query(a[1][i]-1)%mod)%mod;
        T.add(a[1][i],1);
    }
    return ans;
}
int qpow(int x){return x%2==1?-1:1;}
int ksm(int a,int b){
    int z=1;
    while(b){
        if(b&1)z=z*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return z;
}
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    // freopen("B.in","r",stdin);
    // freopen("B.out","w",stdout);
    init(N-10);
    cin>>n;
    for(int i=1;i<=n;i++){
        for(int j=1;j<=n;j++){
            cin>>a[i][j];
        }
    }
//这里改回了正确的sz(前面描述的表示状态)
    for(int x=0;x<=n;x++){
        for(int y=0;y<=n;y++){
            for(int i=0;i<=y;i++){
                sz[x][y]+=qpow(i)*C(y,i)*jc[x-i]%mod;
                sz[x][y]=(sz[x][y]%mod+mod)%mod;
            }
        }
    }
    f[0]=rk();
//打注释的代码块就是之前写的暴力
    for(int i=1;i<n;i++){
        int x=0,y=0;
        T1.clear();T2.clear();T3.clear();
        unordered_map<int,int>M1,M2;
        for(int j=1;j<=n;j++){
            x++;
            if(M2[a[i][j]])y++;
            int total=(a[i+1][j]-1)-T2.query(a[i+1][j]-1)-(a[i][j]<a[i+1][j]&&!M2[a[i][j]]);
            int toty_add=T1.query(a[i+1][j]-1)-T3.query(a[i+1][j]-1);//加上1,减去重复,就是+1的次数
            int toty=total-toty_add;//不加一的次数
            f[i]=(f[i]+sz[n-x][n-2*x+y]*toty%mod+sz[n-x][n-2*x+y+1]*toty_add%mod)%mod;
            // if(i==2&&j==2)cerr<<total<<"\n";
            // for(int k=1;k<a[i+1][j];k++){
            //  if(M2[k])continue;
            //  if(a[i][j]==k)continue;
            //  int xx=x,yy=y;
            //  if(M1[k])yy++;
            //  if(M2[a[i][j]])yy++;
            //  f[i]=(f[i]+sz[xx][yy])%mod;
            // }
            T1.add(a[i][j],1);
            M1[a[i][j]]++;
            T2.add(a[i+1][j],1);
            M2[a[i+1][j]]++;
            // y+=M2[a[i][j]];
            y+=M1[a[i+1][j]];
            if(M1[a[i][j]]&&M2[a[i][j]])T3.add(a[i][j],1);
            if(M1[a[i+1][j]]&&M2[a[i+1][j]])T3.add(a[i+1][j],1);
        }
    }
    for(int i=0;i<n;i++)ans=(ans+f[i]*ksm(D[n],n-i-1)%mod)%mod;
    cout<<ans<<"\n";
    return 0;
}
/*

*/

我们就差最后一步了!

我们发现,容斥是没前途的,考虑朴素转移(真的有人是先想到容斥,才想到朴素转移吗……)。

有一个经典的套路:这种关于序列大小的 DP,可以通过“加一个新的数”来进行转移。

我们插入第 i 个数,有两种情况。

第一种:插入数之后,多一个重复的数对:我们在 m 个数中选择一个,故方案为 m\times sz_{n-1,m-1}

第二种:插入数之后,没有变化:我们在 n-m 个没限制的数中选一个,故方案为 (n-m)sz_{n-1,m}

所以,sz_{n,m}=m\times sz_{n-1,m-1}+(n-m)sz_{n-1,m}。时间复杂度 O(n^2\log n),瓶颈在于对于 f 的计算。

代码有点丑:

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=2010;
const int val=1000000;
const int mod=998244353;
int inv[N];
int jc[N],ijc[N];
int D[N];
void init(int n){
    inv[1]=1;
    jc[0]=ijc[0]=1;
    for(int i=2;i<=n;i++)inv[i]=(mod-mod/i)*inv[mod%i]%mod;
    for(int i=1;i<=n;i++){
        jc[i]=jc[i-1]*i%mod;
        ijc[i]=ijc[i-1]*inv[i]%mod;
    }
    D[0]=1;
    D[1]=0,D[2]=1;
    for(int i=3;i<=n;i++)D[i]=(i-1)*(D[i-1]+D[i-2])%mod;
    return;
}
int C(int n,int m){
    if(n<m)return 0;
    int fz=jc[n];
    int fm=ijc[m]*ijc[n-m]%mod;
    return fz*fm%mod;
}
int n;
struct BIT{
    int t[N];
    void clear(){
        for(int i=1;i<=n;i++)t[i]=0;
        return;
    }
    int lowbit(int x){return x&-x;}
    void add(int x,int y){
        for(int i=x;i<=n;i+=lowbit(i))t[i]+=y;
        return;
    }
    int query(int x){
        int res=0;
        for(int i=x;i>=1;i-=lowbit(i))res+=t[i];
        return res;
    }
}T,T1,T2,T3;
int ans;
int a[N][N];
int sz[N][N];
int f[N];
int rk(){
    int ans=0;
    for(int i=n;i>=1;i--){
        ans=(ans+jc[n-i]*T.query(a[1][i]-1)%mod)%mod;
        T.add(a[1][i],1);
    }
    return ans;
}
int qpow(int x){return x%2==1?-1:1;}
int ksm(int a,int b){
    int z=1;
    while(b){
        if(b&1)z=z*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return z;
}
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    // freopen("B.in","r",stdin);
    // freopen("B.out","w",stdout);
    init(N-10);
    cin>>n;
    for(int i=1;i<=n;i++){
        for(int j=1;j<=n;j++){
            cin>>a[i][j];
        }
    }
    for(int x=0;x<=n;x++){
        sz[x][x]=D[x];
        for(int y=0;y<x;y++){
            sz[x][y]=y*(y==0?0:sz[x-1][y-1])+(x-y)*sz[x-1][y];
            sz[x][y]%=mod;
        }
    }
    f[0]=rk();
    for(int i=1;i<n;i++){
        int x=0,y=0;
        T1.clear();T2.clear();T3.clear();
        unordered_map<int,int>M1,M2;
        for(int j=1;j<=n;j++){
            x++;
            if(M2[a[i][j]])y++;
            int total=(a[i+1][j]-1)-T2.query(a[i+1][j]-1)-(a[i][j]<a[i+1][j]&&!M2[a[i][j]]);
            int toty_add=T1.query(a[i+1][j]-1)-T3.query(a[i+1][j]-1);//加上1,减去重复,就是+1的次数
            int toty=total-toty_add;//不加一的次数
            f[i]=(f[i]+sz[n-x][n-2*x+y]*toty%mod+sz[n-x][n-2*x+y+1]*toty_add%mod)%mod;
            // if(i==2&&j==2)cerr<<total<<"\n";
            // for(int k=1;k<a[i+1][j];k++){
            //  if(M2[k])continue;
            //  if(a[i][j]==k)continue;
            //  int xx=x,yy=y;
            //  if(M1[k])yy++;
            //  if(M2[a[i][j]])yy++;
            //  f[i]=(f[i]+sz[xx][yy])%mod;
            // }
            T1.add(a[i][j],1);
            M1[a[i][j]]++;
            T2.add(a[i+1][j],1);
            M2[a[i+1][j]]++;
            // y+=M2[a[i][j]];
            y+=M1[a[i+1][j]];
            if(M1[a[i][j]]&&M2[a[i][j]])T3.add(a[i][j],1);
            if(M1[a[i+1][j]]&&M2[a[i+1][j]])T3.add(a[i+1][j],1);
        }
    }
    for(int i=0;i<n;i++)ans=(ans+f[i]*ksm(D[n],n-i-1)%mod)%mod;
    cout<<ans<<"\n";
    return 0;
}
/*

*/