P11292 【MX-S6-T4】「KDOI-11」彩灯晚会 - Solution
给定
n 个结点m 条边的 DAG,每个结点可能有k 种颜色,并且给定正整数l 。对于一个染色方案(总共k^n 种染色方案),其价值定义为每种颜色长度为l 的链的数量的平方之和,求所有染色方案的价值之和。
平方直接处理很难,但是我们有一种非常经典的做法是算两遍,简单来说我们把题目所求转化为,对于一个染色方案,求两条链,使得这两条链同色且长度均为
如果我们只是枚举钦定一条链,考虑它自己的颜色,和其它没限制的点的颜色,那么方案数应该是
相交的点不一定是连续的,直接做那就是设
如果令
设
不过消去
转移系数中没有和
具体怎么做呢,
这个转移实现中需要分步进行保证我们只是进行了两次一维卷积。
我们令
令
答案显然为
把
我们为什么要对于
注意到我们任何转移必定都是
其实有一种理解方法是将被钦定的点看成
k - 1 种颜色,至于没被钦定又在链上,也可以看成一种新的颜色,因此其它仍然按k 种颜色算。
总复杂度
int n, k, L, M; vec<pii> e[305]; int q[355], hd = 1, tl = 0, rd[305], a[305], id[305], p;
ll c[305][305][22], cs[305][22], ct[305][22], f[305][22][22], g[305][22][22], ans = 0;
// c[l][u][v] u 到 v 长度为 l 的路径数量,cs[l][u] u 出发长度为 l 的路径数量,ct[l][v] 到 v 结束长度为 l 的路径数量
// f[u][l1][l2],t 这一维被省略
void solve() {
scanf("%*d%d%d%d%d", &n, &k, &L, &M); --L;
for (int i = 1, u, v, c; i <= M; i++) scanf("%d%d%d", &u, &v, &c), e[u].pb({ v, c }), ++rd[v];
rep(i, 1, n) if (!rd[i]) q[++tl] = i;
while (hd <= tl) {
int u = q[hd++]; a[++p] = u; id[u] = p;
for (auto [v, w] : e[u]) if (--rd[v] == 0) q[++tl] = v;
}
rep(u, 1, n) for (auto [v, w] : e[u]) add(c[id[u]][id[v]][1], w);
rep(l, 2, L) rep(i, 1, n) rep(j, i + 1, n) if (c[i][j][l - 1]) rep(k, j + 1, n) if (c[j][k][1]) add(c[i][k][l], c[i][j][l - 1] * c[j][k][1]);
rep(i, 1, n) cs[i][0] = ct[i][0] = 1;
rep(l, 1, L) rep(i, 1, n) rep(j, 1, n) add(cs[i][l], c[i][j][l]), add(ct[j][l], c[i][j][l]);
rep(i, 1, n) add(ans, cs[i][L]); ans = ans * ans % mod;
rep(i, 1, n) rep(l1, 0, L) rep(l2, 0, L) g[i][l1][l2] = ct[i][l1] * ct[i][l2] % mod * (k - 1) % mod;
rep(i, 1, n) {
rep(j, i + 1, n) {
rep(l1, 0, L) {
rep(l2, 0, L) {
if (!g[i][l1][l2]) continue;
rep(p1, 0, L - l1) if (c[i][j][p1]) add(f[j][l1 + p1][l2], g[i][l1][l2] * c[i][j][p1]);
}
}
}
rep(j, i + 1, n) {
rep(l1, 0, L) {
rep(l2, 0, L) {
if (!f[j][l1][l2]) continue;
rep(p2, 0, L - l2) if (c[i][j][p2]) add(g[j][l1][l2 + p2], f[j][l1][l2] * (k - 1) % mod * c[i][j][p2]);
f[j][l1][l2] = 0;
}
}
}
}
rep(i, 1, n) rep(l1, 0, L) rep(l2, 0, L) add(ans, g[i][l1][l2] * cs[i][L - l1] % mod * cs[i][L - l2]);
int t = n - 2 * (L + 1) + 1;
printf("%lld\n", ans * (t > 0 ? ksm(k, t) : ksm(ksm(k), -t)) % mod);
}