题解:P13757 【MX-X17-T6】Selection

· · 题解

设该二维数组为 a。钦定前 k 个数组大于等于后 n-k 个数组,最后乘上 \begin{pmatrix}n\\k \end{pmatrix} 即可。

发现前 k 个数组大于等于后 n-k 个数组的充要条件是对于任意 1m 之间的 j,满足 \min\limits_{i=1}^{k} a_{i,j} \max\limits_{i=k+1}^{n} a_{i,j} 大。

发现大于不是很好维护,于是考虑总方案数减去等于的情况。设数组 b=\min\limits_{i=1}^{k} a_{i,j}。钦定前 k 个数组中有 x 个数组为 b,后 n-k 个数组有 y 个数组也为 b,那么剩余的 n-x-y 个数组的每一维是相对独立的。枚举 \min\max 的分界 i,贡献为 (\sum\limits_{i=1}^{v}i^{k-x}(v-i+1)^{n-k-y})^m。然后再乘上 \begin{pmatrix}k\\x \end{pmatrix} \begin{pmatrix}n-k\\y \end{pmatrix} 以及转移系数 (-1)^{x+y+1}。设 f_{x,y}=\begin{pmatrix}k\\x\end{pmatrix}\begin{pmatrix}n-k\\y\end{pmatrix}(\sum\limits_{i=1}^{v}i^{k-x}(v-i+1)^{n-k-y})^m (-1)^{x+y+1},那么答案就是总方案数减去 \sum\limits_{x=1}^k\sum\limits_{y=1}^{n-k} f_{x,y}

考虑如何计算总方案数。注意,如果我们直接把 (\sum\limits_{i=1}^{v}i^{k}(v-i+1)^{n-k})^m 当做总方案数,那么可能会出现某一个维度 j 上存在某个 \min\max 的分界点 c,使得 \min\limits_{i=1}^{k} a_{i,j} \geq c \geq c-1 \geq \max\limits_{i=k+1}^{n} a_{i,j} ,从而被 cc-1 重复计算导致算重的情况。所以这里要使用点边容斥,把每一个这样的 c 减掉,这样是能够做到补充不漏的。总方案数即为 ((\sum\limits_{i=1}^{v}i^{k}(v-i+1)^{n-k})-(\sum\limits_{i=1}^{v-1}i^{k}(v-i)^{n-k}))^m

这样,我们做到了 \Theta(n^2\log m+nv)。瓶颈在于求 \sum\limits_{i=1}^vi^{x}(v-i+1)^{y} 的形式。设其为 g_{x,y}。有:

\end{aligned}

所以我们只需要求出 x=0y=0 的情况。以下只讨论 y=0,因为 g_{x,y}=g_{y,x}

发现 g_{x,0}=\sum\limits_{i=1}^vi^{x} 是一个求幂和的形式。设 y_{i}=\sum\limits_{j=1}^i j^{x},则 g_{x,0}=y_v。由拉格朗日插值定理有 y_v=\sum\limits_{i=1}^{n+1} y_i \sum\limits_{j=1,j\neq i}^{n+1}\frac{v-j}{i-j}。发现每一个 i\sum\limits_{j=1,j\neq i}^{n+1}\frac{v-j}{i-j} 都是确定了的,可以预处理求解;1n+1 也可以直接求解。所以可以 \Theta(n^2) 求解每一个 g_{x,0}。于是,我们就做到了 \Theta(n^2\log m)

