题解:P4242 树上的毒瘤

· · 题解

题目描述。

很不错的题。Tag:虚树、树链剖分、换根相关。 本文不讲解上述前置知识。

看到树上颜色段覆盖、查询,容易联想到树链剖分。\ 树上颜色段数量是不难统计的。先用树链剖分拍在序列上,然后区间只要维护颜色段数、左右颜色即可。

看到每次询问大小为 m 的子集,并不难联想到虚树相关。对于每次询问建立点集 S 的虚树。\ 钦定 1 号点为根,通过树链剖分 O(m \log n) 加和求出 W_11 号点的答案。\ 但是题目要求点集 S 中每个点的答案?这启发我们做换根。\ 具体地,记录一下虚树上每个子树有多少点在点集 S 内,然后换根其实和求“每个点到 u 的路径长度之和”是一样的,只是路径长度改成了路上的颜色段数。

变量重名挂得早,封装就是好!\ 代码中含有少量的魔怔,可自行去除注释。

#include <bits/stdc++.h>
//#include <windows.h>
using namespace std;
const int N = 1e5 + 5, M = N << 1;
int n, q, a[N];
int h[N], e[M], ne[M], idx;
inline void add(int a, int b) { e[idx] = b, ne[idx] = h[a], h[a] = idx++; }

struct Path { int a, b, c; } ;  // ×óÓÒÑÕÉ«¡¢ÑÕÉ«¶ÎÊý 
Path merge(Path x, Path y) { return (Path){x.a, y.b, x.c + y.c - (x.b == y.a)}; }

void change(int u, int l, int r, int d);
Path query(int u, int l, int r);

//----------------------------------------------------------------------------

int dep[N], fa[N], sz[N], son[N];
void dfs(int u, int father) {
    fa[u] = father, dep[u] = dep[father] + 1, sz[u] = 1;
    for (int i = h[u]; ~i; i = ne[i]) {
        int v = e[i]; if (v == father) continue;
        dfs(v, u), sz[u] += sz[v];
        if (!son[u] || sz[son[u]] < sz[v]) son[u] = v;
    }
}
int top[N], dfn[N], id[N], tot;
void dfs2(int u, int t) {
    top[u] = t, dfn[u] = ++tot, id[tot] = u;
    if (son[u]) dfs2(son[u], t);
    for (int i = h[u]; ~i; i = ne[i]) {
        int v = e[i]; if (v == fa[u] || v == son[u]) continue;
        dfs2(v, v);
    }
}

inline int lca(int a, int b) {
    while (top[a] ^ top[b]) {
        if (dep[top[a]] < dep[top[b]]) swap(a, b);
        a = fa[top[a]];
    }
    if (dep[a] > dep[b]) swap(a, b);
    return a;
}

void change_path(int u, int v, int d) {
    while (top[u] ^ top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
//      cout << "Change: "; for (int i = dfn[top[u]]; i <= dfn[u]; i++) cout << id[i] << ' '; puts("");
        change(1, dfn[top[u]], dfn[u], d), u = fa[top[u]];
    }
//  cout << "Change: "; for (int i = dfn[u]; i <= dfn[v]; i++) cout << id[i] << ' '; puts("");
    if (dep[u] > dep[v]) swap(u, v);
    change(1, dfn[u], dfn[v], d);
}
Path query_path(int u, int v) { // ÇÕ¶¨ u ÊÇ v µÄ׿ÏÈ 
    Path res = (Path){-1, -1, -1};
    while (top[u] ^ top[v]) {
        Path val = query(1, dfn[top[v]], dfn[v]);
        if (res.a == -1) res = val;
        else res = merge(val, res);
        v = fa[top[v]];
    }
    Path val = query(1, dfn[u], dfn[v]);
    if (res.a == -1) res = val;
    else res = merge(val, res);
    return res;
}
//----------------------------------------------------------------------------

struct Tree {
    int l, r, cov;
    Path res;
} tr[N << 2];
inline void pushup(int u) { tr[u].res = merge(tr[u << 1].res, tr[u << 1 | 1].res); }
inline void update(int u, int d) { tr[u].cov = d, tr[u].res = (Path){d, d, 1}; }
inline void pushdown(int u) { if (tr[u].cov) update(u << 1, tr[u].cov), update(u << 1 | 1, tr[u].cov), tr[u].cov = 0; }
void build(int u, int l, int r) {
    tr[u].l = l, tr[u].r = r, tr[u].cov = 0;
    if (l == r) return tr[u].res = (Path){a[id[l]], a[id[l]], 1}, void();
    int mid = l + r >> 1;
    build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
    pushup(u);
}
void change(int u, int l, int r, int d) {
    if (tr[u].l >= l && tr[u].r <= r) return update(u, d), void();
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if (l <= mid) change(u << 1, l, r, d);
    if (r > mid) change(u << 1 | 1, l, r, d);
    pushup(u);
}
Path query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) return tr[u].res;
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if (r <= mid) return query(u << 1, l, r);
    if (l > mid) return query(u << 1 | 1, l, r);
    return merge(query(u << 1, l, r), query(u << 1 | 1, l, r));
}

