题解:CF1787G Colorful Tree Again

· · 题解

大家好,这里是更更更更更更更更更更劣且更更更更更更更更更更难写的 O(n^{\frac{5}{3}}) 做法!

首先一个暴力的做法是,对每条路径维护 cnt_i,sum_i 表示内部有多少点被摧毁和其边权和,查询即求 \max_{cnt_i=0} sum_i,每次修改暴力更新经过当前点的路径,这样做是 O(n^2) 的。

考虑优化,对于每条路径我们取出它的两端点 x,y,并记 z=\operatorname{lca}(x,y),那么路径 x,y 经过点 u,当且仅当 xyu 子树内,且 uz 子树内。

于是对树进行 dfs,得到每个点的 dfs 序 dfn_u 和其子树大小 siz_u,那么上述条件等价于 (dfn_x \in [dfn_u,dfn_u+siz_u-1] \vee dfn_y \in [dfn_u,dfn_u+siz_u-1]) \wedge dfn_u \in [dfn_z,dfn_z+siz_z-1],第三个条件不太好处理,由于在树上,考虑其等价描述为 z=uz 不在 u 子树内,即 dfn_z \notin [dfn_u\red{+1},dfn_u+siz_u-1]

到这里就很清晰了,我们令点 (dfn_x,dfn_y,dfn_z) 表示路径 (x,y),每次修改点 u 时,若点 (x_i,y_i,z_i) 满足 (x_i \in [dfn_u,dfn_u+siz_u-1] \vee y_i \in [dfn_u,dfn_u+siz_u-1]) \wedge z_i \notin [dfn_u\red{+1},dfn_u+siz_u-1],就令 cnt_i \larr cnt_i \pm 1

使用 3D tree 维护点集,由于修改性质保证了任意时刻 cnt_i \ge 0,于是 3D tree 上每个结点维护 mnv_j,mxv_j,表示子树内 \min cnt_i,和子树内 \max_{cnt_i=mnv_j} sum_i,修改时打标记,在根节点查询并检查 mnv_{root} 是否为零即可,复杂度 O(n^{\frac{5}{3}})

void mainSolve()
{
    int n, q;
    read(n, q);
    vector <vector <array <int, 3>>> gra(n);
    for (int i : viota(1, n))
    {
        int u, v, w, c;
        read(u, v, w, c), --u, --v, --c;
        gra[u].push_back({v, w, c});
        gra[v].push_back({u, w, c});
    }

    int tim = 0;
    vector <int> dfn(n, -1), siz(n);
    vector <vector <array <int, 2>>> deg(n);
    vector <LL> sum(n);
    auto dfs = [&] (this auto&& self, int u) -> void
    {
        dfn[u] = tim ++, siz[u] = 1;
        rsort(gra[u], [] (auto a, auto b) {return a[2] < b[2];});
        gra[u].push_back({0, 0, -1});
        for (int j = 0; int i : viota(0u, gra[u].size()))
            if (gra[u][i][2] != gra[u][j][2])
                deg[gra[u][j][2]].push_back({u, i - j}), j = i;
        for (auto [v, w, c] : gra[u])
            if (dfn[v] < 0) sum[c] += w, self(v), siz[u] += siz[v];
    };

    int tot = 0, rot = 0;
    vector <array <LL, 4>> pnt;
    vector <array <LL, 3>> mnp, mxp;
    vector <array <LL, 2>> val;
    vector <int> tag;
    #define p (l + r & ~(r - l > 1))
    #define m (l + r >> 1)
    #define lp (l + m & ~(m - l > 1))
    #define rp (m + r & ~(r - m > 1))
    auto build = [&] (this auto&& self, int l, int r, int d) -> void
    {
        if (r - l == 1)
        {
            auto [x, y, z, w] = pnt[l];
            mnp[p] = mxp[p] = {x, y, z}, val[p] = {0, w};
            return;
        }
        ranges::nth_element(pnt | vtake(r) | vdrop(l), pnt.begin() + m,
            [&d] (auto a, auto b) {return a[d] < b[d];});
        d = d < 2 ? d + 1 : 0, self(l, m, d), self(m, r, d);
        for (int i : {0, 1, 2})
        {
            mnp[p][i] = min(mnp[lp][i], mnp[rp][i]);
            mxp[p][i] = max(mxp[lp][i], mxp[rp][i]);
        }
        val[p][1] = max(val[lp][1], val[rp][1]);
    };
    auto update = [&] (int u, int v) {val[u][0] += v, tag[u] += v;};
    auto modify = [&] (this auto&& self, int L, int R, int v, int l, int r) -> void
    {
        if (((mxp[p][0] < L || mnp[p][0] > R)
          && (mxp[p][1] < L || mnp[p][1] > R))
          || (mnp[p][2] > L && mxp[p][2] <= R))
            return;
        if (((mnp[p][0] >= L && mxp[p][0] <= R)
          || (mnp[p][1] >= L && mxp[p][1] <= R))
          && (mxp[p][2] <= L || mnp[p][2] > R))
            return update(p, v);
        if (tag[p]) update(lp, tag[p]), update(rp, tag[p]), tag[p] = 0;
        self(L, R, v, l, m), self(L, R, v, m, r);
        if (val[lp][0] == val[rp][0]) val[p] = {val[lp][0], max(val[lp][1], val[rp][1])};
        else val[p] = val[lp][0] < val[rp][0] ? val[lp] : val[rp];
    };
    #undef p
    #undef m
    #undef lp
    #undef rp

    dfs(0);
    for (int i : viota(0, n))
    {
        int cnt1 = 0, cnt2 = 0;
        int x = 0, y = 0, z = 0;
        for (auto [u, c] : deg[i])
            if (c == 1) ++cnt1 == 1 ? x = u : y = u;
            else if (c == 2) z = ++cnt2 > 1 && dfn[u] > dfn[z] ? z : u;
        if (cnt1 != 2 || cnt1 + cnt2 != deg[i].size()) continue;
        if (!cnt2 || min(dfn[x], dfn[y]) < dfn[z]) z = dfn[x] < dfn[y] ? x : y;
        pnt.push_back({dfn[x], dfn[y], dfn[z], sum[i]}), ++tot;
    }
    if (tot)
    {
        mnp.resize(tot * 2), mxp.resize(tot * 2);
        val.resize(tot * 2), tag.resize(tot * 2);
        build(0, tot, 0), rot = tot ^ (tot & 1);
    }

    while (q --)
    {
        int o, x; read(o, x), --x;
        if (!tot) {write(0); continue;}
        modify(dfn[x], dfn[x] + siz[x] - 1, o ? -1 : 1, 0, tot);
        write(val[rot][0] ? 0 : val[rot][1]);
    }
}