题解 P4004 【Hello world!】

· · 题解

这题初一看没有什么思路,后来仔细思考了一下,发现可以乱搞!

就是在lca上乱搞!考虑一种分块!我们既k为跳的步数,按k的余数分情况考虑

然后均摊一下即可!我也不造数据为什么这么强

int solve(int x, int k) {
    if (k > m) return up(x, k);
    int y = find(fa(x));
    return up(y, (k - (depth[x] - depth[y]) % k) % k);
}
int solve2(int x, int y, int f, int k) {
    if (depth[y] - depth[f] >= k) return up(y, k);
    return up(x, depth[x] + depth[y] - (depth[f] << 1) - k);
}

这里是重点!如果大于k剪枝!

然后分块处理询问!

void update(int x, int y, int k) {
    int f = lca(x, y), len = depth[x] + depth[y] - (depth[f] << 1);
    if (len % k) upd(y), y = solve2(x, y, f, len % k), f = lca(x, y);
    while (depth[x] >= depth[f]) upd(x), x = solve(x, k);
    while (depth[y] > depth[f]) upd(y), y = solve(y, k);
}
ll query(int x, int y, int k) {
    int f = lca(x, y), len = depth[x] + depth[y] - (depth[f] << 1);
    ll res = 0;
    if (len % k) res += a[y], y = solve2(x, y, f, len % k), f = lca(x, y);
    res += (depth[x] + depth[y] - (depth[f] << 1)) / k + 1;
    while (depth[x] >= depth[f]) res += a[x] - 1, x = solve(x, k);
    while (depth[y] > depth[f]) res += a[y] - 1, y = solve(y, k);
    return res; 
}

这样还是可以均摊!

只是一个按k均摊,一个按sqrt(n)均摊!

然后我们只需求下lca

同是要回用二分思想求up(x,k)

就切了,这么简单吧! 代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 50005, M = (N << 1), K = 255;
ll rd() {
    char c = getchar(); ll x = 0ll, f = 1ll;
    for (; !isdigit(c); c = getchar()) if (c == '-') f = -1;
    for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
    return x * f;
}
int n, m, q, ecnt = 0, fir[N], to[M], nxt[M], fa[N][16], nx[N][K], depth[N], f[N];
ll a[N];
#define fa(x) fa[x][0]
void ae(int u, int v) {to[++ecnt] = v; nxt[ecnt] = fir[u]; fir[u] = ecnt;}
void dfs(int u, int f) {
    int i; nx[u][0] = u, fa[u][0] = nx[u][1] = f, depth[u] = depth[f] + 1;
    for (i = 1; i <= 15; ++i) fa[u][i] = fa[fa[u][i - 1]][i - 1];
    for (i = 2; i <= m; ++i) nx[u][i] = nx[f][i - 1];
    for (i = fir[u]; i; i = nxt[i]) {
        int v = to[i];
        if (v != f) dfs(v, u);
    }
}
int up(int x, int k) {
    int i; if (k <= m) return nx[x][k];
    for (i = 15; ~i; --i) if (k >= (1 << i)) x = fa[x][i], k -= 1 << i;
    return x;
}
int lca(int u, int v) {
    int i; if (depth[u] < depth[v]) swap(u, v);
    for (i = 15; ~i; --i) if (depth[fa[u][i]] >= depth[v]) u = fa[u][i];
    if (u == v) return u;
    for (i = 15; ~i; --i) if (fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i];
    return fa[u][0];
}
int find(int u) {return f[u] = f[u] == u ? u : find(f[u]);}
void upd(int x) {
    if (a[x] == 1) return;
    a[x] = sqrt(a[x]);
    if (a[x] == 1) f[x] = find(fa(x));
} 
int solve(int x, int k) {
    if (k > m) return up(x, k);
    int y = find(fa(x));
    return up(y, (k - (depth[x] - depth[y]) % k) % k);
}
int solve2(int x, int y, int f, int k) {
    if (depth[y] - depth[f] >= k) return up(y, k);
    return up(x, depth[x] + depth[y] - (depth[f] << 1) - k);
}
void update(int x, int y, int k) {
    int f = lca(x, y), len = depth[x] + depth[y] - (depth[f] << 1);
    if (len % k) upd(y), y = solve2(x, y, f, len % k), f = lca(x, y);
    while (depth[x] >= depth[f]) upd(x), x = solve(x, k);
    while (depth[y] > depth[f]) upd(y), y = solve(y, k);
}
ll query(int x, int y, int k) {
    int f = lca(x, y), len = depth[x] + depth[y] - (depth[f] << 1);
    ll res = 0;
    if (len % k) res += a[y], y = solve2(x, y, f, len % k), f = lca(x, y);
    res += (depth[x] + depth[y] - (depth[f] << 1)) / k + 1;
    while (depth[x] >= depth[f]) res += a[x] - 1, x = solve(x, k);
    while (depth[y] > depth[f]) res += a[y] - 1, y = solve(y, k);
    return res; 
}
int main() {
    int i; n = rd(); m = sqrt(n);
    for (i = 1; i <= n; ++i) a[i] = rd();
    for (i = 1; i < n; ++i) {
        int u = rd(), v = rd();
        ae(u, v); ae(v, u);
    }
    dfs(1, 0);
    for (i = 1; i <= n; ++i) {
        f[i] = i;
        if (a[i] == 1) f[i] = fa(i);
    }
    q = rd();
    while (q--) {
        int opt = rd(), x = rd(), y = rd(), k = rd();
        if (opt) printf("%lld\n", query(x, y, k));
        else update(x, y, k);
    }
    return 0;
}