题解 CF995F 【Cowmpany Cowmpensation】

· · 题解

题意

给定一棵树,你要给每个点定一个 [1,D] 的权值,满足每个点权值不大于父亲。问有多少种方案,答案模 10^9+7

题解

由于我太菜不会插值,只好用容斥来做了。

我们考虑一个DP,用 f_{i,j} 表示在节点 i 的子树中,令节点 i 的权值为 j 的合法方案数。记 s_{i,j} = \sum_{k=1}^j f_{i,k} ,由于不同子树之间互相独立,很容易导出转移方程:

f_{i,j} = \prod_{fa_v = i} s[v][j]

但是,由于 d \leq 10^9 ,直接上这种做法是肯定不行的。观察发现,题目中 n \leq 3000 ,相比之下特别小,所以考虑转换一种思路。实际上,总共只会有至多 n 种本质不同的权值,所以我们可以将它们离散化,然后再计算 f_{i,j} 。此时, f_{i,j}j 的意义就变为节点 i 的权值为离散化之后的第 j 大的权值。

计算出 f_{i,j} 之后,下一步考虑如何计算答案。离散化之后,我们应该计算出恰好用了 i 种权值的方案,记为 g_i 。显然,直接统计会导致重复计数,例如样例一中的 (2,2,2) 就会被同时计入 g_1g_2 。所以,我们可以考虑容斥。

由定义可知, f_{1,i} 指最大权值为 i 的方案数,完全包含了 g_i 。所以, g_i 中我们要删去的不合法方案一定是包含权值 i 的一些方案,其权值种数小于 i 。考虑权值种数恰为 j 的方案数 g_j ,则它的贡献为 g_i \times C_{i-1}^{i-j}
其中,系数 C_{i-1}^{j-1} 表示除了权值 i 必须选中之外,余下的 i-1 种权值中选出 j-1 种权值,组成一种权值种数恰为 j 的权值选取方案的方案数。也就是,对于 g_j 所包含的所有方案数,它们都将在 f_{1,i} 中出现 C_{i-1}^{j-1} 次。

接着,我们就可以推出最终的容斥方程和答案:

g_i = f_{1,i} - \sum_{j=1}^{i-1} C_{i-1}^{i-j} \times g_j Ans = \sum_{i=1}^{\min(n,d)} C_{d}^{i} \times g_i

这里,由于 i 远小于 d ,所以我们可以直接暴力计算 C_{d}^{i} ,得出答案。

代码

#include<bits/stdc++.h>
#define reg register
typedef long long ll;
using namespace std;
const int MN=3005;
const int mod=1e9+7;
int to[MN],nxt[MN],h[MN],cnt;
inline void ins(int s,int t){
    to[++cnt]=t;nxt[cnt]=h[s];h[s]=cnt;
}
int n,m,f[MN][MN],s[MN][MN],g[MN],C[MN][MN],inv[MN];
void dfs(int st){
    for(reg int i=h[st];i;i=nxt[i]){
        dfs(to[i]);
        for(reg int j=1;j<=n;j++)
            f[st][j]=1ll*f[st][j]*s[to[i]][j]%mod;
    }
    for(reg int i=1;i<=n;i++)
        s[st][i]=(s[st][i-1]+f[st][i])%mod;
}
int main(){
    scanf("%d%d",&n,&m);
    for(reg int i=2,x;i<=n;i++)scanf("%d",&x),ins(x,i);
    for(reg int i=1;i<=n;i++)for(reg int j=1;j<=n;j++)f[i][j]=1;
    dfs(1);C[0][0]=C[1][0]=C[1][1]=1;
    for(reg int i=2;i<=n;i++){
        C[i][0]=1;
        for(reg int j=1;j<=i;j++)
            C[i][j]=(C[i-1][j]+C[i-1][j-1])%mod;
    }
    inv[0]=inv[1]=1;
    for(reg int i=2;i<=n;i++)inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
    for(reg int i=2;i<=n;i++)inv[i]=1ll*inv[i-1]*inv[i]%mod;
    for(reg int i=1;i<=n;i++){
        g[i]=f[1][i];
        for(reg int j=1;j<i;j++)
            g[i]=(g[i]-1ll*C[i-1][i-j]*g[j]%mod+mod)%mod;
    }
    reg int Ans=0;
    for(reg int i=1;i<=min(n,m);i++){
        reg int Res=inv[i];
        for(reg int j=m-i+1;j<=m;j++)Res=1ll*Res*j%mod;
        Ans=(Ans+1ll*Res*g[i]%mod)%mod;
    }
    printf("%d\n",Ans);
    return 0;
}