题解 P4689 【[Ynoi2016]这是我自己的发明】

· · 题解

貌似是第一次完全脱离题解切掉一个黑题呢?

进入正题。

首先如果你的做题量够多的话一眼就会发现这个询问就是 P5268 [SNOI2017]一个简单的询问 搬到了子树上面。

然后子树转序列有一个很套路的方式就是用 dfs 序。

但是换根怎么搞呢?

其实你只需要稍微想一下就会知道换根是假的,对于任意一个节点 u,定义此时 u 的子树是按以 1 为根计算的,那么我们分情况讨论:

  1. 如果此时的根是不在 u 的子树内的,那么显然以此时的根计算的子树与以 1 为根计算的子树是一样的。
  2. 如果此时的根是 u,那么显然子树就是整棵树。
  3. 如果此时的根在 u 的一个孩子 v 的子树里面,那么显然此时 u 的子树就是整棵树去掉 v 的子树。

显然上面三种情况都可以在 dfs 序上面找到对应的 1\sim 2 段连续区间。

所以换根是假的~

这时你可能会想到直接在 dfs 序上面做 P5268。

下面为了方便说明,我们记 f(a, b, c, d)[a,b] 区间和 [c,d] 区间之间的查询结果,g(x,y)=f(1,x,1,y)

我们在计算 f(l_1,r_1,l_2,r_2) 的时候是拆了 4 个询问,变成 g(r_1,r_2)-g(l_1-1,r_2)-g(r_1,l_2-1)+g(l_1-1,l_2-1)

然后如果根同时出现在 uv 的子树里面,我们要算的是 f(1,l_1,1,l_2)+f(1,l_1,r_2,n)+f(r_1,n,r_2,n)+f(r_1,n,1,l_2)

直接拆得到 4\times 4=16 个询问然后做莫队!

抱歉你这样是过不掉的

然后你发现可以倍长 dfs 序列,然后就不用拆区间了,就只剩 4 倍的询问了!

抱歉莫队的复杂度是 \sout {O(n\sqrt m)} 的,你把 \sout n\sout 2 和把 \sout m\sout 4 木有区别

mmp

所以只能去挖掘一下这个 f 的性质,发现有如下两个性质:

前者显然,后者可以这样理解为如果要在 [1,d] 区间内选一个数,那么要么在 [c,d] 里面要么在 [1,c-1] 里面,而且两者的方案是可以加减的。

于是我们就可以转化这个式子。

首先考虑 u 是一段连续区间,v 是两段连续区间的情况。

这时要计算的是 f(l_1,r_1,1,l_2)+f(l_1,r_1,r_2,n)

利用上面的等式可以算出:

f(l_1,r_1,1,l_2)+f(l_1,r_1,r_2,n)=g(r_1,l_2)-g(l_1-1,l_2)+g(r_1,n)-g(l_1-1,n)-g(r_1,r_2-1)+g(l_1-1,r_2-1)

发现可以预处理 g(i,n),所以实际上这里需要拆的询问只有 4 个。

h(i)=g(i,n)

然后考虑都是两段询问的情况。

此时要计算:

f(1,l_1,1,l_2)+f(1,l_1,r_2,n)+f(r_1,n,1,l_2)+f(r_1,n,r_2,n)

化为下式:

g(l_1,l_2)+h(l_1)-g(l_1,r_2-1)+h(l_2)-g(l_2,r_1-1)+h(n)-h(r_1-1)-h(r_2-1)+g(r_1-1,r_2-1)

发现依然只需要拆 4 个询问。

所以所有询问都可以只拆 4 个,也就是说算法只带 2 倍时间常数和 4 倍空间常数,那么这道题就做完了~

最终时间 O(n\sqrt m),空间 O(n+m)

然后就是注意莫队块长需要取 \dfrac{n}{\sqrt m} 来达到最优复杂度,直接取 \sqrt n 会导致复杂度升高到 O((n+m)\sqrt n),以及注意查询节点是根的时候的一些细节。

上代码~

#include <iostream>
#include <cmath>
#include <cstring>
#include <cstdio>
#include <algorithm>
using namespace std;

#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;

inline int qread() {
    register char c = getchar();
    register int x = 0, f = 1;
    while (c < '0' || c > '9') {
        if (c == '-') f = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') {
        x = (x << 3) + (x << 1) + c - 48;
        c = getchar();
    }
    return x * f;
}

