题解:P12445 [COTS 2025] 数好图 / Promet

· · 题解

注意到若把整张 DAG 分成三类点,一类是符合题目条件的点,一类是从 1 号点可以到达的点,一类是从 1 号点不可到达的点。那么合法点连边方法有 ① -> ②,③ -> ①,③ -> ②。他们之间相互独立,也就是不同类之间的连边合法性只与类的种类有关,故可以将他们分开考虑,再进行合并,以降低问题复杂度。

考虑全为 ① 类点,即只 k = n 的情况怎么做。容易地设计 O(n^3) 的 dp:f(i,j) 表示前 i 个点,有 j 个不同分支(即有 j 个出度为 0 的点),DAG 的方案数。枚举 k 表示前面连接了多少出度为 0 的点,f(i,j) = \sum f(i-1,j+k-1) \times C_{j+k-1}^{k} \times 2^{(i-1)-(j+k-1)}我会FFT但是难以优化。

所以难以想到用容斥原理的思路设计状态。相当于维护每个节点入度出度不为 0,保证入度不为 0 是好做的。我们主要要维护出度为 0 的点,那么设 f(i,j) 为前 i 个点,钦定 j 个点出度为 0,那么可以转移 f(i,j) = f(i-1,j) \times (2^{i-1-j}-1) + f(i-1,j-1) \times (2^{i-j}-1) ,然后通过容斥公式用 f(i) 计算出使用 i 个点,没有出度为 0 的点的 DAG 的方案数。时间复杂度 O(n^2)

若加入 ② 类点,就相当于在 ① 类点有 i 个的基础上,将方案数乘上 j 个 ② 类点自己连边和与 ① 类点连边的方案数,得到符合条件的点数为 i,总点数为 i+j 的答案。这直接 dp 就好了。加 ③ 类点同理。具体写法看代码。

这个故事告诉我们,对计数 dp 的优化束手无策的时候,也许可以试试容斥。

#include <bits/stdc++.h>
using namespace std;
int n, p;
int fac[2002], inv[2002], invfac[2002];
void init(){
    fac[0] = 1, inv[1] = 1, invfac[0] = 1;
    for(int i=1;i<=2000;i++) fac[i] = 1ll * fac[i-1] * i % p;
    for(int i=2;i<=2000;i++) inv[i] = (1ll * (-p/i) * inv[p%i] % p + p) % p;
    for(int i=1;i<=2000;i++) invfac[i] = 1ll * invfac[i-1] * inv[i] % p;
}
int C(int x,int y){
    if(x < y || x < 0 || y < 0) return 0;
    return 1ll * fac[x] * invfac[y] % p * invfac[x-y] % p;
}
int ksm(int s,int cnt){
    int ret = 1;
    while(cnt){
        if(cnt&1) ret = 1ll * ret * s % p;
        s = 1ll * s * s % p;
        cnt >>= 1;
    }
    return ret;
}

int f[2002][2002], g[2002][2002]; // 前 i 个数,钦定 j 个没有出度
int S[2002]; 

int dp1[2002][2002], dg1[2002][2002], dp2[2002][2002];
int Ans[2002];

int main(){
    scanf("%d%d",&n,&p);
    init();

    f[0][0] = 1, g[0][0] = 1;
    f[1][0] = 1, f[1][1] = 1, g[1][0] = 1;
    for(int i=2;i<=n;i++) {
        for(int j=0;j<=i;j++) {
            if(i != j) f[i][j] = 1ll * f[i-1][j] * (ksm(2, i-j-1) - 1) % p;
            if(i != j) g[i][j] = 1ll * f[i-1][j] * (ksm(2, i-j-1) - 1) % p;
            if(j > 0) f[i][j] = (1ll * f[i][j] + 1ll * f[i-1][j-1] * (ksm(2, i-j) - 1) % p) % p;
        }
    }
    for(int i=0;i<=n;i++) {
        for(int j=0;j<=i;j++){
            if(j & 1) S[i] = ((1ll * S[i] - g[i][j]) % p + p) % p;
            else S[i] = (1ll * S[i] + g[i][j]) % p;
        }
    }

    dp1[0][0] = 1;
    for(int i=1;i<=n;i++){
        for(int j=0;j<=i;j++){
            dp1[i][j] = dp1[i-1][j];
            dg1[i][j] = dp1[i-1][j];
            if(i != 1 && i != n) dp1[i][j] = (dp1[i][j] + 1ll * dp1[i-1][j-1] * (ksm(2, i-1)-1) % p) % p;
        }
    }

    dp2[0][0] = 1;
    for(int i=1;i<=n;i++){
        for(int j=0;j<=i;j++){
            dp2[i][j] = 1ll * dp2[i-1][j] * ksm(2, j) % p;
            if(i != 1 && i != n && j != 0) dp2[i][j] = (dp2[i][j] + 1ll * dp2[i-1][j-1] * ksm(2, j-1) % p) % p;
        }
    }

    //S[i] * dg1[k][k-i] * dp2[n][n-k]
    for(int i=0;i<=n;i++){
        for(int k=i;k<=n;k++){
            Ans[i] = (Ans[i] + 1ll * S[i] * dg1[k][k-i] % p * dp2[n][n-k] % p) % p;
        }
    }
    Ans[0] = Ans[2];

    for(int i=0;i<=n;i++) printf("%d ",Ans[i]);

    return 0;
}