P8321 题解

· · 题解

非常值得一做的 dp 模型。

思路

考虑式子中有 \min将两个序列合并后操作,记录属于 a 序列还是 b 序列。

此时,这个问题变成了一个序列上的匹配问题,这种问题有个套路:

考虑 dp_{i,j} 表示看到第 i 小的,目前 a 序列有 j 个需要匹配后面的。

对于第 i 位,我们很容易统计在第 i 位之前有多少 a 序列的,b 序列的,因此存 j 足以让我们推出前面有多少匹配好的和有多少的 b 序列的需要匹配后面一个数。

此时我们考虑第 i 小的数是匹配一个前面的或留下来匹配后面的,容易发现匹配后面的转移时系数就是这个数本身(由于排序了,\min 一定是他),匹配前面的转移时系数则是另一个序列剩余的需要匹配后面一个数的个数。

复杂度 O(n^2),空间小心点,不滚动可以过,但正式考试一定要尽可能优化!

这道题有两个套路值得学习(欢迎各位补充类似题):

  1. 序列 \min/\max 计数尝试排序后解决。
  2. 序列匹配计数问题的 dp 套路。
  3. 含固定辅助值dp 可以尝试寻找状态中变量关系来简化状态。

诶怎么说好两个来了三个。

代码



#include<bits/stdc++.h>
#define int long long
using namespace std;
const int mod=998244353;
int qp(int a,int b){
    int ans=1;
    while(b){
        if(b&1) ans=(ans*a)%mod;
        a=(a*a)%mod;
        b>>=1;
    }
    return ans;
}
int fac[10005],inv[10005];
void init(){
    fac[0]=1;
    for(int i=1;i<=10000;i++) fac[i]=fac[i-1]*i%mod;
    inv[10000]=qp(fac[10000],mod-2);
    for(int i=9999;i>=0;i--) inv[i]=inv[i+1]*(i+1)%mod;
}
struct node{
    int pos,tp;
}x[10005]; 
bool cmp(node A,node B){
    return A.pos<B.pos;
}
signed dp[10005][5005];
signed main(){
    init();
    int n;
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>x[i].pos;
        x[i].tp=1;
    }
    for(int i=n+1;i<=2*n;i++){
        cin>>x[i].pos;
        x[i].tp=2;
    }
    sort(x+1,x+2*n+1,cmp);
    dp[0][0]=1;
    int tot1=0,tot2=0;
    for(int i=1;i<=2*n;i++){
        for(int j=0;j<=n;j++){
            int gt=tot1-j;
            if(gt>tot2) continue;
            int lft1=j,lft2=tot2-gt;
            if(x[i].tp==1){
                //get a bigger one
                dp[i][j+1]=((int)dp[i][j+1]+(int)dp[i-1][j]*(int)x[i].pos)%mod; 
                //get a smaller one
                if(lft2) dp[i][j]=((int)dp[i][j]+(int)dp[i-1][j]*(int)lft2)%mod;
            }
            else{
                //get a bigger one
                dp[i][j]=((int)dp[i][j]+(int)dp[i-1][j]*(int)x[i].pos)%mod;
                //get a smaller one
                if(lft1) dp[i][j-1]=((int)dp[i][j-1]+(int)dp[i-1][j]*(int)lft1)%mod;
            }
        }
        if(x[i].tp==1) tot1++;
        if(x[i].tp==2) tot2++;
    }
    cout<<(int)dp[2*n][0]*inv[n]%mod;
    return 0;
}```