题解 P4689 【[Ynoi2016]这是我自己的发明】
貌似是第一次完全脱离题解切掉一个黑题呢?
进入正题。
首先如果你的做题量够多的话一眼就会发现这个询问就是 P5268 [SNOI2017]一个简单的询问 搬到了子树上面。
然后子树转序列有一个很套路的方式就是用 dfs 序。
但是换根怎么搞呢?
其实你只需要稍微想一下就会知道换根是假的,对于任意一个节点
- 如果此时的根是不在
u 的子树内的,那么显然以此时的根计算的子树与以1 为根计算的子树是一样的。 - 如果此时的根是
u ,那么显然子树就是整棵树。 - 如果此时的根在
u 的一个孩子v 的子树里面,那么显然此时u 的子树就是整棵树去掉v 的子树。
显然上面三种情况都可以在 dfs 序上面找到对应的
所以换根是假的~
这时你可能会想到直接在 dfs 序上面做 P5268。
下面为了方便说明,我们记
我们在计算
然后如果根同时出现在
直接拆得到
抱歉你这样是过不掉的
然后你发现可以倍长 dfs 序列,然后就不用拆区间了,就只剩
抱歉莫队的复杂度是
mmp
所以只能去挖掘一下这个
-
f(a,b,c,d)=f(c,d,a,b) -
f(a,b,c,d)=f(a,b,1,d)-f(a,b,1,c-1)
前者显然,后者可以这样理解为如果要在
于是我们就可以转化这个式子。
首先考虑
这时要计算的是
利用上面的等式可以算出:
发现可以预处理
记
然后考虑都是两段询问的情况。
此时要计算:
化为下式:
发现依然只需要拆
所以所有询问都可以只拆
最终时间
然后就是注意莫队块长需要取
上代码~
#include <iostream>
#include <cmath>
#include <cstring>
#include <cstdio>
#include <algorithm>
using namespace std;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;
inline int qread() {
register char c = getchar();
register int x = 0, f = 1;
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = (x << 3) + (x << 1) + c - 48;
c = getchar();
}
return x * f;
}
inline int Abs(const int& x) {return (x > 0 ? x : -x);}
inline int Max(const int& x, const int& y) {return (x > y ? x : y);}
inline int Min(const int& x, const int& y) {return (x < y ? x : y);}
const int N = 100005, M = 500005;
struct Edge {
int to, nxt;
Edge() {
nxt = -1;
}
};
Edge e[N << 1];
int n, hd[N], pnt, m, pos[N], qcnt, dfn[N], dep[N], post[N], a[N], b[N], S, fa[N][25], _time, cnt1[N], cnt2[N], atmp[N];
bool isquery[M];
long long ans[M], pref[N], curval; // pref[i] = f(1, i, 1, n)
struct Query {
int l, r, id;
Query(int l = 0, int r = 0, int id = 0) : l(l), r(r), id(id) {}
bool operator < (const Query& b) const {
return (pos[l] ^ pos[b.l] ? pos[l] < pos[b.l] : (pos[l] & 1 ? r < b.r : r > b.r));
}
};
Query qry[M << 2];
inline void AddEdge(int u, int v) {
e[++pnt].to = v;
e[pnt].nxt = hd[u];
hd[u] = pnt;
}
inline void Read() {
n = qread(); m = qread();
for (register int i = 1;i <= n;i++) b[i] = a[i] = qread();
for (register int i = 1;i < n;i++) {
register int u = qread(), v = qread();
AddEdge(u, v);
AddEdge(v, u);
}
}
inline void Dfs(int u) {
dfn[u] = ++_time;
atmp[_time] = a[u];
for (register int i = hd[u];~i;i = e[i].nxt) {
if (e[i].to != fa[u][0]) {
fa[e[i].to][0] = u;
dep[e[i].to] = dep[u] + 1;
Dfs(e[i].to);
}
}
post[u] = _time;
}
inline int Up(int u, int k) {
for (register int j = 20;j >= 0;j--) {
if ((k >> j) & 1) u = fa[u][j];
}
return u;
}
inline void AddQuery(int l, int r, int i) {
if (l < 1 || l > n) return;
if (r < 1 || r > n) return;
if (l > r) swap(l, r);
qry[++qcnt] = Query(l, r, i);
}
inline void Prefix() {
sort(b + 1, b + n + 1);
register int vtop = unique(b + 1, b + n + 1) - b - 1;
for (register int i = 1;i <= n;i++) a[i] = lower_bound(b + 1, b + vtop + 1, a[i]) - b;
for (register int j = 1;j <= 20;j++) {
for (register int i = 1;i <= n;i++) fa[i][j] = fa[fa[i][j - 1]][j - 1];
}
for (register int i = 1;i <= n;i++) cnt1[a[i]]++;
for (register int i = 1;i <= n;i++) pref[i] = pref[i - 1] + cnt1[a[i]];
register int rt = 1;
for (register int i = 1;i <= m;i++) {
register int opt = qread();
if (opt == 1) rt = qread();
else {
isquery[i] = 1;
register int u = qread(), v = qread(), type = 0;
if (dfn[u] <= dfn[rt] && dfn[rt] <= post[u]) type |= 2;
if (dfn[v] <= dfn[rt] && dfn[rt] <= post[v]) type |= 1;
//printf("T=%d\n", type);
if (type == 0) {
register int l1 = dfn[u], r1 = post[u], l2 = dfn[v], r2 = post[v];
//printf("l1=%d r1=%d l2=%d r2=%d\n", l1, r1, l2, r2);
AddQuery(r1, r2, i);
AddQuery(l1 - 1, r2, -i);
AddQuery(r1, l2 - 1, -i);
AddQuery(l1 - 1, l2 - 1, i);
} else if (type == 1) {
register int l1, r1, l2, r2;
l1 = dfn[u]; r1 = post[u];
if (v != rt) {
register int sv = Up(rt, dep[rt] - dep[v] - 1);
l2 = dfn[sv] - 1; r2 = post[sv] + 1;
} else {
l2 = 0;
r2 = 1;
}
/*
f(l1, r1, 1, l2) + f(l1, r1, r2, n)
= f(1, r1, 1, l2) - f(1, l1 - 1, 1, l2) + f(l1, r1, 1, n) - f(l1, r1, 1, r2 - 1)
= f(1, r1, 1, l2) - f(1, l1 - 1, 1, l2) + fpre(r1) - fpre(l1 - 1) - f(1, r1, 1, r2 - 1) + f(1, l1 - 1, 1, r2 - 1)
*/
ans[i] += pref[r1] - pref[l1 - 1];
AddQuery(r1, l2, i);
AddQuery(l1 - 1, l2, -i);
AddQuery(r1, r2 - 1, -i);
AddQuery(l1 - 1, r2 - 1, i);
} else if (type == 2) {
swap(u, v);
register int l1, r1, l2, r2;
l1 = dfn[u]; r1 = post[u];
if (v != rt) {
register int sv = Up(rt, dep[rt] - dep[v] - 1);
l2 = dfn[sv] - 1; r2 = post[sv] + 1;
} else {
l2 = 0;
r2 = 1;
}
/*
f(l1, r1, 1, l2) + f(l1, r1, r2, n)
= f(1, r1, 1, l2) - f(1, l1 - 1, 1, l2) + f(l1, r1, 1, n) - f(l1, r1, 1, r2 - 1)
= f(1, r1, 1, l2) - f(1, l1 - 1, 1, l2) + fpre(r1) - fpre(l1 - 1) - f(1, r1, 1, r2 - 1) + f(1, l1 - 1, 1, r2 - 1)
*/
ans[i] += pref[r1] - pref[l1 - 1];
AddQuery(r1, l2, i);
AddQuery(l1 - 1, l2, -i);
AddQuery(r1, r2 - 1, -i);
AddQuery(l1 - 1, r2 - 1, i);
} else if (type == 3) {
register int l1, r1, l2, r2;
if (u != rt) {
register int su = Up(rt, dep[rt] - dep[u] - 1);
l1 = dfn[su] - 1; r1 = post[su] + 1;
} else {
l1 = 0;
r1 = 1;
}
if (v != rt) {
register int sv = Up(rt, dep[rt] - dep[v] - 1);
l2 = dfn[sv] - 1; r2 = post[sv] + 1;
} else {
l2 = 0;
r2 = 1;
}
ans[i] += pref[l1] + pref[l2] + pref[n] - pref[r1 - 1] - pref[r2 - 1];
AddQuery(l1, r2 - 1, -i);
AddQuery(l1, l2, i);
AddQuery(l2, r1 - 1, -i);
AddQuery(r1 - 1, r2 - 1, i);
}
}
}
}
inline void Addl(int x) {curval += cnt2[x]; cnt1[x]++;}
inline void Dell(int x) {curval -= cnt2[x]; cnt1[x]--;}
inline void Addr(int x) {curval += cnt1[x]; cnt2[x]++;}
inline void Delr(int x) {curval -= cnt1[x]; cnt2[x]--;}
inline void Solve() {
//for (register int i = 1;i <= n;i++) printf("%d ", a[i]);
//puts("");
S = (int)(1.0 * n / sqrt(qcnt));
for (register int i = 1;i <= n;i++) pos[i] = (i - 1) / S + 1;
sort(qry, qry + qcnt + 1);
memset(cnt1, 0, sizeof(cnt1));
memset(cnt2, 0, sizeof(cnt2));
register int lp = 0, rp = 0;
for (register int i = 1;i <= qcnt;i++) {
//printf("%d %d %d\n", qry[i].l, qry[i].r, qry[i].id);
while (lp < qry[i].l) Addl(a[++lp]);
while (rp < qry[i].r) Addr(a[++rp]);
while (lp > qry[i].l) Dell(a[lp--]);
while (rp > qry[i].r) Delr(a[rp--]);
if (qry[i].id > 0) ans[qry[i].id] += curval;
else ans[-qry[i].id] -= curval;
}
for (register int i = 1;i <= m;i++) {
if (isquery[i]) printf("%lld\n", ans[i]);
}
}
int main() {
memset(hd, -1, sizeof(hd));
Read();
dep[1] = 1; Dfs(1); memcpy(a, atmp, sizeof(a));
Prefix();
Solve();
#ifndef ONLINE_JUDGE
while (1);
#endif
return 0;
}