inline int Abs(const int& x) {return (x > 0 ? x : -x);}
inline int Max(const int& x, const int& y) {return (x > y ? x : y);}
inline int Min(const int& x, const int& y) {return (x < y ? x : y);}

const int N = 100005, M = 500005;

struct Edge {
    int to, nxt;
    Edge() {
        nxt = -1;
    }
};
Edge e[N << 1];
int n, hd[N], pnt, m, pos[N], qcnt, dfn[N], dep[N], post[N], a[N], b[N], S, fa[N][25], _time, cnt1[N], cnt2[N], atmp[N];
bool isquery[M];
long long ans[M], pref[N], curval; // pref[i] = f(1, i, 1, n)
struct Query {
    int l, r, id;
    Query(int l = 0, int r = 0, int id = 0) : l(l), r(r), id(id) {}
    bool operator < (const Query& b) const {
        return (pos[l] ^ pos[b.l] ? pos[l] < pos[b.l] : (pos[l] & 1 ? r < b.r : r > b.r));
    }
};
Query qry[M << 2];

inline void AddEdge(int u, int v) {
    e[++pnt].to = v;
    e[pnt].nxt = hd[u];
    hd[u] = pnt;
}

inline void Read() {
    n = qread(); m = qread();
    for (register int i = 1;i <= n;i++) b[i] = a[i] = qread();
    for (register int i = 1;i < n;i++) {
        register int u = qread(), v = qread();
        AddEdge(u, v);
        AddEdge(v, u);
    }
}

inline void Dfs(int u) {
    dfn[u] = ++_time;
    atmp[_time] = a[u];
    for (register int i = hd[u];~i;i = e[i].nxt) {
        if (e[i].to != fa[u][0]) {
            fa[e[i].to][0] = u;
            dep[e[i].to] = dep[u] + 1;
            Dfs(e[i].to);
        }
    }
    post[u] = _time;
}

inline int Up(int u, int k) {
    for (register int j = 20;j >= 0;j--) {
        if ((k >> j) & 1) u = fa[u][j];
    }
    return u;
}

inline void AddQuery(int l, int r, int i) {
    if (l < 1 || l > n) return;
    if (r < 1 || r > n) return;
    if (l > r) swap(l, r);
    qry[++qcnt] = Query(l, r, i);
}

