题解 P5642 【 人造情感(emotion)】

· · 题解

膜拜 EI 鸽鸽!/崇拜

给一个 O(m \log^2n) 的树剖做法,不过跑得好像比一些 1log 的要快一些(大雾)。

考虑对于给定的集合 S 我们如何求出 W(S)。这其实是一个经典的问题,这里就直接给出做法了。我们设:

对于路径 (u,v) 我们将它存到 \mathrm{lca}(u,v) 对应的 vector 里,有转移:

g_u = \sum_{v \in son_u} f_v f_u = g_u + \max\{w_i + \sum_{j \in \mathrm{path}(u_i,v_i)} c_j\} \ \cup \ \{0\}

对于 \sum_{j \in \mathrm{path}(u_i,v_i)} c_j 可以使用树状数组优化,具体来说可以将链求和差分成点到根求和,单点加、点到根求和相当于子树加、单点求值,可以直接树状数组。对初始的路径集合 U 做一遍这个 DP,那么 W(U) 的值即为 f_1

接着,我们考虑如何求出 W(U \ \cup \ \{u,v,w+1\})。显然,若 \mathrm{lca}(u,v) = 1,那么答案即为 f_1 +\sum_{j \in \mathrm{path}(u_,v_)} c_j\ + (w+1)。但如果 \mathrm{lca}(u,v) 不为 1,那么 f_1 所对应的路径集合中可能存在某条路径从上面穿过了 \mathrm{lca}(u,v),我们必须想办法消除此类路径对答案造成的影响。于是我们可以另设 h_u 表示不存在一条路径经过 (u,fa_u) 的情况下 f_1 的最大值。那么有转移:

\begin{cases} h_v \leftarrow h_u + c_u, & \text{$v \in son_u$}\\ h_v \leftarrow h_u + \sum_{j \in \mathrm{path}(u_i,v_i)} c_j + w_i,& \text{$\mathrm{lca}(u_i,v_i) = u, v \notin \mathrm{path}(u_i,v_i),fa_v \in \mathrm{path}(u_i,v_i)$} \end{cases}

对这两种转移的解释:

第一种转移直接做即可;但对于第二种转移,如果我们暴力枚举所有路径进行判断,时间复杂度就达到了惊人的 O(nm \log n),因此我们希望转移的过程中不会枚举到不合法的路径。考虑某一条路径 (u_i,v_i),我们可以在 \mathrm{lca}(u_i,v_i) = u 处给 (u_i,v_i) 上的所有点打上 h_u + \sum_{j \in \mathrm{path}(u_i,v_i)} c_j + w_i 的 tag,转移的时候就直接把 tag 下传到儿子取 \max,但同时不能传到路径上的点。

先考虑一个暴力的维护,我们对每一个点开一个 vector 维护在它身上的 tag,每个 tag 都是一个形如 (w,a,b) 的三元组,表示权值为 w,并且不能传到 a,b 两个儿子(如果不存在这么多限制就将剩下的位置赋值为 0,显然只有在 \mathrm{lca}a,b 才可能都不为 0)。下传标记时先把所有 tag 按权值从大到小排序,对于每个儿子找到第一个可以转移的 tag 进行转移,显然每个 tag 最多只会被扫到 2 遍,因此复杂度为 O(标记数)

但可惜的是,即使避免了多余的枚举,标记数量仍然可能很多,时间复杂度并没有实质上的优化。还能再给力点吗?

对给定的树进行轻重链剖分,这时我们会惊喜地发现,大多数 tag 都形如 (w,son_u,0)son_u 表示 u 的重儿子),只有在重链链底,即轻重链切换的地方需要单独考虑。当然,路径端点和 \mathrm{lca} 上的 tag 也是特殊的。于是我们可以维护每个点上 (w,son_u,0) 的 tag 的最大 w 值,这只需要使用线段树支持区间取 \max,单点查询。而对于特殊的 tag 就按暴力的方式进行维护即可。每条路径至多贡献 O(\log n) 个 tag,因此 tag 总数是 O(m \log n) 的。

