题解 P2166 【Gty的超级妹子树】

· · 题解

下面是一个不基于树分块的做法(树分块对块数无法保证,可以用菊花图卡掉):

虽然这个做法也可以被卡,但必须针对块长(一条链,平分第一块,前一半2操作,后一半3操作,然后询问)。

前置知识:Gty的妹子树

上面这题O(n^{1.5}\log n)的解法:对树的DFS序建划分树(线段树的一种,运用归并排序的方法在各个节点维护对应区间内元素排好序的结果,可以O(logn)询问区间k大,不支持修改),对每个询问先找到划分树上的结果,然后暴力讨论每个修改对答案的影响,积累到O(\sqrt{n})个询问后暴力重建划分树。

如无法理解上面这段话请出门左转该题题解。

现在我们讨论3号操作对答案的影响。(假设当前询问节点为u)(先忽略1、2操作)

1.这一操作的节点不是u的儿子。它对答案没有影响。O(1)

2.这一操作的节点在u为根的子树中,并且u到该节点的路径上没有其他节点被3号操作影响。那么该节点为根的子树中各个节点都不应当被计入答案(但实际上都被计入了),只要把这一部分的答案去掉即可(在划分树上询问)。O(\log n)

3.这一操作的节点在u为根的子树中,但u到该节点的路径上还有其他节点被3操作影响。那么这一节点对答案的影响已被去除,无需修改。O(1)

上面有两个问题:询问是否在u为根的子树中(解法:维护每个节点向上跳2^i后的节点)、询问某路径(两端点间有祖先——后代关系)上是否有节点被3操作影响(解法:先转化为对被3操作影响的节点进行计数,树上差分转化为对根到某一节点被3操作影响的节点进行计数,则每个3操作对以它为根的子树中的节点有贡献,可以在划分树中节点多开一个变量维护)

1号操作对答案的影响:

1.这一操作的节点不是u的儿子。它对答案没有影响。O(1)

2.这一操作的节点在u为根的子树中,并且u到该节点的路径上没有节点被3号操作影响。那么该节点可能对答案有影响,记录修改前后的权值不难判断。O(1)

3.这一操作的节点在u为根的子树中,但u到该节点的路径上还有其他节点被3操作影响。那么这一节点对答案的影响已被去除,无需修改。O(1)

2号操作对答案的影响:

1.这一操作的节点不是u的儿子。它对答案没有影响。O(1)

2.这一操作的节点在u为根的子树中,并且u到该节点的路径上没有节点被3号操作影响。那么该节点可能对答案有影响,不难判断。O(1)

3.这一操作的节点在u为根的子树中,但u到该节点的路径上还有其他节点被3操作影响。那么这一节点对答案的影响已被去除,无需修改。O(1)

下面有一个重要的东西:

维护根到2号操作中新建节点的路径上被3操作影响的节点数(开一个数组记录):

1.新建了一个节点。那么新建的节点可以继承其父亲的答案。

2.产生了被3操作影响的节点。暴力扫描新建的节点,判断是否被影响。

以下是复杂度分析:

1操作:记录修改其前后的权值O(1),总计O(n)

2操作:基础维护O(1),维护向上跳2^i后的节点O(\log n),继承父亲的被影响节点数O(\log n),总计O(n\log n)

3操作:修改影响到的节点(原树中的节点)O(\log n),(新建的节点)O(\sqrt{n}\log n),总计O(n^{1.5}\log n)

询问:可以发现为O(n^{1.5}\log n)

重建:单次O(n\log n),总计O(n^{1.5}\log n)

因此总复杂度仍为O(n^{1.5}\log n)

实现细节:

看着不长,写完以后发现和树套树差不多……

准备对拍吧(去掉重建部分对拍一次,每次重建再对拍一次)

块长取大一点,恰好\sqrt{n}容易被卡常(如果知道为什么请私信)