#include<bits/stdc++.h>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/tree_policy.hpp>
#include<ext/rope>
#define rep(i,l,r) for(int i=(l),i##end=(r);i<=i##end;++i)
#define per(i,r,l) for(int i=(r),i##end=(l);i>=i##end;--i)
#define ll long long
#define int ll
#define double long double
#define pii pair<int,int>
#define fi first
#define se second
#define pb push_back
#define popcnt __builtin_popcount
#define rbtree(way) tree<way,null_type,less<way>,rb_tree_tag,tree_order_statistics_node_update>
using namespace std;
using namespace __gnu_cxx;
using namespace __gnu_pbds;
class IO{
    char ibuf[1<<16],obuf[1<<16],*ipl=ibuf,*ipr=ibuf,*op=obuf;
    public:
    ~IO(){write();}
    inline void read(){ipl==ipr?ipr=(ipl=ibuf)+fread(ibuf,1,1<<15,stdin):0;}
    inline void write(){fwrite(obuf,1,op-obuf,stdout),op=obuf;}
    inline char getchar(){return (read(),ipl!=ipr)?*ipl++:EOF;}
    inline void putchar(char c){*op++=c;if(op-obuf>(1<<15)) write();}
    template<typename V>IO&operator>>(V&v){
        int s=1;char c=getchar();v=0;
        for(;!isdigit(c);c=getchar()) if(c=='-') s=-s;
        for(;isdigit(c);c=getchar()) v=(v<<1)+(v<<3)+(c^48);
        return v*=s,*this;
    }
    inline IO&operator<<(char c){return putchar(c),*this;}
    template<typename V>IO&operator<<(V v){
        if(!v) putchar('0');
        if(v<0) putchar('-'),v=-v;
        char stk[100],*top=stk;
        for(;v;v/=10) *++top=v%10+'0';
        while(top!=stk) putchar(*top--);
        return *this;
    }
}io;
#define cin io
#define cout io
#define IOS (ios::sync_with_stdio(0),cin.tie(0),cout.tie(0))
const int maxn=4000+10,maxm=1e6+10,mod=1e9+7,inf=INT_MAX;
inline int ksm(int x,int k,int mod=mod){
    int ans=1;
    for(x%=mod;k;k>>=1,x=x*x%mod) if(k&1) ans=ans*x%mod;
    return ans;
}

int T,n,m,k,v,ans,pi[maxn],spi[maxn],g[maxn];
int f[maxn][maxn],fct[maxn],ifct[maxn],inv[maxn],comb[maxn][maxn];

inline void makef(){
    rep(i,1,n+3){
        pi[i]=1,spi[i]=i,g[i]=ifct[i-1]*ifct[n+3-i]%mod,i+1&1?g[i]=mod-g[i]:0;
        rep(j,1,n+3) if(j!=i) (g[i]*=j+mod-v)%=mod;
    }
    rep(i,0,n+1){
        int sum=0;
        rep(j,1,n+3) (sum+=spi[j]*g[j])%=mod;
        f[0][i]=f[i][0]=sum;
        rep(j,1,n+3) (pi[j]*=j)%=mod,spi[j]=(spi[j-1]+pi[j])%mod;
    }
    rep(x,1,n) rep(y,1,n) f[x][y]=((v+1)*f[x-1][y]+mod-f[x-1][y+1])%mod;
}

signed submain(){
    cin>>n>>m>>k>>v;
    --v,makef(),ans=mod-f[k][n-k],++v,makef(),ans=ksm(ans+f[k][n-k],m);
    rep(i,1,k) rep(j,1,n-k){
        int nw=comb[i][k-i]*comb[j][n-k-j]%mod*ksm(f[k-i][n-k-j],m)%mod;
        i+j&1?(ans+=nw)%=mod:(ans+=mod-nw)%=mod;
    }
    cout<<ans*comb[k][n-k]%mod<<'\n';
    return 0;
}

signed main(){
    cin>>T,comb[0][0]=fct[0]=1;
    rep(i,1,4005) fct[i]=fct[i-1]*i%mod;
    ifct[4005]=ksm(fct[4005],mod-2);
    per(i,4005,1) ifct[i-1]=ifct[i]*i%mod,inv[i]=ifct[i]*fct[i-1]%mod;
    rep(x,0,4000) rep(y,0,4000) if(x||y){
        if(x) (comb[x][y]+=comb[x-1][y])%=mod;
        if(y) (comb[x][y]+=comb[x][y-1])%=mod;
    }
    rep(o,1,T) submain();
    return 0;
}/*
2
5 1 3 3
10 4 7 2

*/