题解:CF1797F Li Hua and Path

· · 题解

CF1797F Li Hua and Path

前言

模拟赛赛时打了一半,拿了 q=0 的。2e5不就是给双 log 过的吗

正文

首先考虑解决 q=0 之后怎么做。添加一个点,新的贡献也一定是由这个点贡献的,并且是作为最大值(新点编号一定比原来的大)。

考虑求不合法方案数。如果 fa>n,那么新的不合法方案就是 fa 被添加时产生新的不合法方案 +1。设 f_u 表示点 u 被添加时产生新的不合法方案,则可以表示为 f_u=f_{fa}+1

如果 fa\le n,那么新的不合法方案就是在原树上满足 v=\min(\text{Road(fa,v)})v 的个数。

那为什么只用算原树上的点 v?因为新加的点编号一定比原树上的大,不可能在穿越原树后作为最小值。

而加入点 u 后产生的新贡献就是 n+u-1-f_u

前者可以 O(1) 计算,而后者可以每次 O(n) 计算,于是得到了 O(qn) 的做法。

接下来考虑如何优化以及处理 q=0

考虑点分治。设 w_1 表示其中一端为最小值的方案数,w_2 表示其中一端为最大值的方案数,w 表示一端为最小值一端为最大值的方案数。而两端至少有一个为最小值/最大值就等于 w_1+w_2-w(简单容斥),减去两端都合法的就是 w_1+w_2-2w

mn_u 表示从分治中心到 u 的最小点编号,mx_u 表示从分治中心到 u 的最大点编号。则 w_1 即为满足 mn_u=u,mn_u<mn_vu,v 在不同子树内的点对 (u,v) 的数量,就是一维偏序减整体减去子树内贡献,直接排序很好做。w_2 即为满足 mx_u=u,mx_u>mx_vu,v 在不同子树内的点对 (u,v) 的数量,与上面类似。重点即为如何求 w

然后就是求 $f_u$。$v$ 对 $u$ 有贡献当且仅当 $mn_u>mn_v,mn_v=v$。这依旧是一个一维偏序,和之前的排序一起做即可。 注意事项: - $mn_u,mx_u$ 在分治中心处理时可能出错,建议单独拿出来处理。 - 存点集的 vector 要记得清空。 ## AC Code ```cpp #include <bits/stdc++.h> using namespace std; const int N = 5e5 + 5 , inf = 1e6; int n; vector <int> g[N]; bool vis[N]; int CNT; namespace Find_Root { int sz , s[N] , w[N] , rt; inline void dfs (int u , int fa = 0) { if (vis[u]) { s[u] = w[u] = 0; return ; } s[u] = 1 , w[u] = 0; for (int v : g[u]) if (v != fa) { dfs (v , u); s[u] += s[v]; w[u] = max (w[u] , s[v]); } w[u] = max (w[u] , sz - s[u]); if (w[u] <= sz / 2) rt = u; } } struct BIT { int tr[N] = {0}; vector <int> S; inline void upd (int p , int x) { p = n - p + 1; S.push_back (p); while (p <= n) tr[p] += x , p += p & -p; } inline int qry (int p) { p = n - p + 1; int ans = 0; while (p) ans += tr[p] , p -= p & -p; return ans; } inline void clr (int p) { while (tr[p] && p <= n) tr[p] = 0 , p += p & -p; } inline void clr () { for (int p : S) clr (p); S.clear(); S.shrink_to_fit(); } } T; long long ans; int f[N] , F[N] , cnt[N] , top[N]; vector <int> S , S2[N]; inline void add (int u , int mn , int mx , int fa = 0) { if (vis[u]) return ; if (fa) top[u] = (top[fa] ? top[fa] : u); f[u] = mx , F[u] = mn; S.push_back (u); for (int v : g[u]) if (v != fa) add (v , min (mn , v) , max (mx , v) , u); } inline bool cmp1 (int x , int y) { if (F[x] == F[y]) return x > y; return F[x] > F[y]; } inline bool cmp2 (int x , int y) { if (f[x] == f[y]) return x < y; return f[x] < f[y]; } long long FMin[N << 1]; inline void divide (int u , int sz) { if (sz == 1) return ; Find_Root::sz = sz; S.reserve (sz); Find_Root::dfs (u); u = Find_Root::rt; top[u] = 0; add (u , u , u); sort (S.begin () , S.end () , cmp1); int l = 0 , cc = 0; for (int p : S) { if (p != u) S2[top[p]].push_back (p); while (F[S[l]] > F[p]) { if (f[S[l]] == S[l] && u != S[l]) T.upd (S[l] , 1); l ++; } if (F[p] == p && u != p) ans -= T.qry (f[p] + 1) << 1; } cc = S.size (); l = cc - 1; int Cnt = 0; for (int i = cc - 1;i >= 0;i --) if (S[i] != u) { while (F[S[l]] < F[S[i]]) Cnt += (F[S[l]] == S[l] && S[l] != u) , l --; FMin[S[i]] += Cnt; } for (int v : g[u]) if (!vis[v]) { T.clr (); l = 0; for (int p : S2[v]) { while (F[S2[v][l]] > F[p]) { if (f[S2[v][l]] == S2[v][l]) T.upd (S2[v][l] , 1); l ++; } if (F[p] == p && v != p) ans += T.qry (f[p] + 1) << 1; } cc = S2[v].size (); l = cc - 1; Cnt = 0; for (int i = cc - 1;i >= 0;i --) { while (F[S2[v][l]] < F[S2[v][i]]) Cnt += (F[S2[v][l]] == S2[v][l]) , l --; FMin[S2[v][i]] -= Cnt; } S2[v].clear (); S2[v].shrink_to_fit(); } for (int p : S) if (u != p) { if (F[p] == p) FMin[u] ++; if (F[p] == u) FMin[p] ++; if (f[p] == u && F[p] == p) ans -= 2; if (F[p] == u && f[p] == p) ans -= 2; } l = 0 , cc = 0; for (int p : S) { while (F[S[l]] > F[p]) cnt[top[S[l]]] ++ , l ++; if (p == u) ans += cc; else if (F[p] == p) ans += l - cnt[top[p]]; cc ++; } for (int p : S) cnt[top[p]] = 0; sort (S.begin () , S.end () , cmp2); l = 0 , cc = 0; for (int p : S) { while (f[S[l]] < f[p]) cnt[top[S[l]]] ++ , l ++; if (p == u) ans += cc; else if (f[p] == p) ans += l - cnt[top[p]]; cc ++; } for (int p : S) cnt[top[p]] = 0; S.clear (); S.shrink_to_fit(); T.clr (); vis[u] = 1; for (int v : g[u]) if (!vis[v]) divide (v , sz - Find_Root::w[v]); } signed main () { ios::sync_with_stdio (0); cin.tie (0) , cout.tie (0); cin >> n; for (int i = 1;i < n;i ++) { int u , v; cin >> u >> v; g[u].push_back (v); g[v].push_back (u); } divide (1 , n); for (int i = 1;i <= n;i ++) FMin[i] ++; cout << ans << '\n'; int q; cin >> q; for (int i = 1;i <= q;i ++) { int u; cin >> u; long long w = ans; if (u > n) ans -= FMin[u] + 1; else ans -= FMin[u]; FMin[i + n] = w - ans; ans += i + n - 1; cout << ans << '\n'; } return 0; } ```