inline void Prefix() {
    sort(b + 1, b + n + 1);
    register int vtop = unique(b + 1, b + n + 1) - b - 1;
    for (register int i = 1;i <= n;i++) a[i] = lower_bound(b + 1, b + vtop + 1, a[i]) - b;
    for (register int j = 1;j <= 20;j++) {
        for (register int i = 1;i <= n;i++) fa[i][j] = fa[fa[i][j - 1]][j - 1];
    }
    for (register int i = 1;i <= n;i++) cnt1[a[i]]++;
    for (register int i = 1;i <= n;i++) pref[i] = pref[i - 1] + cnt1[a[i]];
    register int rt = 1;
    for (register int i = 1;i <= m;i++) {
        register int opt = qread();
        if (opt == 1) rt = qread();
        else {
            isquery[i] = 1;
            register int u = qread(), v = qread(), type = 0;
            if (dfn[u] <= dfn[rt] && dfn[rt] <= post[u]) type |= 2;
            if (dfn[v] <= dfn[rt] && dfn[rt] <= post[v]) type |= 1;
            //printf("T=%d\n", type);
            if (type == 0) {
                register int l1 = dfn[u], r1 = post[u], l2 = dfn[v], r2 = post[v];
                //printf("l1=%d r1=%d l2=%d r2=%d\n", l1, r1, l2, r2);
                AddQuery(r1, r2, i);
                AddQuery(l1 - 1, r2, -i);
                AddQuery(r1, l2 - 1, -i);
                AddQuery(l1 - 1, l2 - 1, i);
            } else if (type == 1) {
                register int l1, r1, l2, r2;
                l1 = dfn[u]; r1 = post[u];
                if (v != rt) {
                    register int sv = Up(rt, dep[rt] - dep[v] - 1);
                    l2 = dfn[sv] - 1; r2 = post[sv] + 1;
                } else {
                    l2 = 0;
                    r2 = 1;
                }
                /*
                  f(l1, r1, 1, l2) + f(l1, r1, r2, n)
                = f(1, r1, 1, l2) - f(1, l1 - 1, 1, l2) + f(l1, r1, 1, n) - f(l1, r1, 1, r2 - 1)
                = f(1, r1, 1, l2) - f(1, l1 - 1, 1, l2) + fpre(r1) - fpre(l1 - 1) - f(1, r1, 1, r2 - 1) + f(1, l1 - 1, 1, r2 - 1)
                */
                ans[i] += pref[r1] - pref[l1 - 1];
                AddQuery(r1, l2, i);
                AddQuery(l1 - 1, l2, -i);
                AddQuery(r1, r2 - 1, -i);
                AddQuery(l1 - 1, r2 - 1, i);
            } else if (type == 2) {
                swap(u, v);
                register int l1, r1, l2, r2;
                l1 = dfn[u]; r1 = post[u];
                if (v != rt) {
                    register int sv = Up(rt, dep[rt] - dep[v] - 1);
                    l2 = dfn[sv] - 1; r2 = post[sv] + 1;
                } else {
                    l2 = 0;
                    r2 = 1;
                }
                /*
                  f(l1, r1, 1, l2) + f(l1, r1, r2, n)
                = f(1, r1, 1, l2) - f(1, l1 - 1, 1, l2) + f(l1, r1, 1, n) - f(l1, r1, 1, r2 - 1)
                = f(1, r1, 1, l2) - f(1, l1 - 1, 1, l2) + fpre(r1) - fpre(l1 - 1) - f(1, r1, 1, r2 - 1) + f(1, l1 - 1, 1, r2 - 1)
                */
                ans[i] += pref[r1] - pref[l1 - 1];
                AddQuery(r1, l2, i);
                AddQuery(l1 - 1, l2, -i);
                AddQuery(r1, r2 - 1, -i);
                AddQuery(l1 - 1, r2 - 1, i);
            } else if (type == 3) {
                register int l1, r1, l2, r2;
                if (u != rt) {
                    register int su = Up(rt, dep[rt] - dep[u] - 1);
                    l1 = dfn[su] - 1; r1 = post[su] + 1;
                } else {
                    l1 = 0;
                    r1 = 1;
                }
                if (v != rt) {
                    register int sv = Up(rt, dep[rt] - dep[v] - 1);
                    l2 = dfn[sv] - 1; r2 = post[sv] + 1;
                } else {
                    l2 = 0;
                    r2 = 1;
                }
                ans[i] += pref[l1] + pref[l2] + pref[n] - pref[r1 - 1] - pref[r2 - 1];
                AddQuery(l1, r2 - 1, -i);
                AddQuery(l1, l2, i);
                AddQuery(l2, r1 - 1, -i);
                AddQuery(r1 - 1, r2 - 1, i);
            }
        }
    }
}

inline void Addl(int x) {curval += cnt2[x]; cnt1[x]++;}
inline void Dell(int x) {curval -= cnt2[x]; cnt1[x]--;}
inline void Addr(int x) {curval += cnt1[x]; cnt2[x]++;}
inline void Delr(int x) {curval -= cnt1[x]; cnt2[x]--;}

inline void Solve() {
    //for (register int i = 1;i <= n;i++) printf("%d ", a[i]);
    //puts("");
    S = (int)(1.0 * n / sqrt(qcnt));
    for (register int i = 1;i <= n;i++) pos[i] = (i - 1) / S + 1;
    sort(qry, qry + qcnt + 1);
    memset(cnt1, 0, sizeof(cnt1));
    memset(cnt2, 0, sizeof(cnt2));
    register int lp = 0, rp = 0;
    for (register int i = 1;i <= qcnt;i++) {
        //printf("%d %d %d\n", qry[i].l, qry[i].r, qry[i].id);
        while (lp < qry[i].l) Addl(a[++lp]);
        while (rp < qry[i].r) Addr(a[++rp]);
        while (lp > qry[i].l) Dell(a[lp--]);
        while (rp > qry[i].r) Delr(a[rp--]);
        if (qry[i].id > 0) ans[qry[i].id] += curval;
        else ans[-qry[i].id] -= curval;
    }
    for (register int i = 1;i <= m;i++) {
        if (isquery[i]) printf("%lld\n", ans[i]);
    }
}

int main() {
    memset(hd, -1, sizeof(hd));
    Read();
    dep[1] = 1; Dfs(1); memcpy(a, atmp, sizeof(a));
    Prefix();
    Solve();
    #ifndef ONLINE_JUDGE
    while (1);
    #endif
    return 0;
}