题解:P14364 [CSP-S 2025] 员工招聘 / employ(民间数据)

· · 题解

好像是正赛场切过的最牛的题。

考虑 dp。先考虑设 f_{i,j} 表示前 i 天寄了 j 个人的方案数。但是发现我们要往后填的时候,我们不知道此时还有哪些数能用,比较倒闭。而且这个也很难塞到状态里。

上述做法做不了的本质是,我们前面决策选了后面的哪个 c_x 会对后面每个时刻能选的数产生影响。因此我们考虑一个经典 trick:贡献延后计算。具体地,我们可以考虑不在加入一个人时就计算它的方案数,而是留到后面“有用”时再来计算。由于每个时刻一个人 k 能不能成功,本质上只与是否 c_x > j 有关,因此我们自然想到在 j 增加到第一次 c_x = j 时去考虑其贡献。

我们加入一维 kf_{i,j,k} 表示前 i 天寄了 j 个人,当前有 k 个满足 c_x>j 的人已经被面试过。k 的本质就是我们提前钦定了多少个位置,现在还剩这么多要在后面再填回去。

考虑转移:

f_{i,j,k} \to f_{i+1,j,k+1} f_{i,j,k} \times \binom{k+1}{l} \times \binom{cnt_{j+1}}{l} \times l! \to f_{i+1,j+1,k+1-l} f_{i,j,k} \times (pre_j-(i-k)) \times \binom{k}{l} \times \binom{cnt_{j+1}}{l} \times l! \to f_{i+1,j+1,k-l}

最后我们只需要枚举有几个人寄了即可,答案为

\sum_{j=0}^{n-m}f_{n,j,n-pre_j} \times (n-pre_j)!

还有一个问题:转移要枚举 l,这个不是 O(n^4) 的吗?仔细想想就会发现由于 l 不可能超过 cnt_{j+1},而对于固定的 ikcnt_{j+1} 的和不超过 n,所以其实是 O(n^3) 的。

记得滚动数组。

赛后回忆随手写的丑陋的代码,细节与题解略有出入,仅供参考:

#include <bits/stdc++.h>
using namespace std;
#define int long long
#define ull unsigned long long
#define pii pair<int,int>
#define fir first
#define sec second
#define chmin(a,b) a=min(a,b)
#define chmax(a,b) a=max(a,b)
#define pb push_back
const int inf=0x3f3f3f3f3f3f3f3f;
const int mod=998244353;
int n,m,dp[2][510][510],fac[510],C[510][510],P[510][510],ans;
string s;
int c[510],rm[510]; 
signed main()
{
    cin>>n>>m>>s;
    s=" "+s;
    for(int i=1;i<=n;i++)
    {
        int x;
        cin>>x;
        c[x]++;
        for(int j=0;j<x;j++)rm[j]++;
    }
    fac[0]=1;
    for(int i=1;i<=n;i++)fac[i]=fac[i-1]*i%mod;
    for(int i=0;i<=n;i++)
    {
        C[i][0]=1;
        for(int j=1;j<=i;j++)C[i][j]=(C[i-1][j]+C[i-1][j-1])%mod;
        for(int j=0;j<=i;j++)P[i][j]=C[i][j]*fac[j]%mod;
    }
    dp[0][0][0]=1;
    for(int i=1;i<=n;i++)
    {
        memset(dp[i&1],0,sizeof(dp[i&1]));
        bool ps=s[i]-'0';

        for(int j=0;j<i;j++)if(j<=n-m)
            for(int k=0;k<i;k++)
            {
                int tx=c[j+1];
                int tmp=dp[i&1^1][j][k];
                if(!tmp)continue;
                if(ps)(dp[i&1][j][k+1]+=tmp)%=mod;
                int pr=(n-rm[j])-(i-1-k);
                if(pr)for(int l=min(k,tx);~l;l--)(dp[i&1][j+1][k-l]+=tmp*pr%mod*C[k][l]%mod*P[tx][l]%mod)%=mod;
                if(!ps)for(int l=min(k+1,tx);~l;l--)(dp[i&1][j+1][k+1-l]+=tmp%mod*C[k+1][l]%mod*P[tx][l]%mod)%=mod;
            }
    }
    for(int i=0;i<=n-m;i++)(ans+=dp[n&1][i][rm[i]]*fac[rm[i]]%mod)%=mod;
    cout<<ans<<endl;
    return 0;
}