题解 P4689 【[Ynoi2016]这是我自己的发明】
诗乃
2019-11-01 10:58:47
不难发现换根是假的。以1为根将树拍成dfs序,则:
>1、若x为当前根,则x的子树对应区间为整个dfs序列。
>2、若x在当前根的子树中,则x的子树对应区间为一段连续的区间。
>3、若当前根在x的子树中,则x的子树对应区间为整个序列挖去一段区间。
至此,可以将子树问题转化为序列上的区间问题。
设$ans(l1,r1,l2,r2)$为区间$[l1,r1]$与$[l2,r2]$的答案。
则有:
$$ans(l1,r1,l2,r2)=ans(1,r1,1,r2)+ans(1,l1-1,1,l2-1)-ans(1,r1,1,l1-1)-ans(1,r2,1,l2-1)$$
莫队处理即可。
```cpp
#include <bits/stdc++.h>
using namespace std;
typedef long long lint;
const int MAXN = 500050;
void read(int &x) {
char ch; while(ch = getchar(), ch < '!'); x = ch - 48;
while(ch = getchar(), ch > '!') x = (x << 3) + (x << 1) + ch - 48;
}
void write(lint x) {if(x > 9) write(x/10); putchar(x%10+48);}
struct Edge {int to, next;} e[MAXN];
struct Qry {int l, r, id, f;} q[8000050];
int Q, m, in[MAXN], out[MAXN], idx, _dfn, st[MAXN][21], dfn[MAXN], dep[MAXN], fa[MAXN], h[MAXN], en, rt, cnt[MAXN][2];
int is_q[MAXN], top[MAXN], son[MAXN], siz[MAXN], tl[MAXN], tr[MAXN], o[MAXN], c[MAXN], a[MAXN], n, lg2[MAXN]; lint ans[MAXN], ima, _c, SIZ;
void addedge(int u, int v) {e[en] = (Edge) {v, h[u]}; h[u] = en++;}
bool cmp(Qry a, Qry b) {return a.l/SIZ == b.l/SIZ ? a.r < b.r : a.l < b.l;}
void dfs1(int u, int p) {
st[++idx][0] = u; in[u] = ++_dfn; dfn[u] = idx; c[_dfn] = a[u];
dep[u] = dep[p] + 1; fa[u] = p; siz[u] = 1;
for(int v, i = h[u]; ~i; i = e[i].next)
if((v = e[i].to) != p) {
dfs1(v, u); st[++idx][0] = u; siz[u] += siz[v];
if(siz[v] > siz[son[u]]) son[u] = v;
}
out[u] = _dfn;
}
void dfs2(int u, int f) {
top[u] = f; if(son[u]) dfs2(son[u], f);
for(int v, i = h[u]; ~i; i = e[i].next)
if((v = e[i].to) != fa[u] && v != son[u]) dfs2(v, v);
}
int find(int y, int x) {//找x在y的哪个子树内
int u; for(; top[x] != top[y]; u = top[x], x = fa[u]);
return x == y ? u : son[y];
}
void init() {
dep[1] = 1; dfs1(1, 1); dfs2(1, 1);
for(int i = 2; i <= (n << 1); ++i) lg2[i] = lg2[i>>1] + 1;
for(int j = 1; j <= 21; ++j)
for(int i = 1; i + (1 << j) - 1 <= (n << 1); ++i)
st[i][j] = dep[st[i][j-1]] < dep[st[i+(1<<(j-1))][j-1]] ? st[i][j-1] : st[i+(1<<(j-1))][j-1];
}
int LCA(int u, int v) {
int x = dfn[u], y = dfn[v];
if(x > y) swap(x, y);
int k = lg2[y-x+1];
return dep[st[x][k]] < dep[st[y-(1<<k)+1][k]] ? st[x][k] : st[y-(1<<k)+1][k];
}
void RC(int l1, int r1, int l2, int r2, int id) {//将区间询问转化为几个子问题的容斥
q[++m] = (Qry) {r1, r2, id, 1}; q[++m] = (Qry) {l1-1, l2-1, id, 1};
q[++m] = (Qry) {r1, l2-1, id, -1}; q[++m] = (Qry) {l1-1, r2, id, -1};
}
void devide(int x) { //处理x的子树对应的区间
if(x == rt) tl[++_c] = 1, tr[_c] = n;
else {
int z = LCA(x, rt);
if(z != x) tl[++_c] = in[x], tr[_c] = out[x];
else {
int y = find(x, rt);
if(1 <= in[y]-1) tl[++_c] = 1; tr[_c] = in[y]-1;
if(out[y]+1 <= n) tl[++_c] = out[y]+1, tr[_c] = n;
}
}
}
void build(int x, int y, int id) {//将两个子树问题转化为区间问题
_c = 0; devide(x); int mid = _c; devide(y);
for(int i = 1; i <= mid; ++i)
for(int j = mid+1; j <= _c; ++j)
RC(tl[i], tr[i], tl[j], tr[j], id);
}
void add(int x, int p) {ima += cnt[c[x]][p^1]; ++cnt[c[x]][p];}
void del(int x, int p) {ima -= cnt[c[x]][p^1]; --cnt[c[x]][p];}
int main() {
read(n); read(Q); SIZ = sqrt(n); memset(h, -1, sizeof h);
for(int i = 1; i <= n; ++i) read(a[i]), o[i] = a[i];
sort(o+1, o+n+1); int _n = unique(o+1, o+n+1)-o-1;
for(int i = 1; i <= n; ++i) a[i] = lower_bound(o+1, o+_n+1, a[i])-o;
for(int u, v, i = 1; i < n; ++i) read(u), read(v), addedge(u, v), addedge(v, u);
init(); rt = 1;
for(int opt, x, y, i = 1; i <= Q; ++i) {
read(opt); read(x);
if(opt == 1) rt = x;
else is_q[i] = 1, read(y), build(x, y, i);
}
for(int i = 1; i <= m; ++i) if(q[i].l > q[i].r) swap(q[i].l, q[i].r);
sort(q+1, q+m+1, cmp);
for(int L = 0, R = 0, i = 1; i <= m; ++i) {
int l = q[i].l, r = q[i].r;
while(L < l) add(++L, 0); while(L > l) del(L--, 0);
while(R < r) add(++R, 1); while(R > r) del(R--, 1);
ans[q[i].id] += ima * q[i].f;
}
for(int i = 1; i <= Q; ++i) if(is_q[i]) write(ans[i]), putchar('\n');
}
```