题解:P12480 [集训队互测 2024] Classical Counting Problem

· · 题解

[Luogu P12480]/[QOJ9533] Classical Counting Problem

pro.

- 选择当前树上编号最大或最小的点 $u$,删去 $u$ 及其连边,保留任意一个连通块作为操作后的树。 令 $min$ 为树上所有节点编号的最小值,$max$ 为树上所有节点编号的最大值,$size$ 为树上的节点个数,则一棵树的权值为 $min\cdot max\cdot size$。 求所有能通过操作得到的非空树的权值和。 $n\le1e5$。$\mathrm{3s,1024MB}$。 ### sol. 读完题后发现没有明显的多项式做法,考虑寻找性质。 ~~不很~~容易发现:一对合法的 $(min,max)$ 可以确定唯一的一棵树。 证明:一对点 $(u,v)$ 能作为一棵合法树的 $(min,max)$ 当且仅当 $u$ 到 $v$ 的路径上的所有点都在 $[u,v]$ 区间内,然后在这条路径上不断加入在 $[u,v]$ 区间内且与当前联通块联通的点,即可得到 $(u,v)$ 对应的树 $T$,显然树 $T$ 的形态与加点顺序无关,故每个合法 $(u,v)$ 对应唯一的一棵树。而对于一棵树 $T$ ,设其最小值和最大值分别为 $(u,v)$,由于只对最大值或最小值操作,故再进行操作显然不对应 $(u,v)$ 。故合法 $(min,max)$ 与合法的树一一对应。 于是可以枚举 $min,max$ ,然后加入 $[min,max]$ 区间内的点,判断是否联通并统计答案。暴力实现 $\mathcal{O}(n^3)$ ,固定 $min$ 或 $max$ 再依次加入点,用并查集维护连通性,即可实现 $\mathcal{O}(n^2)$ 。 考虑优化。注意到一对合法 $(min, max)$ 的判断用到了路径信息,于是想到点分治。对于一个分治中心,记录每个点到其路径上的最大值 $max_u$ 和最小值 $min_u$,则对于 $u=min_u\land v=max_v\land min_v\ge u\land max_u\le v$ 的 $(u,v)$ 即为合法对。如果没有 $size$ 可以二维数点维护,难点在于怎么处理 $size$ 这一项。 **Trick**:对于难处理的有限制联通块大小,可以考虑拆贡献,即维护有多少个 $x$ 能满足这样的限制。 于是可以转化为在每个分治中心下统计有多少个 $(l,r,x)$ 可以在同一个联通块里,对答案的贡献即为 $l\cdot r$。 显然合法 $l,r$ 的限制依然成立,而 $x$ 需要满足的限制应该为 $min_x\ge l\land max_x\le r$。 即一对合法 $(l,r,x)$ 应满足: $$ \begin{cases} min_l=l \\ max_r=r \\ min_l\le min_x \\ min_l\le min_r \\ max_x\le max_r \\ max_l\le max_r \end{cases} $$ 观察到关键限制都是 $min$ 或 $max$ 之间的偏序关系,于是不妨设每个点的坐标为 $(min_x,max_x)$ ,发现合法三元组在二维平面上应满足:![](https://cdn.luogu.com.cn/upload/image_hosting/xd8xhfv7.png) 考虑对 $max$ 一维扫描线,线段树维护另一维。 于是当扫到某个 $r$ 时,能与 $r$ 匹配的 $(l,x)$ 一定已经加入了线段树。 而对于某个 $r$ ,能与它产生贡献的 $l$ 一定在 $min_r$ 左侧。设 $l$ 能与 $cnt_l$ 个 $x$ 匹配,则答案为 $l\cdot cnt_l\cdot r$ 。 所以线段树需要维护区间 $\sum l\cdot cnt_l$ 。 我们设一段区间内可以作为 $l$ 的点的编号和为标准和 $std$,要求的 $\sum l\cdot cnt_l$ 为结果和 $sum$ 。 则插入一个 $l$ 时,$min_l=l$ 处的标准和要增加 $l$ ,能与该点匹配的 $cnt$ 不变,结果和要增加 $l\cdot cnt_l$;插入一个 $x$ 时,对于 $min_x$ 左侧的区间,标准和不变,能与区间中每处匹配的 $cnt$ 增加 $1$,结果和要增加一个区间对应的标准和。 于是我们线段树中维护 $std,cnt,sum,add$,分别为区间标准和,能与这段区间匹配的 $x$ 数量,区间结果和,(懒标记)区间加了多少次标准和,支持 $std$ 的单点加,$cnt$ 的区间加,$sum$ 的区间修改($cnt$ 及所谓区间加的定义其实并不严格,因为一段区间每个位置能匹配的 $x$ 的数量不尽相同,但由于 $cnt$ 只在单点修改 $l$ 时由于需要补上之前加入的 $x$ 的贡献才使用,而一个单点能匹配的 $x$ 的数量是一定的,即 $cnt$ 只需要下传,所以 $cnt$ 才可以简单当作区间加处理~~定义是不知道该怎么描述~~)。 复杂度分析可以考虑点分治最坏情况下的形式就是一条链,形态类似线段树,相当于对线段树上每个点开一棵线段树,即树套树,故复杂度 $\mathcal{O}(n\log^2n)$ 。 一些实现细节: - 由于 $(l,r,x)$ 可以任意相等,所以可以作为 $l,r$ 的也可以作为 $x$ 。因此在同一高度的点的操作顺序应该是先插入 $l$ ,再插入 $x$ ,最后查询 $r$ 。 - 子树去重时不能更新 $min$ 和 $max$ ,仍然要保留原分治中心下的 $min$ 和 $max$ 数值,否则起到的不是去重效果。 - 对于每个分治中心需要离散化,不能直接扫描线 $[1,~n]$ ,否则复杂度会退化至 $\mathcal{O}(n^2\log^2n)$ 。 ### cod. ```cpp #include <bits/stdc++.h> #define file(name, suf) ""#name"."#suf"" #define input(name) freopen(file(name, in), "r", stdin) #define output(name) freopen(file(name, out), "w", stdout) #define map(type, x) static_cast<type>(x) typedef unsigned int uint; constexpr int N = 1e5 + 10; int n, siz[N], son_siz[N], id[N], min[N], max[N]; std::vector<int> e[N], node; bool arr[N]; uint ans; struct Seg_Tree { struct Node { uint sum, std, cnt, add, clr; } t[N << 2]; #define ls (u << 1) #define rs (u << 1 | 1) #define mid ((l + r) >> 1) void up(int u) { t[u].sum = t[ls].sum + t[rs].sum; t[u].std = t[ls].std + t[rs].std; } void build(int u, int l, int r) { t[u] = {0, 0, 0, 0, false}; if (l == r) return; build(ls, l, mid), build(rs, mid + 1, r); } void add(int u, uint x) { t[u].sum += t[u].std * x, t[u].cnt += x, t[u].add += x; } void down(int u) { if (t[u].add) add(ls, t[u].add), add(rs, t[u].add), t[u].add = 0; } void insert(int u, int l, int r, int k, uint x) { if (l == r) return t[u].std += x, t[u].sum += t[u].cnt * x, void(); down(u); if (k <= mid) insert(ls, l, mid, k, x); else insert(rs, mid + 1, r, k, x); up(u); } void add(int u, int l, int r, int ql, int qr) { if (l > qr || r < ql) return; if (l >= ql && r <= qr) return add(u, 1); down(u), add(ls, l, mid, ql, qr), add(rs, mid + 1, r, ql, qr), up(u); } uint query(int u, int l, int r, int ql, int qr) { if (l > qr || r < ql) return 0; if (l >= ql && r <= qr) return t[u].sum; down(u); return query(ls, l, mid, ql, qr) + query(rs, mid + 1, r, ql, qr); } } T; int get_core(int u, int f, int all) { int core = son_siz[u] = (siz[u] = 1) - 1; for (const int& v : e[u]) if (v != f && !arr[v]) { int res = get_core(v, u, all); siz[u] += siz[v], son_siz[u] = std::max(son_siz[u], siz[v]); core = !core || son_siz[res] < son_siz[core] ? res : core; } son_siz[u] = std::max(son_siz[u], all - siz[u]); return !core || son_siz[u] < son_siz[core] ? u : core; } void dfs(int u, int f) { node.push_back(u), siz[u] = 1, min[u] = std::min(min[f], u), max[u] = std::max(max[f], u); for (const int& v : e[u]) if (v != f && !arr[v]) dfs(v, u), siz[u] += siz[v]; } void reput(int u, int f) { node.push_back(u); for (const int& v : e[u]) if (v != f && !arr[v]) dfs(v, u); } void erase(int u) { reput(u, 0), std::sort(node.begin(), node.end()); int all = node.size(); T.build(1, 1, all); for (const int& x : node) id[x] = std::lower_bound(node.begin(), node.end(), x) - node.begin() + 1; std::sort(node.begin(), node.end(), [](const int& a, const int& b) { return max[a] < max[b];}); for (int i = 0, j; i < node.size(); i = j) { for (j = i; j < node.size() && max[node[j]] == max[node[i]]; j++) if (min[node[j]] == node[j]) T.insert(1, 1, all, id[node[j]], node[j]); for (j = i; j < node.size() && max[node[j]] == max[node[i]]; j++) T.add(1, 1, all, 1, id[min[node[j]]]); for (j = i; j < node.size() && max[node[j]] == max[node[i]]; j++) if (max[node[j]] == node[j]) ans -= T.query(1, 1, all, 1, id[min[node[j]]]) * node[j]; } } void sol(int u) { arr[u] = true, dfs(u, 0); std::sort(node.begin(), node.end()); int all = node.size(); T.build(1, 1, all); for (const int& x : node) id[x] = std::lower_bound(node.begin(), node.end(), x) - node.begin() + 1; std::sort(node.begin(), node.end(), [](const int& a, const int& b) { return max[a] < max[b];}); for (int i = 0, j; i < node.size(); i = j) { for (j = i; j < node.size() && max[node[j]] == max[node[i]]; j++) if (min[node[j]] == node[j]) T.insert(1, 1, all, id[node[j]], node[j]); for (j = i; j < node.size() && max[node[j]] == max[node[i]]; j++) T.add(1, 1, all, 1, id[min[node[j]]]); for (j = i; j < node.size() && max[node[j]] == max[node[i]]; j++) if (max[node[j]] == node[j]) ans += T.query(1, 1, all, 1, id[min[node[j]]]) * node[j]; } for (const int& v : e[u]) if (!arr[v]) node.clear(), erase(v); node.clear(); for (const int& v : e[u]) if (!arr[v]) sol(get_core(v, 0, siz[v])); } void solve() { std::cin >> n, max[0] = 0, min[0] = INT_MAX, ans = 0; for (int i = 1; i <= n; i++) e[i].clear(), arr[i] = false; for (int i = 1, u, v; i < n; i++) std::cin >> u >> v, e[u].push_back(v), e[v].push_back(u); sol(get_core(1, 0, n)); std::cout << ans << "\n"; } int main() { // input(main), output(main); int _ = 1; std::cin >> _; while (_--) solve(); return 0; } ```