求出 h_u 后,显然有 f(u,v) = f_1 - h_{\mathrm{lca}(u,v)} - \sum_{j \in \mathrm{path}(u,v)} c_j。考虑如何对所有 f(u,v) 求和,显然我们只需计算每一项的贡献,即对每个点 u 求出有多少路径的 \mathrm{lca} 恰为 u,以及有多少路径穿过 u,简单容斥即可。

注意操作顺序,先打 tag 后转移。代码写得有点丑,调试懒得删了就这样将就看看?

const int MN = 4e5 + 5;
const int Mod = 998244353;

// #define dbg

int N, M, ans;
vector <int> G[MN];

struct BIT {
    int c[MN];
    inline int lowbit(int x) {
        return x & (-x);
    }
    inline void Modify(int x, int k) {
        for (int i = x; i <= N; i += lowbit(i))
            c[i] += k;
    }
    inline void Modify(int l, int r, int k) {
        Modify(l, k), Modify(r + 1, -k);
    }
    inline int Query(int x) {
        if (!x)
            return 0;
        int res = 0;
        for (int i = x; i; i -= lowbit(i))
            res += c[i];
        return res;
    }
} T;

int siz[MN], son[MN], dep[MN], fa[MN], fat[MN][25];
inline void DFS0(int u, int pr) {
    siz[u] = 1, fat[u][0] = pr;
    dep[u] = dep[fa[u] = pr] + 1;
    for (int i = 1; i <= 20; i++)
        fat[u][i] = fat[fat[u][i - 1]][i - 1];
    for (int v : G[u]) 
        if (v != pr) {
            DFS0(v, u);
            siz[u] += siz[v];
            if (siz[v] > siz[son[u]])
                son[u] = v;
        }
}
int top[MN], dfn[MN], dfc, st[MN], ed[MN];
inline void DFS1(int u, int tp) {
    top[u] = tp;
    dfn[u] = st[u] = ++dfc;
    if (son[u]) {
        DFS1(son[u], tp);
        for (int v : G[u])
            if (v != fa[u] && v != son[u]) 
                DFS1(v, v);
    }
    ed[u] = dfc;
}
inline int LCA(int x, int y) {
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) swap(x, y);
        x = fa[top[x]];
    }
    if (dep[x] > dep[y]) swap(x, y);
    return x;
}
inline int climb(int x, int y) {
    for (int i = 20; ~i; i--)
        if (dep[fat[x][i]] > dep[y]) x = fat[x][i];
    return x;
}

#define PI3 pair < pii, int >
#define mp3(x, y, z) mp(mp(x, y), z)

int f[MN], g[MN], c[MN];
vector <PI3> vr[MN];
inline int Qry(int x, int y, int r) {
    return T.Query(st[x]) + T.Query(st[y]) - T.Query(st[fa[r]]) - T.Query(st[r]);
}
inline void DFS2(int u) {
    for (int v : G[u])
        if (v != fa[u]) 
            DFS2(v), g[u] += f[v];
    int del = 0;
    for (auto [p, w] : vr[u]) {
        del = max(del, Qry(p.fi, p.se, u) + w);
#ifdef dbg
        printf("Qry(%lld %lld) = %lld\n", p.fi, p.se, Qry(p.fi, p.se, u));
#endif
    }
    f[u] = g[u] + del;
    c[u] = g[u] - f[u];
    T.Modify(st[u], ed[u], c[u]);
}

const int MS = MN << 2;
#define ls o << 1
#define rs o << 1 | 1
#define mid ((l + r) >> 1)
#define LS ls, l, mid
#define RS rs, mid + 1, r
int tr[MS], tg[MS];
inline void Mdf(int o, int l, int r, int L, int R, int v) {
    if (r < L || l > R || L > R) return;
    if (L <= l && R >= r) return tg[o] = max(tg[o], v), void();
    Mdf(LS, L, R, v), Mdf(RS, L, R, v);
}   
inline int Qry(int o, int l, int r, int p) {
    if (l == r) return tg[o];
    return max(tg[o], p <= mid ? Qry(LS, p) : Qry(RS, p));
}