#include<bits/stdc++.h>
using namespace std;
int read() {
    char c=getchar();while(!isdigit(c))c=getchar();
    int num=0;while(isdigit(c))num=num*10+c-'0',c=getchar();
    return num;
}
void write(int num){if(num>=10)write(num/10);putchar(num%10+'0');}
int head[200001], ver[400001], nxt[400001], sz;
void addedge(int u, int v) {
    ver[++sz] = v, nxt[sz] = head[u], head[u] = sz;
    ver[++sz] = u, nxt[sz] = head[v], head[v] = sz;
}
int fa[200001], dep[200001];
void dfs(int x) {
    for (int i = head[x]; i; i = nxt[i])
        if (fa[x] != ver[i]) {
            fa[ver[i]] = x;
            dep[ver[i]] = dep[x] + 1;
            dfs(ver[i]);
        }
}
int n;
int seq[200001], pt[200001];
int top;
int size[200001];
int w[200001];
void getsq(int x) {
    seq[++top] = x, pt[x] = top;
    size[x] = 1;
    for (int i = head[x]; i; i = nxt[i])
        if (fa[ver[i]] == x) {
            getsq(ver[i]);
            size[x] += size[ver[i]];
        }
}
int bf;
vector<int> sq[800001];
int dat[800001];
int cnt[800001];
int out[100001];
void erase(int p, int l, int r) {
    sq[p].clear();
    dat[p] = cnt[p] = 0;
    if (l == r) return;
    int mid = (l + r) / 2;
    erase(p * 2, l, mid);
    erase(p * 2 + 1, mid + 1, r);
}
void pushdown(int p, int l, int r) {
    int mid = (l + r) / 2;
    cnt[p*2] += dat[p]*(mid-l+1);
    cnt[p*2+1] += dat[p]*(r-mid);
    dat[p*2] += dat[p];
    dat[p*2+1] += dat[p];
    dat[p] = 0;
}
void init(int p, int l, int r) {
    if (l == r) {sq[p].push_back(w[seq[l]]);return;}
    int mid = (l + r) / 2;
    init(p * 2, l, mid);
    init(p * 2 + 1, mid + 1, r);
    int p1 = 0, p2 = 0;
    while (p1 < sq[p*2].size() && p2 < sq[p*2+1].size())
        if (sq[p*2][p1]<sq[p*2+1][p2]) sq[p].push_back(sq[p*2][p1++]);
        else sq[p].push_back(sq[p*2+1][p2++]);
    while (p1 < sq[p*2].size()) sq[p].push_back(sq[p*2][p1++]);
    while (p2 < sq[p*2+1].size()) sq[p].push_back(sq[p*2+1][p2++]);
}
void insert(int p, int l, int r, int l0, int r0) {
    if (l0 >= l && r0 <= r) {
        ++dat[p];
        cnt[p] += r0 - l0 + 1;
        return;
    }
    pushdown(p, l0, r0);
    int mid = (l0 + r0) / 2;
    if (l <= mid) insert(p * 2, l, r, l0, mid);
    if (r > mid) insert(p * 2 + 1, l, r, mid + 1, r0);
    cnt[p] = cnt[p*2] + cnt[p*2+1];
}
int query(int p, int l0, int r0, int u) {
    if (l0 == r0) return cnt[p];
    pushdown(p, l0, r0);
    int mid = (l0 + r0) / 2;
    if (u <= mid) return query(p * 2, l0, mid, u);
    else return query(p * 2 + 1, mid + 1, r0, u);
}
int query(int p, int l, int r, int l0, int r0, int k) {
    if (l0 >= l && r0 <= r) return sq[p].end() - upper_bound(sq[p].begin(), sq[p].end(), k);
    int ans = 0, mid = (l0 + r0) / 2;
    if (l <= mid) ans += query(p * 2, l, r, l0, mid, k);
    if (r > mid) ans += query(p * 2 + 1, l, r, mid + 1, r0, k);
    return ans;
}
int qu(int u, int maxn) {
    if (u > maxn) return out[u];
    return query(1, 1, maxn, pt[u]);
}
int f[200001][20];
void build() {
    top = 0;
    for (int i = 1; i <= n; i++) if (!fa[i]) getsq(i);
    erase(1, 1, bf);
    memset(f, 0, sizeof(f));
    for (int i = 1; i <= n; i++) f[i][0] = fa[i];
    for (int j = 1; j <= 18; j++)
        for (int i = 1; i <= n; i++)
            f[i][j] = f[f[i][j-1]][j-1];
    init(1, 1, n);
}
inline int jmp(int x, int d) {
    for (int i = 18; i >= 0; i--)
        if (dep[f[x][i]] >= d)
            x = f[x][i];
    return x;
}
struct cmd {
    int op, u, x, x0;
}c[100001];
int main() {
    n = read();
    for (int i = 1; i < n; i++) {
        int u = read(), v = read();
        addedge(u, v);
    }
    dep[1] = 1;
    dfs(1);
    for (int i = 1; i <= n; i++) w[i] = read();
    bf = n;
    build();
    int m = read();
    int s = sqrt(m) * 4;
    int q = 1;
    int lastans = 0;
    for (int i = 1; i <= m; i++) {
        int op = read();
        switch (op) {
            case 0: {
                int u = read() ^ lastans, x = read() ^ lastans;
                int ans = 0;
                if (u <= bf) ans = query(1, pt[u], pt[u]+size[u]-1, 1, bf, x);
                for (int j = 1; j < q; j++) {
                    if (c[j].op == 1) {
                        if (jmp(c[j].u, dep[u]) != u || qu(c[j].u, bf) > qu(u, bf)) continue;
                        if (c[j].x > x) ++ans;
                        if (c[j].x0 > x) --ans;
                    }
                    if (c[j].op == 2) {
                        if (jmp(c[j].u, dep[u]) != u || qu(c[j].u, bf) > qu(u, bf)) continue;
                        if (c[j].x > x) ++ans;
                    }
                    if (c[j].op == 3) {
                        if (c[j].u > bf || jmp(c[j].u, dep[u]) != u || qu(fa[c[j].u], bf) != qu(u, bf)) continue;
                        ans -= query(1, pt[c[j].u], pt[c[j].u]+size[c[j].u]-1, 1, bf, x);       
                    }
                }
                write(lastans = ans);
                putchar('\n');
                break;
            }
            case 1: {
                int u = read() ^ lastans, x = read() ^ lastans;
                c[q].op = op, c[q].u = u, c[q].x = x, c[q].x0 = w[u];
                w[u] = x;
                ++q;
                break;
            }
            case 2: {
                int u = read() ^ lastans, x = read() ^ lastans;
                ++n;
                fa[n] = u, w[n] = x, dep[n] = dep[u] + 1;
                f[n][0] = u;
                addedge(n, u);
                for (int i = 1; i <= 18; i++) f[n][i] = f[f[n][i-1]][i-1];
                out[n] = qu(u, bf);
                c[q].op = op, c[q].u = n, c[q].x = x;
                ++q;
                break;
            }
            case 3: {
                int u = read() ^ lastans;
                if (u <= bf) insert(1, pt[u], pt[u]+size[u]-1, 1, bf);
                for (int j = bf + 1; j <= n; j++)
                    if (jmp(j, dep[u]) == u) ++out[j];
                c[q].op = op, c[q].u = u;
                ++q;
                break;
            }
        }
        if (q > s) {
            for (int j = 1; j <= q; j++)
                if (c[j].op == 3) fa[c[j].u] = 0;
            q = 1;
            bf = n;
            build();
        }
    }
}