P11292 【MX-S6-T4】「KDOI-11」彩灯晚会 - Solution

· · 题解

给定 n 个结点 m 条边的 DAG,每个结点可能有 k 种颜色,并且给定正整数 l。对于一个染色方案(总共 k^n 种染色方案),其价值定义为每种颜色长度为 l 的链的数量的平方之和,求所有染色方案的价值之和。

平方直接处理很难,但是我们有一种非常经典的做法是算两遍,简单来说我们把题目所求转化为,对于一个染色方案,求两条链,使得这两条链同色且长度均为 l 的方案数。

如果我们只是枚举钦定一条链,考虑它自己的颜色,和其它没限制的点的颜色,那么方案数应该是 k^{n - l + 1}。但是现在我们钦定了两条链,这两条链是可能有交叉的。我们假设它们有 t 个点是相交的,方案数就是 k^{n - 2l + t + 1}

相交的点不一定是连续的,直接做那就是设 f_{i,\,x,\,y,\,l_1,\,l_2} 代表考虑了拓扑序下的前 i 个点,第一条路径末尾在 x 长度为 l_1,第二条路径末尾在 y 长度为 l_2 的贡献。转移就是枚举下一个点是不是同一个点(交叉点),如果是那么乘上 k,不然无事发生。复杂度 \Theta(n^3l^2)

如果令 h_t 代表相交恰好 t 个点的方案数,则答案为 \sum k^{n - 2l + t - 1} h_t,注意到恰好不很好做,现在考虑钦定一些点必须是相交的(其它点相交不相交任意)然后容斥。

f_{x,\,l_1,\,l_2,\,t},其中 x 是最后一个相交的结点,t 是相交的点数,l_1,\,l_2 的意义同上,我们要枚举下一个相交的点,还要枚举两个道路到下一个相交点各自走了多少,复杂度 \Theta(n^2l^5),你发现好像还更差了!

不过消去 l 应该是比消去 n 好做的。

转移系数中没有和 l_1,\,l_2 具体相关的项,所以可以各自转移,复杂度 \Theta(n^2l^4)

具体怎么做呢,f 的定义如上,设 c_{l,\,u,\,v} 代表 u 走到 v 恰好 l 条边的方案数,有转移:

f_{u,\,l_1,\,l_2,\,t} \times c_{p_1,\,u,\,v} \times c_{p_2,\,u,\,v} \to f_{v,\,l_1 + p_1,\,l_2 + p_2,\,t + 1}

这个转移实现中需要分步进行保证我们只是进行了两次一维卷积。

我们令 F_t = \sum\limits_{u} f_{u,\,l,\,l,\,t},此时 F_t 的意义是钦定了两条长度为 l 的路径 t 个点是相交的方案数。

G_t 是恰好 t 个点相交的方案数,形式是至少 t 个相交转恰好 t 个相交,自然是二项式反演。根据二项式反演,我们有:

\begin{aligned} & F_i = \sum_{i \le j \le l} \binom{j}{i} G_j \\ & \Rightarrow\\ & G_i = \sum_{j = i}^l \binom{j}{i} (-1)^{j - i} F_j \end{aligned}

答案显然为 \sum\limits_{i = 0}^{l} k^{n - 2l + i + 1} G_i

G_i 展开,并且把 i 有关的项合并到一块。

\begin{aligned} & \sum_{i = 0}^{l} k^{n - 2l + i + 1} G_i \\ & = \left(\sum_{i = 0}^{l} k^{n - 2l + i + 1} \right) \left( \sum_{j = i}^l \binom{j}{i} (-1)^{j - i} F_j \right) \\ & = \left( \sum_{i = 0}^{l} k^{n - 2l + 1} \right) \left( \sum_{j = i}^l \binom{j}{i} (-1)^{j - i} F_j k^i \right) \\ & = k^{n - 2l + 1} \left( \sum_{j = 0}^l F_j \right) \left( \sum_{i = 0}^j \binom{j}{i} (-1)^{j - i} k^i \right) \\ & = k^{n - 2l + 1} \sum_{j = 0}^{l} F_j(k - 1)^j \end{aligned}