struct Dat {
    int w, a, b;
    Dat() {}
    Dat(int W, int X, int Y) : w(W), a(X), b(Y) {}
    inline bool operator < (const Dat &p) const {
        return w > p.w;
    }
};
int h[MN];
vector <Dat> tag[MN];
inline void Cov(int x, int y, int w, int r) {
    int u = x, v = y;
    if (dep[x] < dep[y])
        swap(x, y);
    tag[x].pb(Dat(w, 0, 0));
    if (y != r)
        tag[y].pb(Dat(w, 0, 0));
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) 
            swap(x, y);
        Mdf(1, 1, N, dfn[top[x]], dfn[fa[x]], w);
        if (fa[top[x]] != r)
            tag[fa[top[x]]].pb(Dat(w, top[x], 0));
        x = fa[top[x]];
    }
    u = climb(u, r);
    v = climb(v, r);
    tag[r].pb(Dat(w, u, v));
    if (x == y)
        return;
    if (dep[x] < dep[y]) 
        swap(x, y);
    int p = climb(x, y);
    Mdf(1, 1, N, dfn[p], dfn[fa[x]], w);
}
inline void DFS3(int u) {
    for (auto [p, w] : vr[u]) {
        int x = p.fi, y = p.se;
        Cov(x, y, h[u] + Qry(x, y, u) + w, u);
#ifdef dbg
        printf("%lld : pushtag (%lld, %lld, %lld)\n", u, x, y, h[u] + Qry(x, y, u) + w);
#endif
    }
    sort(tag[u].begin(), tag[u].end());
    int w = Qry(1, 1, N, dfn[u]);
    for (int v : G[u]) 
        if (v != fa[u]) {
            h[v] = max(h[u] + c[u], 0);
            if (v != son[u]) 
                h[v] = max(h[v], w);
            for (auto tg : tag[u])
                if (tg.a != v && tg.b != v) {
                    h[v] = max(h[v], tg.w);
                    break;
                }
            DFS3(v);
        }
}

signed main(void) {
    N = read(), M = read();
    for (int i = 1; i < N; i++) {
        int u = read(), v = read();
        G[u].pb(v), G[v].pb(u);
    }
    DFS0(1, 0);
    DFS1(1, 1);
    for (int i = 1; i <= M; i++) {
        int x = read(), y = read(), w = read();
        int r = LCA(x, y);
#ifdef dbg
    printf("LCA(%lld, %lld) = %lld\n", x, y, r);
#endif
        vr[r].pb(mp3(x, y, w));
    }
    DFS2(1);
    h[1] = f[1];
#ifdef dbg
    puts("----------------");
#endif
    DFS3(1);
#ifdef dbg
    puts("----------------");
    printf("maxW(U) = %lld\n", f[1]);
    for (int i = 1; i <= N; i++)
        printf("%lld : f = %lld, g = %lld, c = %lld, h = %lld\n", i, f[i], g[i], c[i], h[i]);
    puts("----------------");
#endif
    ans = f[1] % Mod * N % Mod * N % Mod;
    for (int i = 1; i <= N; i++)
        c[i] = Mod + c[i] % Mod, h[i] %= Mod;
    for (int u = 1; u <= N; u++) {
        int Coef = N * N % Mod;
        Dec(Coef, (N - siz[u]) * (N - siz[u]) % Mod);
        int w = 0;
        for (int v : G[u])  
            if (v != fa[u]) Add(w, siz[v] * siz[v] % Mod);
        Dec(Coef, w);
        Dec(ans, Coef * c[u] % Mod);
        Coef = siz[u] * siz[u] % Mod;
        Dec(Coef, w);
        Dec(ans, Coef * h[u] % Mod);
    }
    printf("%lld\n", ans);
    return 0; 
}