题解:P10773 [NOISG2021 Qualification] Truck

· · 题解

解题思路

首先,看到树上路径,第一眼想到树链剖分,本题解默认你已经掌握基于边的树链剖分。

如果我们倒着走,从终点走到起点,不难发现每走一段就是产生 D_i\cdot G 的费用并将 G 增加 T_i。假设此后都是倒着走。

由于在走的过程中 G 不断变化,故计算经过若干条边的费用时,需要引入额外费用 x 表示走的过程中由于 G 的变化产生的费用,则走过这些边的总代价为 D_{sum}\cdot G+x。可以得知依次走过两条链 a,b 的额外费用为 x_a+x_b+T_a\cdot D_b,而走过单段没有额外费用。实现时,由于一条链可以以两种方向走过,额外费用也要计算两份分别对应两种方向。

然后,通过树链剖分实现即可。

参考代码

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
const ll Mod = 1e9 + 7;

map<pair<int, int>, int> rev;
int n, q, logn;
ll g;
int dfn[100024], hson[100024], top[100024], fa[24][100024], dfc;
int ecnt, dep[100024], ev[200024], ex[200024], es[100024], rdfn[100024];
ll et[200024], ed[200024], pt[100024], pd[100024];

// 第一轮 DFS
int dfsr1(int u, int f) {
    int w = 1, hw = -1, tmp;
    for (int i = es[u]; i; i = ex[i]) {
        int v = ev[i];
        if (v == f) continue;
        dep[v] = dep[u] + 1;
        fa[0][v] = u;
        pt[v] = et[i];
        pd[v] = ed[i];
        rev[{u, v}] = rev[{v, u}] = v;
        tmp = dfsr1(v, u);
        w += tmp;
        if (tmp > hw) {
            hw = tmp;
            hson[u] = v;
        }
    }
    return w;
}

// 第二轮 DFS
void dfsr2(int u, int f) {
    rdfn[dfc] = u;
    dfn[u] = dfc++;
    if (!hson[u]) return;
    top[hson[u]] = top[u];
    dfsr2(hson[u], u);
    for (int i = es[u]; i; i = ex[i]) {
        int v = ev[i];
        if (v == f || v == hson[u]) continue;
        top[v] = v;
        dfsr2(v, u);
    }
}

// 加边
inline void add(int u, int v, int d, int t) {
    ev[++ecnt] = v;
    ex[ecnt] = es[u];
    et[ecnt] = t;
    ed[ecnt] = d;
    es[u] = ecnt;
    ev[++ecnt] = u;
    ex[ecnt] = es[v];
    et[ecnt] = t;
    ed[ecnt] = d;
    es[v] = ecnt;
}

struct Node {
    ll t, d, px, nx;
} seg[400024];

// 合并答案
inline Node pushup(const Node l, const Node r) {
    return {(l.t + r.t) % Mod, (l.d + r.d) % Mod, (l.px + r.px + l.t * r.d) % Mod, (l.nx + r.nx + l.d * r.t) % Mod};
}

// 单点修改
void modify(int id, int l, int r, int u, ll v) {
    if (l == r) {
        seg[id].t = v;
        return;
    }
    int mid = l + r >> 1;
    if (u <= mid) modify(id << 1, l, mid, u, v);
    else modify(id << 1 | 1, mid + 1, r, u, v);
    seg[id] = pushup(seg[id << 1], seg[id << 1 | 1]);
}

// 区间查询
Node query(int id, int l, int r, int L, int R) {
    if (l > R || r < L) return {0, 0, 0, 0};
    if (L <= l && r <= R) return seg[id];
    int mid = l + r >> 1;
    return pushup(query(id << 1, l, mid, L, R), query(id << 1 | 1, mid + 1, r, L, R));
}

// 建树
void build(int id, int l, int r) {
    if (l == r) {
        seg[id] = {pt[rdfn[l]], pd[rdfn[l]], 0, 0};
        return;
    }
    int mid = l + r >> 1;
    build(id << 1, l, mid);
    build(id << 1 | 1, mid + 1, r);
    seg[id] = pushup(seg[id << 1], seg[id << 1 | 1]);
}

// 查询 LCA 下面的点
pair<int, int> sublca(int a, int b) {
    if (dep[a] > dep[b]) {
        for (int i = logn; i >= 0 && dep[a] > dep[b]; i--) {
            if (dep[a] - dep[b] > (1 << i)) a = fa[i][a];
        }
        if (fa[0][a] == b) {
            return {a, -1};
        }
        a = fa[0][a];
    }
    else if (dep[a] < dep[b]) {
        for (int i = logn; i >= 0 && dep[a] < dep[b]; i--) {
            if (dep[b] - dep[a] > (1 << i)) b = fa[i][b];
        }
        if (fa[0][b] == a) {
            return {-1, b};
        }
        b = fa[0][b];
    }
    if (a == b) return {-1, -1};
    for (int i = logn; i >= 0; i--) if (fa[i][a] != fa[i][b]) {
        a = fa[i][a];
        b = fa[i][b];
    }
    return {a, b};
}

// 查询链
Node qchain(int u, int d) {
    Node res = {0, 0, 0, 0};
    while (dep[top[d]] > dep[u]) {
        res = pushup(query(1, 1, n - 1, dfn[top[d]], dfn[d]), res);
        d = fa[0][top[d]];
    }
    res = pushup(query(1, 1, n - 1, dfn[u], dfn[d]), res);
    return res;
}

// 查询完整答案
ll qlca(int a, int b) {
    auto[x, y] = sublca(a, b);
    if (~x) {
        if (~y) {
            auto le = qchain(x, a), ri = qchain(y, b);
            return (ri.d * g + le.d * (g + ri.t) + ri.nx + le.px) % Mod;
        }
        else {
            auto le = qchain(x, a);
            return (le.d * g + le.px) % Mod;
        }
    }
    else {
        if (~y) {
            auto ri = qchain(y, b);
            return (ri.d * g + ri.nx) % Mod;
        }
        else return 0;
    }
}

int main() {
    scanf("%d %lld", &n, &g);
    int u, v, w;
    ll d, t;
    for (int i = 1; i < n; i++) {
        scanf("%d %d %lld %lld", &u, &v, &d, &t);
        add(u, v, d, t);
    }
    fa[0][1] = 1;
    top[1] = 1;
    dfsr1(1, -1);
    dfsr2(1, -1);
    build(1, 1, n - 1);
    logn = log2(n);
    for (int i = 1; i <= logn; i++) for (int j = 1; j <= n; j++) {
        fa[i][j] = fa[i - 1][fa[i - 1][j]];
    }
    scanf("%d", &q);
    while (q--) {
        scanf("%d %d %d", &u, &v, &w);
        if (u) {
            printf("%lld\n", qlca(v, w));
        }
        else {
            scanf("%lld", &t);
            int e = rev[{v, w}];
            modify(1, 1, n - 1, dfn[e], t);
        }
    }
    return 0;
}