我们为什么要对于 f_{u,\,l,\,l,\,t} 记录 t 这一维,因为我们要根据 t 来计算容斥系数,但是现在计算出 F_t 之后,我们跟 t 有关的只剩下了 (k - 1)^t,那么我们显然可以直接把 (k - 1) 这个东西提前计算,乘入 f_{?,\,?,\,?,\,t} \to f_{?,\,?,\,?,\,t + 1} 的转移中。

注意到我们任何转移必定都是 t \to t + 1 的转移,因此 t 这一维现在变得完全没有必要,可以删去。

其实有一种理解方法是将被钦定的点看成 k - 1 种颜色,至于没被钦定又在链上,也可以看成一种新的颜色,因此其它仍然按 k 种颜色算。

总复杂度 \Theta(n^3l + n^2l^3),实际上可以把所有东西先 DFT 一遍,过程中只进行点乘来优化卷积,做到 \Theta(n^3l + n^2l^2) 的复杂度,但对于本题应该没有必要。

int n, k, L, M; vec<pii> e[305]; int q[355], hd = 1, tl = 0, rd[305], a[305], id[305], p;
ll c[305][305][22], cs[305][22], ct[305][22], f[305][22][22], g[305][22][22], ans = 0;
// c[l][u][v] u 到 v 长度为 l 的路径数量,cs[l][u] u 出发长度为 l 的路径数量,ct[l][v] 到 v 结束长度为 l 的路径数量
// f[u][l1][l2],t 这一维被省略
void solve() {
    scanf("%*d%d%d%d%d", &n, &k, &L, &M); --L;
    for (int i = 1, u, v, c; i <= M; i++) scanf("%d%d%d", &u, &v, &c), e[u].pb({ v, c }), ++rd[v];
    rep(i, 1, n) if (!rd[i]) q[++tl] = i;
    while (hd <= tl) {
        int u = q[hd++]; a[++p] = u; id[u] = p;
        for (auto [v, w] : e[u]) if (--rd[v] == 0) q[++tl] = v;
    }
    rep(u, 1, n) for (auto [v, w] : e[u]) add(c[id[u]][id[v]][1], w);
    rep(l, 2, L) rep(i, 1, n) rep(j, i + 1, n) if (c[i][j][l - 1]) rep(k, j + 1, n) if (c[j][k][1]) add(c[i][k][l], c[i][j][l - 1] * c[j][k][1]);
    rep(i, 1, n) cs[i][0] = ct[i][0] = 1;
    rep(l, 1, L) rep(i, 1, n) rep(j, 1, n) add(cs[i][l], c[i][j][l]), add(ct[j][l], c[i][j][l]);
    rep(i, 1, n) add(ans, cs[i][L]); ans = ans * ans % mod;
    rep(i, 1, n) rep(l1, 0, L) rep(l2, 0, L) g[i][l1][l2] = ct[i][l1] * ct[i][l2] % mod * (k - 1) % mod;
    rep(i, 1, n) {
        rep(j, i + 1, n) {
            rep(l1, 0, L) {
                rep(l2, 0, L) {
                    if (!g[i][l1][l2]) continue;
                    rep(p1, 0, L - l1) if (c[i][j][p1]) add(f[j][l1 + p1][l2], g[i][l1][l2] * c[i][j][p1]);
                }
            }
        }
        rep(j, i + 1, n) {
            rep(l1, 0, L) {
                rep(l2, 0, L) {
                    if (!f[j][l1][l2]) continue;
                    rep(p2, 0, L - l2) if (c[i][j][p2]) add(g[j][l1][l2 + p2], f[j][l1][l2] * (k - 1) % mod * c[i][j][p2]);
                    f[j][l1][l2] = 0;
                }
            }
        }
    }
    rep(i, 1, n) rep(l1, 0, L) rep(l2, 0, L) add(ans, g[i][l1][l2] * cs[i][L - l1] % mod * cs[i][L - l2]);
    int t = n - 2 * (L + 1) + 1;
    printf("%lld\n", ans * (t > 0 ? ksm(k, t) : ksm(ksm(k), -t)) % mod);
}