P6177 Count on a tree II/【模板】树分块 题解
这题调了六个来小时调自闭了…准确来说还是因为我树分块板子不够完善。
写完之后发现题解区没有我的做法,甚至
思路是 Top_Cluster + 树上前缀和,离散化后预处理出每个界点到其他每个点的答案以及前缀值计数,然后散块内暴力。
其实 critnos 大佬的题解中有和我类似的思路,但是他用的是不够优美的树上撒点分块,而且目测是口胡,这题实现的时候细节特别多。
不妨设
-
对于
u 和v 属于同一个簇的情况,我们特判一下直接暴力跳 LCA,时间复杂度O(\sqrt{n}) 。 -
否则求出
u 和v 的 LCAw ,显然这个点是一个界点。如果w 和v 属于同一个簇,那么我们先记录u 和v 所属于的簇之间的答案,然后分别从u ,v ,\text{dwbn}_v (指v 所在簇的下界点)往上跳,时间复杂度O(\sqrt{n}) 。 -
否则让
u 和v 都跳到自己的上界点,并统计两个上界点之间的答案。
然后用 Top_Cluster 求 LCA 可能很多人没有写过,下面放一下我的代码,可以自己理解一下(这里就体现出 Top_Cluster 的优美之处了:两个簇之间只会有一个交点,所以不用另外再写 LCA)。
int get_lca(int u, int v) {
while (u ^ v) {
if (dep[u] < dep[v]) swap(u, v);
if (id[u] && dwbn[v] != u) u = CLfa[u];
else u = fa[u];
} return u;
}
最后是总代码,写了将近 300 行,没有卡常效率还不错,不过 cache miss 比较严重。调整下循环顺序和代码逻辑有希望跑进最优解(不过我调疯了不想卡了…)。
#include<cstdio>
#include<cctype>
#include<cmath>
#include<ctime>
#include<vector>
#include<algorithm>
#include<initializer_list>
using namespace std;
namespace Main {
#define vi vector<int>
#define ld long double
#define rep(i, l, r) for(int i(l), END##i(r); i <= END##i; ++ i)
#define per(i, r, l) for(int i(r), END##i(l); i >= END##i; -- i)
template<class T>
inline void cmin(T& x, const T& y) { x = min(x, y); }
template<class T>
inline void cmax(T& x, const T& y) { x = max(x, y); }
namespace Fast_OI {
char buf[1000000], *p1 = buf, *p2 = buf, obuf[1000000], *p3 = obuf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1000000,stdin),p1==p2)?EOF:*p1++)
#define putchar(x) (p3-obuf<1000000?*p3++=x:(fwrite(obuf,1,p3-obuf,stdout),p3=obuf,*p3++=x))
int read() {
int x = 0; bool f = 1; char c = getchar();
while (!isdigit(c)) { if (c == '-') f = 0; c = getchar(); }
while (isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
return f ? x : -x;
}
void puts(const char* str, bool nw_line = 1) {
while (*str != '\0')
putchar(*str), ++ str;
if (nw_line) putchar('\n');
}
void write(int x) {
if (x < 0) putchar('-'), x = -x;
if (x > 9) write(x / 10);
putchar(x % 10 + 48);
}
void flush() { fwrite(obuf,1,p3-obuf,stdout); }
} using namespace Fast_OI;
const int N = 4e4 + 10, B = 260;
const int Bnum = 6 * N / B + 5;
int n, Q, val[N];
struct edge {
int to, nex;
}e[N << 1]; int idx, head[N];
inline void add(int u, int v) {
e[++ idx].to = v;
e[idx].nex = head[u];
head[u] = idx;
}
inline void add_edge(int u, int v) {
add(u, v); add(v, u);
}
void Dscrt() {
vi tmp;
rep(i, 1, n) tmp.emplace_back(val[i]);
sort(tmp.begin(), tmp.end());
tmp.erase(unique(tmp.begin(), tmp.end()), tmp.end());
rep(i, 1, n)
val[i] = lower_bound(tmp.begin(), tmp.end(), val[i]) - tmp.begin() + 1;
}
int fa[N], dep[N];
void dfs_prew(int u) {
dep[u] = dep[fa[u]] + 1;
for (int i = head[u]; i; i = e[i].nex) {
int v = e[i].to;
if (v == fa[u]) continue;
fa[v] = u; dfs_prew(v);
}
}
class Top_Cluster {
public:
int CLfa[N], nrd[N];
int upbn[N], dwbn[N];
int BNct, id[N];
vi BN, clpt[N];
private :
vi cur_cl;
void New_CL(int up, int dw) {
if (cur_cl.empty()) return;
if (!dw) dw = cur_cl.back();
if (!id[dw]) ++ BNct, BN.emplace_back(dw);
if (up ^ dw) CLfa[dw] = up; nrd[up] = up;
for (int u = dw; u ^ up; u = fa[u])
nrd[u] = u;
for (const int& u : cur_cl) {
upbn[u] = up, dwbn[u] = dw;
clpt[dw].emplace_back(u);
int y = u; while (!nrd[y]) y = fa[y];
nrd[u] = nrd[y];
} cur_cl.clear();
}
int stc[N], stctop;
int rec_top[N];
int udfct[N], lwBN[N];
void dfs(int u) {
rec_top[u] = stctop;
udfct[u] = 1;
int BN_cnt = 0;
for (int i = head[u]; i; i = e[i].nex) {
int v = e[i].to;
if (v == fa[u]) continue;
stc[++ stctop] = v;
dfs(v); udfct[u] += udfct[v];
if (lwBN[v]) lwBN[u] = lwBN[v], ++ BN_cnt;
}
if (udfct[u] > B || BN_cnt > 1 || !fa[u]) {
udfct[u] = 0; lwBN[u] = u;
int p = rec_top[u] + 1, cnt = 0, cur_down = 0;
auto reset = [&](const int& v) -> void {
while (p <= stctop && (!v || p < rec_top[v]))
cur_cl.emplace_back(stc[p ++]);
New_CL(u, cur_down); cnt = cur_down = 0;
}; for (int i = head[u]; i; i = e[i].nex) {
int v = e[i].to;
if (v == fa[u]) continue;
if (cnt + udfct[v] > B || (cur_down && lwBN[v]))
reset(v);
cnt += udfct[v]; if (lwBN[v]) cur_down = lwBN[v];
} reset(0); ++ BNct; BN.emplace_back(u);
stctop = rec_top[u];
}
}
private :
int sum[Bnum][N];
int res[Bnum][N];
int tmpct, tmpsum[N];
void dfs_calc(int u, int idx, int fath) {
if (!tmpsum[val[u]] ++) ++ tmpct;
res[idx][u] = tmpct;
for (int i = head[u]; i; i = e[i].nex) {
int v = e[i].to;
if (v == fath) continue;
dfs_calc(v, idx, u);
} if (!-- tmpsum[val[u]]) -- tmpct;
}
void Prew() {
rep(i, 1, BNct) {
int u = BN[i];
sum[i][val[u]] = 1;
for (u = fa[u]; u && !id[u]; u = fa[u])
++ sum[i][val[u]];
rep(j, 1, n)
sum[i][j] += sum[id[u]][j];
// printf("#------------ prew : %d, %d\n", BN[i], CLfa[i]);
// rep(j, 1, n) printf ("%d ", sum[i][j]);
// puts("");
dfs_calc(BN[i], i, 0);
}
}
int find_nrBN(int u) {
while (!id[u]) u = fa[u];
return u;
}
int get_lca(int u, int v) {
while (u ^ v) {
if (dep[u] < dep[v]) swap(u, v);
if (id[u] && dwbn[v] != u) u = CLfa[u];
else u = fa[u];
} return u;
}
public :
void Build(int rt = 1) {
dfs(rt);
cur_cl.emplace_back(rt);
New_CL(1, 1);
BN.emplace_back(0);
reverse(BN.begin(), BN.end());
rep(i, 1, BNct) id[BN[i]] = i;
// puts("----------");
// rep(i, 1, n) printf ("%d : %d %d\n", i, upbn[i], dwbn[i]);
// puts("----------");
Prew();
}
int Query(int u, int v) {
// fprintf (stderr, "query %d %d\n", u, v);
static bool apl[N]; int ans = 0;
int fu = find_nrBN(u), fv = find_nrBN(v);
if (dep[fu] < dep[fv]) swap(u, v), swap(fu, fv);
if (fu == fv) {
// fprintf(stderr, "#1\n");
int ans = 0;
int recu = u, recv = v;
while (u ^ v) {
if (dep[u] < dep[v])
swap(u, v);
if (!apl[val[u]]) apl[val[u]] = 1, ++ ans;
u = fa[u];
} if (!apl[val[u]]) ++ ans;
u = recu, v = recv;
while (u ^ v) {
if (dep[u] < dep[v])
swap(u, v);
apl[val[u]] = 0;
u = fa[u];
}
return ans;
}
int fw = get_lca(fu, fv);
if (fv == fw) {
// fprintf(stderr, "#2\n");
int dwv = dwbn[v];
int lca = get_lca(u, v);
if (lca == upbn[v]) dwv = upbn[v];
int ans = res[id[fu]][v];
// fprintf (stderr, "# %d %d & (%d , %d) : %d %d\n", u, v, fu, fv, dwv, lca);
int t = v;
while (t ^ lca) apl[val[t]] = 1, t = fa[t];
t = dwv;
while (t ^ lca) apl[val[t]] = 1, t = fa[t];
apl[val[lca]] = 1;
t = u;
while (t ^ fu) {
if (!apl[val[t]] && sum[id[fu]][val[t]] - sum[id[dwv]][val[t]] == 0)
++ ans, apl[val[t]] = 1;
t = fa[t];
}
t = v;
while (t ^ lca) apl[val[t]] = 0, t = fa[t];
t = dwv;
while (t ^ lca) apl[val[t]] = 0, t = fa[t];
apl[val[lca]] = 0;
t = u;
while (t ^ fu) apl[val[t]] = 0, t = fa[t];
return ans;
}
// fprintf(stderr, "#3\n");
apl[val[fw]] = 1;
ans = res[id[fu]][fv];
int recu = u, recv = v;
rep(i, 0, 1) {
// fprintf (stderr, "%d & %d\n", u, fu);
while (u != fu) {
if (!apl[val[u]] && !(sum[id[fu]][val[u]] + sum[id[fv]][val[u]] - 2 * sum[id[fw]][val[u]]))
apl[val[u]] = 1, ++ ans; // fprintf (stderr, "(%d, %d)\n", u, val[u]);
u = fa[u];
} swap(u, v); swap(fu, fv);
}
u = recu, v = recv;
rep(i, 0, 1) {
while (u != fu) apl[val[u]] = 0, u = fa[u];
swap(u, v); swap(fu, fv);
} apl[val[fw]] = 0;
return ans;
}
}tcl;
void ERoRain() {
n = read(), Q = read();
rep(i, 1, n) val[i] = read();
Dscrt(); //rep(i, 1, n) printf ("%d ", val[i]); puts("val");
rep(i, 1, n - 1) {
int u = read(), v = read();
add_edge(u, v);
} dfs_prew(1);
tcl.Build();
int lstans = 0;
while (Q --) {
int u = read(), v = read();
u ^= lstans;
write(lstans = tcl.Query(u, v)), puts("");
// fprintf (stderr, "getans : %d\n", lstans);
}
}
signed main() {
ld start_time = clock();
int T = 1;
while (T --) ERoRain();
flush();
fprintf(stderr, "Time : %Lfs\n", (clock() - start_time) / CLOCKS_PER_SEC);
return 0;
}
} signed main() {
freopen("a.in", "r", stdin);
freopen("a.out", "w", stdout);
Main::main();
return 0;
}