题解 CF1458F 【Range Diameter Sum】

command_block

2021-01-14 08:33:25

Solution

**题意** : 给出一棵 $n$ 个点的树,边权均为 $1$。 设 $len(S)$ 为点集 $S$ 的直径长度。 求 $\sum\limits_{i=1}^n\sum\limits_{j=i}^nlen([i,j])$ $n\leq 10^5$ ,时限$\texttt{7s}$。 ------------ 很有意思的一个题。 树上邻域(圆)理论。 定义 $f(u,r)=\{v|dis(u,v)\leq r\}=\{\text{距离u不超过r的节点}\}$ ,即邻域。其中 $r$ 称为半径。 - **性质①** 对于一个点集 $S$ ,其所有直径的中点是重合的。(可能在一条边的中间) 若中点不重合则显然可以构造新的直径。 设点集 $S$ 的直径中点为 $mid(S)$ ,直径长度为 $len(S)$。 根据这条性质,定义点集 $S$ 的覆盖邻域 $c(S)=f(mid(S),len(S)/2)$。 显然,这是能覆盖 $S$ 的半径最小的邻域,类似于最小覆盖圆。 - **性质②** $f(u,r)\supseteq S\Rightarrow f(u,r)\supseteq c(S)$。 圆的形式 : $⊙C\supseteq S\Rightarrow ⊙C\supseteq S$ 的最小覆盖圆。 **证明** : 显然有 $dis(mid(S),u)\leq r-len(S)/2$ ,否则 $f(u,r)$ 不可能包含 $S$ 的直径。 也就是说,从 $mid(S)$ 开始还能走 $len(S)/2$ 步,自然就覆盖了 $c(S)$。 - **性质③** 设 $td(S,v)$ 为 $v$ 到 $S$ 内最远点的距离。 有 $td(S,v)=dis(mid(S),v)+len(S)/2$。 **证明** : 显然最远点一定是直径的端点之一。 无论从哪个方向来到中点,一定能继续走 $len(S)/2$ 来到某个端点。显然这是上界。 这个结论应用比较广泛,但是在本题中并不是重点,只用来证明。 计算 $c(S∪T)$ 并不困难,因为点集直径具有封闭性。 但是,合并直径的讨论太复杂,不便分析,考虑仅使用 $c(S)$ 来推导。 - 情况A1 : 若 $c(S)\subseteq c(T)$ 则 $c(S∪T)=c(T)$。 - 情况A2 : 若 $c(T)\subseteq c(S)$ 则 $c(S∪T)=c(S)$。 上述两条是显然的。 判定 $c(T)\subseteq c(S)$ 的方法 : 计算 $d=dis\big(mid(S),mid(T)\big)$ ,查看是否 $len(T)/2+d\leq len(S)/2$。 圆的形式 :即两圆包含的情况。 ![](https://cdn.luogu.com.cn/upload/image_hosting/y9h7mztj.png) **证明** : 若 $len(T)/2+d\leq len(S)/2$ ,则说明从 $mid(S)$ 走到 $mid(T)$ 之后还能走 $len(T)/2$ 步,显然覆盖了 $c(T)$。 否则,无论从哪个方向到达 $mid(T)$ ,总有某个不是来向的方向存在长为 $len(T)/2$ 的半条 $T$ 的直径,则不能完整覆盖。 - 情况B : 若不满足情况A1,A2,有 $c(S∪T)=c(f(u_1,r_1)∪f(u_2,r_2))$ $=f\Big(t,\big(dis(u_1,u_2)+r_1+r_2\big)/2\Big)$ 其中 $t$ 是 $u_1,u_2$ 路径上的一个点,距离 $u_1$ 为 $\big(dis(u_1,u_2)-r_1+r_2\big)/2$ ,距离 $u_2$ 为 $\big(dis(u_1,u_2)+r_1-r_2\big)/2$ 圆的形式 :(两圆相离也类似) ![](https://cdn.luogu.com.cn/upload/image_hosting/rlwrld99.png) **证明** :已经排除了前两种情况,所以,新的直径一定跨越 $S,T$。 考虑 $td(S,u_2)=dis(u_1,u_2)+r_1$。取这个最远点 $p_S$ ,进一步得到 $td(T,p_S)=dis(p_S,u_2)+r_2=dis(u_1,u_2)+r_1+r_2$。 我们已经证明了直径长度,现在来找中点。 不难发现,前面构造的直径形如 $p_S\xleftrightarrow{r_1}u_1\xleftrightarrow{dis(u_1,u_2)}u_2\xleftrightarrow{r_2}p_T$。 现在中点位置是显然的。完全等价于数轴上画圆的情况。 好了,现在我们已经有一套 $c(S)$ 合并的理论了,来分析一下本题吧。 考虑分治,每次计算跨越区间 $[L,R]$ 中点 $t$ 的区间的贡献。 枚举每个 $[l,t]$ ,批量计算并上各个 $(t,r]$ 的贡献。 即 $\sum\limits_{r=t+1}^Rlen([l,r])$。 $=\sum\limits_{r=t+1}^Rlen\Big(c\big([l,t]∪(t,r]\big)\Big)$ 现在要分类讨论了。 不难发现,随着 $r$ 的增大,点集 $(t,r]$ 不断扩大。 一开始有 $c\big([l,t]\big)\supseteq c\big((t,r]\big)$ ,然后是情况B,最终是 $c\big([l,t]\big)\subseteq c\big((t,r]\big)$ 我们对这三个区间分别计算贡献。 - 第一部分 : $c\big([l,t]\big)\supseteq c\big((t,r]\big)\Rightarrow len\Big(c\big([l,t]∪(t,r]\big)\Big)=len\big(c([l,t])\big)$ - 第二部分 : 情况B $\Rightarrow len\Big(c\big([l,t]∪(t,r]\big)\Big)=dis\big(mid([1,t]),mid((t,r])\big)+\Big(len\big(c([l,t])\big)+len\big(c((t,r])\big)\Big)\Big/2$ - 第三部分 : $c\big([l,t]\big)\subseteq c\big((t,r]\big)\Rightarrow len\Big(c\big([l,t]∪(t,r]\big)\Big)=len\big(c((t,r])\big)$ ( 注意,可能有 $c\big([l,t]\big)=c\big((t,r]\big)$ ) 首先预处理出各个 $len,mid$。 第一部分和第三部分的贡献对 $len$ 求和即可计算。 第二部分的贡献中,$\Big(len\big(c([l,t])\big)+len\big(c((t,r])\big)\Big)\Big/2$ 也容易计算。 剩下的 $dis\big(mid([1,t]),mid((t,r])\big)$ 等价于形如 $\sum\limits_{v\in[l,r]}dis(u,v)$ 的问题。 这是个经典问题,做法见 :[P4211 [LNOI2014]LCA](https://www.luogu.com.cn/problem/P4211)。若使用树剖BIT,复杂度为 $O(\log^2)$ ,也可以使用全局平衡二叉树或者点分树做到 $O(\log)$。 接下来考虑如何求出这三部分的分界线。 当 $l$ 逐渐减小时,点集 $[l,t]$ 逐渐扩大,所以这三条分界线会单调右移,容易维护。 代码实现采用了全局平衡二叉树,复杂度 $O(n\log^2n)$。 ```cpp #include<algorithm> #include<cstdio> #include<vector> #define ll long long #define pb push_back #define MaxN 200500 using namespace std; vector<int> g[MaxN]; struct totode {int tf,f,son,c;}b[MaxN]; int dep[MaxN]; void pfs1(int u) { b[u].c=1; dep[u]=dep[b[u].f]+1; for (int i=0,v;i<g[u].size();i++) if ((v=g[u][i])!=b[u].f){ b[v].f=u;pfs1(v); b[u].c+=b[v].c; if (b[v].c>b[b[u].son].c) b[u].son=v; } } struct Node{ int f,l,r,tag,c;ll s; inline void ladd(int t) {tag+=t;s+=1ll*t*c;} }a[MaxN<<1];int tn; int st[MaxN],tc[MaxN],tot; int build(int l,int r) { if (l==r)return st[l]; int c=0,mid; for (int i=l;i<=r;i++)c+=tc[i]; for (int i=l,c2=0;i<=r;i++){ c2+=tc[i]; if (c2+c2>c){mid=i-1;break;} }if (mid<l)mid=l; a[c=++tn].c=r-l+1; return a[a[c].l=build(l,mid)].f= a[a[c].r=build(mid+1,r)].f=c; } int dfn[MaxN],tp[MaxN],tim; void pfs2(int u,int top) { tp[dfn[u]=++tim]=u; b[u].tf=top; if (!b[u].son){ for (tot=0;b[u].tf==top;u=b[u].f)tc[++tot]=u; for (int i=1;i<=tot;i++)st[tot-i+1]=tc[i]; for (int i=1;i<=tot;i++) tc[i]=b[st[i]].c-b[b[st[i]].son].c; build(1,tot); return ; }pfs2(b[u].son,top); for (int i=0,v;i<g[u].size();i++) if ((v=g[u][i])!=b[u].f&&v!=b[u].son) pfs2(v,v); } inline void up(int u) {a[u].s=a[a[u].l].s+a[a[u].r].s;} inline void ladd(int u){ if (a[u].tag){ a[a[u].l].ladd(a[u].tag); a[a[u].r].ladd(a[u].tag); a[u].tag=0; } } void grt(int u){ for (tot=0;u;u=a[u].f)st[++tot]=u; for (int i=tot;i>1;i--)ladd(st[i]); } void add(int u,int w){ grt(u); a[st[1]].ladd(w); for (int i=2,v;i<=tot;i++){ if ((v=a[st[i]].l)!=st[i-1]) a[v].ladd(w); up(st[i]); } } ll sum; void qry(int u){ grt(u);ll las=sum; sum+=a[st[1]].s; for (int i=2,v;i<=tot;i++) if ((v=a[st[i]].l)!=st[i-1]) sum+=a[v].s; } void padd(int x,int w) {while(x){add(x,w);x=b[b[x].tf].f;}} ll pqry(int x){ sum=0; while(x){qry(x);x=b[b[x].tf].f;} return sum; } int lca(int u,int v) { while(b[u].tf!=b[v].tf){ if (dep[b[v].tf]>dep[b[u].tf])swap(u,v); u=b[b[u].tf].f; }return dep[u]<dep[v] ? u:v; } int dis(int u,int v) {return dep[u]+dep[v]-2*dep[lca(u,v)];} int up(int u,int d){ while(dep[b[u].tf]>d)u=b[b[u].tf].f; return tp[dfn[b[u].tf]+d-dep[b[u].tf]]; } int lin(int u,int v,int l) { int t=lca(u,v); if (dep[u]-dep[t]>=l)return up(u,dep[u]-l); return up(v,l-dep[u]+2*dep[t]); } struct Data{int u,r;}s[MaxN]; Data merge(const Data &A,int u) { int d=dis(A.u,u); if (d<=A.r)return A; return (Data){lin(A.u,u,(d-A.r)/2),(d+A.r)/2}; } bool in(const Data &A,const Data &B) {return A.r+dis(A.u,B.u)<=B.r;} ll ans,o[MaxN],o2[MaxN]; Data q[MaxN];int tq; bool cmpQ(const Data &A,const Data &B) {return A.r<B.r;} void solve(int L,int R) { if (L==R)return ; int mid=(L+R)>>1; solve(L,mid);solve(mid+1,R); s[mid]=(Data){mid,0}; s[mid+1]=(Data){mid+1,0}; for (int l=mid-1;l>=L;l--) s[l]=merge(s[l+1],l); for (int r=mid+2;r<=R;r++) s[r]=merge(s[r-1],r); o[mid]=0; for (int r=mid+1;r<=R;r++){ o[r]=o[r-1]+s[r].r; o2[r]=o2[r-1]+dep[s[r].u]; } int p1=mid,p2=mid; tq=0; for (int l=mid;l>=L;l--){ while(p1<R&&in(s[p1+1],s[l]))p1++; while(p2<R&&!in(s[l],s[p2+1]))p2++; p2=max(p2,p1); ans+=2ll*(p1-mid)*s[l].r+2*(o[R]-o[p2]) +1ll*(p2-p1)*s[l].r+(o[p2]-o[p1]) +1ll*(p2-p1)*dep[s[l].u]+(o2[p2]-o2[p1]); if (p1<p2){ if (p1>mid)q[++tq]=(Data){-s[l].u,p1}; q[++tq]=(Data){s[l].u,p2}; } }sort(q+1,q+tq+1,cmpQ); for (int i=mid+1,p=1;i<=R;i++){ padd(s[i].u,1); while(p<=tq&&q[p].r==i){ if (q[p].u<0)ans+=2*pqry(-q[p].u); else ans-=2*pqry(q[p].u); p++; } }for (int i=mid+1;i<=R;i++)padd(s[i].u,-1); } int n; int main() { scanf("%d",&n);tn=n*2-1; for (int i=1;i<=tn;i++)a[i].c=1; for (int i=1,u,v;i<n;i++){ scanf("%d%d",&u,&v); g[u].pb(n+i);g[n+i].pb(u); g[v].pb(n+i);g[n+i].pb(v); }pfs1(1);pfs2(1,1); solve(1,n); printf("%lld\n",ans>>1); return 0; } ```