int qwq[N];
void print(int u) {
    if (tr[u].l == tr[u].r) return qwq[id[tr[u].l]] = tr[u].res.a, void();
    pushdown(u);
    print(u << 1), print(u << 1 | 1);
}

//----------------------------------------------------------------------------

struct Virtual_Tree {
    int m, h[N], e[M], ne[M], idx;
    Path w[M];  // fa -> son
    inline void add(int a, int b, Path c) { e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++; }
    int cnt[N];
    long long Ans[N], ans;
    void dfs1(int u) { for (int i = h[u]; ~i; i = ne[i]) dfs1(e[i]), cnt[u] += cnt[e[i]]; }
    void dfs2(int u) {
        Ans[u] = ans;
//      cout << "Ans: " << u << ' ' << ans << endl;
        long long pre = ans;
        for (int i = h[u]; ~i; i = ne[i]) {
            int v = e[i];
            int in = cnt[v], out = m - in, a = w[i].a, b = w[i].b, c = w[i].c;
            ans -= in * 1ll * (c - 1), ans += out * 1ll * (c - 1);
            dfs2(v), ans = pre;
        }
    }
    void clr(int u) {
        cnt[u] = Ans[u] = 0;
        for (int i = h[u]; ~i; i = ne[i]) clr(e[i]);
        h[u] = -1;
    }
} VT;

int arr[N], qry[N], m;
int stk[N], tt;
inline bool cmp(int a, int b) { return dfn[a] < dfn[b]; }
inline void insert(int u, int v) {  // ÇÕ¶¨ u ÊÇ v µÄ׿ÏÈ 
    Path w = query_path(u, v);
//  cout << "Virtual Tree Add edge: " << u << ' ' << v << '\t' << w.a << ' ' << w.b << ' ' << w.c << endl;
    VT.add(u, v, w);
}
void build_VTree() {
    VT.ans = 0ll, VT.idx = 0;
    sort(arr + 1, arr + 1 + m, cmp);
    for (int i = 1; i <= m; i++) VT.cnt[arr[i]] = 1, VT.ans += query_path(1, arr[i]).c/*, cout << "Init: " << arr[i] << ' ' << query_path(1, arr[i]).c << endl*/;
    stk[tt = 1] = 1;
    for (int i = 1; i <= m; i++) {
        if (arr[i] == 1) continue;
        int u = arr[i], l = lca(arr[i], stk[tt]);
        while (dfn[l] < dfn[stk[tt - 1]]) insert(stk[tt - 1], stk[tt]), tt--;
        if (l == stk[tt]) stk[++tt] = u;
        else if (dfn[stk[tt - 1]] < dfn[l] && dfn[l] < dfn[stk[tt]]) insert(l, stk[tt]), tt--, stk[++tt] = l, stk[++tt] = u;
        else if (l == stk[tt - 1]) insert(stk[tt - 1], stk[tt]), tt--, stk[++tt] = u;
    }
    while (tt > 1) insert(stk[tt - 1], stk[tt]), tt--;
}

int main() {
//  freopen("rubbish.out", "w", stdout); 
    scanf("%d%d", &n, &q);
    for (int i = 1; i <= n; i++) h[i] = VT.h[i] = -1; idx = VT.idx = 0;
    for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
    for (int i = 1, a, b; i < n; i++) scanf("%d%d", &a, &b), add(a, b), add(b, a);
    dfs(1, 0), dfs2(1, 1), build(1, 1, n);

//  cout << "Dfn: "; for (int i = 1; i <= n; i++) cout << dfn[i] << ' '; puts("");
//  cout << "Top: "; for (int i = 1; i <= n; i++) cout << top[i] << ' '; puts("");

    while (q--) {
        int op; scanf("%d", &op);
        if (op == 1) {
            int u, v, d; scanf("%d%d%d", &u, &v, &d);
            change_path(u, v, d);
        } else {
            scanf("%d", &m), VT.m = m;
            for (int i = 1; i <= m; i++) scanf("%d", &arr[i]), qry[i] = arr[i];
            build_VTree();
            VT.dfs1(1), VT.dfs2(1);
//    SetConsoleTextAttribute(GetStdHandle(STD_OUTPUT_HANDLE), FOREGROUND_INTENSITY | FOREGROUND_BLUE);//????
            for (int i = 1; i <= m; i++) printf("%lld ", VT.Ans[qry[i]]); puts("");
//    SetConsoleTextAttribute(GetStdHandle(STD_OUTPUT_HANDLE), FOREGROUND_INTENSITY |FOREGROUND_RED |FOREGROUND_GREEN | FOREGROUND_BLUE);//??????
            VT.clr(1);
        }
//      cout << "Print: "; print(1); for (int i = 1; i <= n; i++) printf("%d ", qwq[i]); puts("");
    }
    return 0;
}