题解:P9706 「TFOI R1」Ride the Wind and Waves
lzyqwq
·
·
题解
这个题 800 吧。
快进到我们需要求出每个点 $u$ 为根的子树内,到 $u$ 边数 $\ge k$(换根还要用到 $k-1$ 的信息)的点的个数以及路径的边权和。
以边权和为例。朴素的 dp 是转化成求 $\le k$ 的信息,由于 $k\le 10$,记录每个值的答案,从子树合并上来,就是 $\mathcal{O}(nk)$。
这样难以优化。考虑每个点 $u$ 能对它哪些祖先 $v$ 的信息产生贡献,显然要求 $u$ 到 $v$ 简单路径的边数 $\ge k$,那么就是 $k$ 级祖先根链上的所有点 $v$,受到 $u$ 的贡献为 $d_u-d_v$。其中 $d_u$ 表示 $u$ 到这棵子树环上的根的路径的边权和。
把贡献写成 $p_vd_v+q_v$ 的形式,相当于根链 $p_v,q_v$ 加。树上差分维护即可。
现在压力给到 $k$ 级祖先,考虑 dfs 的时候维护根链序列 $a$。假设 dfs 到 $u$ 的时候序列的长度为 $m$,那么 $k$ 级祖先就是 $a_{m-k}$。这个点 dfs 结束的时候就从末尾删除。可以发现这个过程就是实现一个栈。
这样时间复杂度就是 $\mathcal{O}(n)$。
因为写的比较省流所以放个代码。
```cpp
#include <bits/stdc++.h>
#define eb emplace_back
using namespace std; typedef long long ll;
const int N = 1000005, K = 12;
int n, k, to[N], rd[N], cir[N << 1], cnt, dep[N], cn;
vector<pair<int, ll>> g[N];
bool vis[N], onc[N];
ll dw[N], ans[N], F[N], tov[N], dis[N << 1], s1[N], s2[N], s3[N], G[N];
ll kf[N], bf[N], bs[N], kg[N], bg[N], bt[N];
int arr[N];
void dfs1(int u) {
arr[++cn] = u;
if (cn >= k) {
--kf[arr[cn - k]]; bf[arr[cn - k]] += dw[u];
--kg[arr[cn - k + 1]]; bg[arr[cn - k + 1]] += dw[u];
++bt[arr[cn - k]]; ++bs[arr[cn - k + 1]];
}
for (auto [v, w] : g[u]) {
if (!onc[v]) {
dep[v] = dep[u] + 1; dw[v] = dw[u] + w; dfs1(v);
}
}
--cn;
}
void dfs3(int u) {
for (auto [v, w] : g[u])
if (!onc[v]) {
dfs3(v);
kf[u] += kf[v]; bf[u] += bf[v];
kg[u] += kg[v]; bg[u] += bg[v];
bs[u] += bs[v]; bt[u] += bt[v];
}
F[u] = kf[u] * dw[u] + bf[u];
G[u] = kg[u] * dw[u] + bg[u];
//cout << u << ' ' << bs[u] << ' ' << bt[u] << '\n';
}
void dp(int u, ll cur, ll sum, ll val) {
ans[u] = cur + val * sum;
for (auto [v, w] : g[u]) {
if (!onc[v]) {
ll cur_ = ans[u];
ll sum_ = sum + F[u] - G[v] - w * bs[v];
dp(v, cur_, sum_, w);
}
}
}
void dfs2(int u, ll x, ll y) {
ans[u] += x + dw[u] * y;
for (auto [v, w] : g[u]) if (!onc[v]) dfs2(v, x, y);
}
int main() {
cin.tie(0); cout.tie(0); ios::sync_with_stdio(0); cin >> n >> k;
for (int i = 1, u, v, w; i <= n; ++i)
cin >> u >> v >> w, ++rd[v], g[v].eb(u, w), to[u] = v, tov[u] = w;
bool flg = 0;
for (int i = 1; i <= n; ++i) {
if (!rd[i]) {
vector<int> tmp;
int u = i;
while (1) {
tmp.eb(u);
if (vis[to[u]]) {
bool fg = 0;
for (int j : tmp) {
if (j == to[u]) fg = 1;
if (fg) cir[++cnt] = j, onc[j] = 1, dis[cnt] = tov[j];
}
break;
} else u = to[u], vis[u] = 1;
}
flg = 1; break;
}
}
// for (int i = 1; i <= cnt; ++i) cout << cir[i] << '\n';
if (!flg) {
for (int i = 1; i <= n; ++i) cout << "0\n"; return 0;
}
// for (int i = 1; i <= cnt; ++i) cout << cir[i] << '\n';
for (int i = 1; i <= cnt; ++i) {
dfs1(cir[i]); dfs3(cir[i]); dp(cir[i], 0, 0, 0);
}
//for (int i = 1; i <= n; ++i) cout << bs[i] << ' ' << bt[i] << '\n';
for (int i = 1; i <= cnt; ++i) cir[i + cnt] = cir[i], dis[i + cnt] = dis[i];
for (int i = cnt * 2; i >= 1; --i)
dis[i] += dis[i + 1],
s1[i] = s1[i + 1] + F[cir[i]],
s2[i] = s2[i + 1] + dis[i] * F[cir[i]];
for (int i = 1; i <= cnt; ++i)
dfs2(cir[i], dis[i] * (s1[i + 1] - s1[i + cnt]) - (s2[i + 1] - s2[i + cnt]), s1[i + 1] - s1[i + cnt]);
for (int i = 1; i <= n; ++i) cout << ans[i] << '\n';
return 0;
}
```