题解 P4004 【Hello world!】
这题初一看没有什么思路,后来仔细思考了一下,发现可以乱搞!
就是在lca上乱搞!考虑一种分块!我们既k为跳的步数,按k的余数分情况考虑
然后均摊一下即可!我也不造数据为什么这么强
int solve(int x, int k) {
if (k > m) return up(x, k);
int y = find(fa(x));
return up(y, (k - (depth[x] - depth[y]) % k) % k);
}
int solve2(int x, int y, int f, int k) {
if (depth[y] - depth[f] >= k) return up(y, k);
return up(x, depth[x] + depth[y] - (depth[f] << 1) - k);
}
这里是重点!如果大于k剪枝!
然后分块处理询问!
void update(int x, int y, int k) {
int f = lca(x, y), len = depth[x] + depth[y] - (depth[f] << 1);
if (len % k) upd(y), y = solve2(x, y, f, len % k), f = lca(x, y);
while (depth[x] >= depth[f]) upd(x), x = solve(x, k);
while (depth[y] > depth[f]) upd(y), y = solve(y, k);
}
ll query(int x, int y, int k) {
int f = lca(x, y), len = depth[x] + depth[y] - (depth[f] << 1);
ll res = 0;
if (len % k) res += a[y], y = solve2(x, y, f, len % k), f = lca(x, y);
res += (depth[x] + depth[y] - (depth[f] << 1)) / k + 1;
while (depth[x] >= depth[f]) res += a[x] - 1, x = solve(x, k);
while (depth[y] > depth[f]) res += a[y] - 1, y = solve(y, k);
return res;
}
这样还是可以均摊!
只是一个按k均摊,一个按sqrt(n)均摊!
然后我们只需求下lca
同是要回用二分思想求up(x,k)
就切了,这么简单吧!
代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 50005, M = (N << 1), K = 255;
ll rd() {
char c = getchar(); ll x = 0ll, f = 1ll;
for (; !isdigit(c); c = getchar()) if (c == '-') f = -1;
for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
return x * f;
}
int n, m, q, ecnt = 0, fir[N], to[M], nxt[M], fa[N][16], nx[N][K], depth[N], f[N];
ll a[N];
#define fa(x) fa[x][0]
void ae(int u, int v) {to[++ecnt] = v; nxt[ecnt] = fir[u]; fir[u] = ecnt;}
void dfs(int u, int f) {
int i; nx[u][0] = u, fa[u][0] = nx[u][1] = f, depth[u] = depth[f] + 1;
for (i = 1; i <= 15; ++i) fa[u][i] = fa[fa[u][i - 1]][i - 1];
for (i = 2; i <= m; ++i) nx[u][i] = nx[f][i - 1];
for (i = fir[u]; i; i = nxt[i]) {
int v = to[i];
if (v != f) dfs(v, u);
}
}
int up(int x, int k) {
int i; if (k <= m) return nx[x][k];
for (i = 15; ~i; --i) if (k >= (1 << i)) x = fa[x][i], k -= 1 << i;
return x;
}
int lca(int u, int v) {
int i; if (depth[u] < depth[v]) swap(u, v);
for (i = 15; ~i; --i) if (depth[fa[u][i]] >= depth[v]) u = fa[u][i];
if (u == v) return u;
for (i = 15; ~i; --i) if (fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i];
return fa[u][0];
}
int find(int u) {return f[u] = f[u] == u ? u : find(f[u]);}
void upd(int x) {
if (a[x] == 1) return;
a[x] = sqrt(a[x]);
if (a[x] == 1) f[x] = find(fa(x));
}
int solve(int x, int k) {
if (k > m) return up(x, k);
int y = find(fa(x));
return up(y, (k - (depth[x] - depth[y]) % k) % k);
}
int solve2(int x, int y, int f, int k) {
if (depth[y] - depth[f] >= k) return up(y, k);
return up(x, depth[x] + depth[y] - (depth[f] << 1) - k);
}
void update(int x, int y, int k) {
int f = lca(x, y), len = depth[x] + depth[y] - (depth[f] << 1);
if (len % k) upd(y), y = solve2(x, y, f, len % k), f = lca(x, y);
while (depth[x] >= depth[f]) upd(x), x = solve(x, k);
while (depth[y] > depth[f]) upd(y), y = solve(y, k);
}
ll query(int x, int y, int k) {
int f = lca(x, y), len = depth[x] + depth[y] - (depth[f] << 1);
ll res = 0;
if (len % k) res += a[y], y = solve2(x, y, f, len % k), f = lca(x, y);
res += (depth[x] + depth[y] - (depth[f] << 1)) / k + 1;
while (depth[x] >= depth[f]) res += a[x] - 1, x = solve(x, k);
while (depth[y] > depth[f]) res += a[y] - 1, y = solve(y, k);
return res;
}
int main() {
int i; n = rd(); m = sqrt(n);
for (i = 1; i <= n; ++i) a[i] = rd();
for (i = 1; i < n; ++i) {
int u = rd(), v = rd();
ae(u, v); ae(v, u);
}
dfs(1, 0);
for (i = 1; i <= n; ++i) {
f[i] = i;
if (a[i] == 1) f[i] = fa(i);
}
q = rd();
while (q--) {
int opt = rd(), x = rd(), y = rd(), k = rd();
if (opt) printf("%lld\n", query(x, y, k));
else update(x, y, k);
}
return 0;
}