P12486 题解

· · 题解

学到了一个比较好的做法,遂记录一下。

问题转化

本题要求 \sum_p \prod_{i=1}^n (\min_{j=1}^m p_{j,i}),考虑其组合意义:对于每个由 m 个排列组成的排列组 p,为每一列 i 分配一个数 val_i 满足 \forall 1\leq j \leq m,val_i \leq p_{j,i},求所有排列组分配 val 的方案数之和。

p 计算 val 的方案需要容斥,较为复杂,考虑对每组 val,计算其对于多少排列组 p 是合法的。

考虑从小到大枚举 i,确定 val=i 的列并填入每个排列值为 i 的位置。注意到排列的 i 只能填入 val \leq i 的列,所以每次先插入列,恰好可以在插入后计算填入排列的 i 的方案,并且这两部分几乎是独立的。

特殊性质 q=0

考虑根据上述转化设计 dp,设 f_{i,j} 表示考虑到数值 i,已经有 j 列的 val 被确定的所有方案,其填完所有排列 \leq i 的位置的方案数之和。

枚举 i,转移分为两部分:

最终答案为 f_{n,n}。时间复杂度 O(n^3)

满分做法

确定值的存在,影响了对方案数的计算。对于 f_{i,j},其代表的所有方案的填法数不一定相同了。

下面称存在确定值的行为特殊行,列为特殊列,其余为普通行,列。

注意到 q\leq 10,说明存在确定值的行和列都不多,考虑状压。设 f_{i,j,S} 表示考虑到 i,已经有 j 个普通列和集合 S 中的特殊列的 val 被确定了,填完所有排列 \leq i 的位置的方案数之和。

转移依然可以分成两部分:

综上,总复杂度是 O(n^22^q(n+q)) 的。

Code

代码写的比较丑 QAQ

#include<bits/stdc++.h>
using namespace std;
#define inf 0x3f3f3f3f
#define mod 998244353
#define N 60
#define pii pair<int,int>
#define fi first
#define sc second
#define mp make_pair
#define int long long
#define il inline
int n,m,k,h,q,a[N],b[N],cnt[1<<10],f[N][N][1<<10],c[N][N];
int s[N],t[N],v[N],d[N]; pii p[N];
il void madd(int &x,int y){
    x=(x+y>=mod)?(x+y-mod):(x+y); return ;
}
il int qpow(int a,int b){
    int res=1;
    while(b){if(b&1) res=res*a%mod;a=a*a%mod,b>>=1;}
    return res;
}
signed main(){
    // freopen("a.in","r",stdin);
    // freopen("a.out","w",stdout);
    scanf("%lld%lld%lld",&n,&m,&q);
    for(int i=0;i<=n;++i) c[i][0]=1;
    for(int i=1;i<=n;++i)
        for(int j=1;j<=i;++j) c[i][j]=(c[i-1][j-1]+c[i-1][j])%mod;
    for(int i=1;i<=q;++i){
        scanf("%lld%lld%lld",&p[i].fi,&p[i].sc,&v[i]);
        a[i]=p[i].sc,b[i]=p[i].fi,++s[v[i]];
    }
    sort(a+1,a+q+1),k=unique(a+1,a+q+1)-a-1;
    for(int i=1;i<=q;++i){
        p[i].sc=lower_bound(a+1,a+k+1,p[i].sc)-a-1;
        t[v[i]]|=(1<<p[i].sc);
    }
    sort(b+1,b+q+1),h=unique(b+1,b+q+1)-b-1;
    for(int i=1;i<=q;++i)
        p[i].fi=lower_bound(b+1,b+h+1,p[i].fi)-b;
    for(int S=1;S<(1<<k);++S) cnt[S]=cnt[S>>1]+(S&1);
    f[0][0][0]=1;
    for(int i=1;i<=n;++i){
        for(int j=0;j<=n-k;++j)
            for(int S=0;S<(1<<k);++S) if(f[i-1][j][S])
                for(int l=0;j+l<=n-k;++l)
                    f[i][j+l][S]=(f[i][j+l][S]+f[i-1][j][S]*c[n-k-j][l])%mod;
        for(int j=0;j<=n-k;++j)
            for(int l=0;l<k;++l)
                for(int S=0;S<(1<<k);++S)
                    if(S&(1<<l)) madd(f[i][j][S],f[i][j][S^(1<<l)]);
        for(int j=0;j<=n-k;++j)
            for(int S=0;S<(1<<k);++S){
                if((S&t[i])!=t[i]||j+cnt[S]<i) f[i][j][S]=0;
                if(!f[i][j][S]) continue;
                memset(d,0,(h+1)<<3);
                for(int l=1;l<=q;++l) if(S&(1<<p[l].sc)){
                    if(v[l]==i) d[p[l].fi]=-inf; if(v[l]>i) ++d[p[l].fi];
                }
                f[i][j][S]=f[i][j][S]*qpow(j+cnt[S]-i+1,m-h)%mod;
                for(int l=1;l<=h;++l) if(d[l]>=0)
                    f[i][j][S]=f[i][j][S]*(j+cnt[S]-i+1-d[l])%mod;
            }
    }
    printf("%lld\n",f[n][n-k][(1<<k)-1]);
    return 0;
}