题解:AT_arc207_a [ARC207A] Affinity for Artifacts

· · 题解

题目可以转化为求合法的 1n 的排列 p 满足 \sum_{i=1}^n{a_{p_i}-\min(a_{p_i},i)}\le X 的方案数。

把上面的式子变换成 \sum_{i=1}^n{a_i}-X\le \sum_{i=1}^n{\min(a_{p_i},i)},那么式子左侧是定值,而右侧的最大值为 n\times (n-1)/2

这样我们就能发现当 \sum_{i=1}^n{a_i}-X<0 时,对于任意的排列 p,均能满足条件,即答案为 n!;如果 \sum_{i=1}^n{a_i}-X>n\times (n-1)/2,就没有合法方案了,因此答案为 0

所以问题就可以转化为求满足 \sum_{i=1}^n{\min(a_{p_i},i)}\ge\sum_{i=1}^n{a_i}-Xp 的方案数,而这个问题又可以看作是一个包含 1n 的序列(我们称为 2 序列)与 a 序列(我们称为 1 序列)之间两两匹配的问题。

对于这种问题,我们可以考虑将两个序列中的数放进一个序列 b 中,每个数都有个属性,代表它是原先哪个序列的,对于最小值的操作我们就可以将数列从大到小排序,那么一个匹配的最小值就是两个数中靠后的那个。

这样的一个序列就很好处理了,考虑 dp。设当前考虑到前 i 个数,之中有 L1 序列的点,有 R2 序列的点,有 j1 序列点未被匹配,则有 R-L+j2 序列的点未匹配,当前匹配而产生的代价为 s。则有 dp_{i,j,s} 表示在上面的条件下的方案数。

接下来考虑状态转移。分两种情况:

  1. i 个点是 1 序列的。有 dp_{i,j,s}=dp_{i-1,j-1,s}+dp_{i-1,j,s-b_i}\times (R-L+j+1),代表当前点是否匹配。
  2. i 个点是 2 序列的。有 dp_{i,j,s}=dp_{i-1,j,s}+dp_{i-1,j+1,s-b_i}\times (j+1)

最后的答案就是 \sum_{i=\sum_{j=1}^n{a_j}-X}^{n\times (n-1)/2}{dp_{n,0,i}},初始化 dp_{0,0,0}=1,时间复杂度是 O(n^4) 级别的。

#include<bits/stdc++.h>
using namespace std;
const long long mod=998244353;
long long n,m,a[110],idx;
long long dp[210][210][5010],ans;
struct node{
    long long num,op;
}b[210];
bool cmp(node x,node y){
    return x.num>y.num;
}
int main(){
    scanf("%lld%lld",&n,&m);
    long long sum=0;
    for(int i=1;i<=n;i++)
        scanf("%lld",&a[i]),sum+=a[i],b[++idx].num=a[i],b[idx].op=1;
    sum=sum-m;
    if(sum<0){
        ans=1;
        for(int i=1;i<=n;i++)
            ans*=i,ans%=mod;
        printf("%lld",ans);
        return 0;
    }
    if(sum>n*(n-1)/2){
        printf("0");
        return 0;
    }
    for(int i=1;i<=n;i++)b[++idx].num=i-1,b[idx].op=2;
    sort(b+1,b+2*n+1,cmp);
    long long L=0,R=0;
    dp[0][0][0]=1;
    for(int i=1;i<=2*n;i++){
        if(b[i].op==1)L++;
        else R++;
        for(int j=0;j<=L;j++){
            long long k=R-L+j;
            for(int s=0;s<=n*(n-1)/2;s++){
                if(b[i].op==1){
                    if(j>0)dp[i][j][s]+=dp[i-1][j-1][s],dp[i][j][s]%=mod;
                    if(s>=b[i].num)dp[i][j][s]+=dp[i-1][j][s-b[i].num]*(k+1),dp[i][j][s]%=mod;
                }
                else{
                    dp[i][j][s]+=dp[i-1][j][s],dp[i][j][s]%=mod;
                    if(s>=b[i].num)dp[i][j][s]+=dp[i-1][j+1][s-b[i].num]*(j+1),dp[i][j][s]%=mod;
                }
            }
        }
    }
    for(int i=sum;i<=n*(n-1)/2;i++){
        ans+=dp[2*n][0][i],ans%=mod;
    }
    printf("%lld",ans);
    return 0;
}