P10879 对树链剖分的爱 题解

· · 题解

我们考虑对于一次操作的 u, v,不妨设 u\lt v,则 (v, f_v) 这条边一定被加上了 w,证明显然。

于是我们可以考虑这样一个做法:每次给 v 的答案加上 w,然后以相同的概率跳到 f_v\in [l_v, r_v] 的其中一个父亲,同时令 w \leftarrow \dfrac{w}{r_v - l_v + 1},此时该问题被我们递归到了一个规模更小的子问题。

发现这个形式十分像 dp,我们不妨将一次操作先挂在 g_{u, v} 上,m 次操作结束后,我们对 g 进行 dp,设当前状态为 g_{u, v}(u\lt v),枚举 v 的父亲 k\in [l_v, r_v],有转移:

g_{u, k} \leftarrow g_{u, k} + \dfrac{g_{u, v}}{r_v - l_v + 1}

这里同样钦定 u\lt k,若 u\gt k 需要将 g_{u, k} 改为 g_{k, u},时间复杂度 \mathcal{O}(n^3 + m)

最后点 u 的答案即为 \sum\limits_{u\gt v}g_{v, u}

我们考察这个转移形式,发现相当于是给 g 数组的一段 L 形区域加上了一个固定值,使用二维差分优化即可,时间复杂度 \mathcal{O}(n^2 + m)

结合代码可能比较好理解。

#include <bits/stdc++.h>

using i64 = long long;

constexpr int N = 5e3 + 5, P = 998244353;

using Z = ModInt<P>;

int n, m;
int l[N], r[N];
Z g[N][N], inv[N], ans[N];

int main() {

    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    std::cin >> n;  
    for (int i = 2; i <= n; i++) std::cin >> l[i] >> r[i];
    for (int i = 1; i <= n; i++) inv[i] = Z(1) / i;

    auto add = [&](int l1, int r1, int l2, int r2, Z val) -> void {
        g[r1][r2] += val;
        g[l1 - 1][r2] -= val;
        g[r1][l2 - 1] -= val;
        g[l1 - 1][l2 - 1] += val;
    };

    std::cin >> m;
    for (int i = 1; i <= m; i++) {
        int u, v, w; std::cin >> u >> v >> w;
        if (u == v) continue;
        if (u > v) std::swap(u, v);
        add(u, u, v, v, w);
    }

    for (int j = n; j >= 1; j--) {
        for (int i = n; i >= 1; i--) {
            g[i][j] += g[i + 1][j] + g[i][j + 1] - g[i + 1][j + 1]; 

            // k\in [l, r], i
            // k <= i, g[k][i] += g[i][j] * inv[r[j] - l[j] + 1];
            // k > i,  g[i][k] += g[i][j] * inv[r[j] - l[j] + 1];

            if (i < j) {
                if (i >= l[j]) add(l[j], std::min(r[j], i), i, i, g[i][j] * inv[r[j] - l[j] + 1]);
                if (i <= r[j]) add(i, i, std::max(l[j], i), r[j], g[i][j] * inv[r[j] - l[j] + 1]);                  
            }
        }
    }

    for (int i = 1; i <= n; i++) for (int j = i + 1; j <= n; j++) ans[j] += g[i][j];
    for (int i = 2; i <= n; i++) std::cout << ans[i] << " ";

    return 0;
}