题解 CF1458F 【Range Diameter Sum】
command_block
2021-01-14 08:33:25
**题意** : 给出一棵 $n$ 个点的树,边权均为 $1$。
设 $len(S)$ 为点集 $S$ 的直径长度。
求 $\sum\limits_{i=1}^n\sum\limits_{j=i}^nlen([i,j])$
$n\leq 10^5$ ,时限$\texttt{7s}$。
------------
很有意思的一个题。
树上邻域(圆)理论。
定义 $f(u,r)=\{v|dis(u,v)\leq r\}=\{\text{距离u不超过r的节点}\}$ ,即邻域。其中 $r$ 称为半径。
- **性质①** 对于一个点集 $S$ ,其所有直径的中点是重合的。(可能在一条边的中间)
若中点不重合则显然可以构造新的直径。
设点集 $S$ 的直径中点为 $mid(S)$ ,直径长度为 $len(S)$。
根据这条性质,定义点集 $S$ 的覆盖邻域 $c(S)=f(mid(S),len(S)/2)$。
显然,这是能覆盖 $S$ 的半径最小的邻域,类似于最小覆盖圆。
- **性质②** $f(u,r)\supseteq S\Rightarrow f(u,r)\supseteq c(S)$。
圆的形式 : $⊙C\supseteq S\Rightarrow ⊙C\supseteq S$ 的最小覆盖圆。
**证明** : 显然有 $dis(mid(S),u)\leq r-len(S)/2$ ,否则 $f(u,r)$ 不可能包含 $S$ 的直径。
也就是说,从 $mid(S)$ 开始还能走 $len(S)/2$ 步,自然就覆盖了 $c(S)$。
- **性质③** 设 $td(S,v)$ 为 $v$ 到 $S$ 内最远点的距离。
有 $td(S,v)=dis(mid(S),v)+len(S)/2$。
**证明** : 显然最远点一定是直径的端点之一。
无论从哪个方向来到中点,一定能继续走 $len(S)/2$ 来到某个端点。显然这是上界。
这个结论应用比较广泛,但是在本题中并不是重点,只用来证明。
计算 $c(S∪T)$ 并不困难,因为点集直径具有封闭性。
但是,合并直径的讨论太复杂,不便分析,考虑仅使用 $c(S)$ 来推导。
- 情况A1 : 若 $c(S)\subseteq c(T)$ 则 $c(S∪T)=c(T)$。
- 情况A2 : 若 $c(T)\subseteq c(S)$ 则 $c(S∪T)=c(S)$。
上述两条是显然的。
判定 $c(T)\subseteq c(S)$ 的方法 : 计算 $d=dis\big(mid(S),mid(T)\big)$ ,查看是否 $len(T)/2+d\leq len(S)/2$。
圆的形式 :即两圆包含的情况。
![](https://cdn.luogu.com.cn/upload/image_hosting/y9h7mztj.png)
**证明** :
若 $len(T)/2+d\leq len(S)/2$ ,则说明从 $mid(S)$ 走到 $mid(T)$ 之后还能走 $len(T)/2$ 步,显然覆盖了 $c(T)$。
否则,无论从哪个方向到达 $mid(T)$ ,总有某个不是来向的方向存在长为 $len(T)/2$ 的半条 $T$ 的直径,则不能完整覆盖。
- 情况B : 若不满足情况A1,A2,有 $c(S∪T)=c(f(u_1,r_1)∪f(u_2,r_2))$
$=f\Big(t,\big(dis(u_1,u_2)+r_1+r_2\big)/2\Big)$
其中 $t$ 是 $u_1,u_2$ 路径上的一个点,距离 $u_1$ 为 $\big(dis(u_1,u_2)-r_1+r_2\big)/2$ ,距离 $u_2$ 为 $\big(dis(u_1,u_2)+r_1-r_2\big)/2$
圆的形式 :(两圆相离也类似)
![](https://cdn.luogu.com.cn/upload/image_hosting/rlwrld99.png)
**证明** :已经排除了前两种情况,所以,新的直径一定跨越 $S,T$。
考虑 $td(S,u_2)=dis(u_1,u_2)+r_1$。取这个最远点 $p_S$ ,进一步得到 $td(T,p_S)=dis(p_S,u_2)+r_2=dis(u_1,u_2)+r_1+r_2$。
我们已经证明了直径长度,现在来找中点。
不难发现,前面构造的直径形如 $p_S\xleftrightarrow{r_1}u_1\xleftrightarrow{dis(u_1,u_2)}u_2\xleftrightarrow{r_2}p_T$。
现在中点位置是显然的。完全等价于数轴上画圆的情况。
好了,现在我们已经有一套 $c(S)$ 合并的理论了,来分析一下本题吧。
考虑分治,每次计算跨越区间 $[L,R]$ 中点 $t$ 的区间的贡献。
枚举每个 $[l,t]$ ,批量计算并上各个 $(t,r]$ 的贡献。
即 $\sum\limits_{r=t+1}^Rlen([l,r])$。
$=\sum\limits_{r=t+1}^Rlen\Big(c\big([l,t]∪(t,r]\big)\Big)$
现在要分类讨论了。
不难发现,随着 $r$ 的增大,点集 $(t,r]$ 不断扩大。
一开始有 $c\big([l,t]\big)\supseteq c\big((t,r]\big)$ ,然后是情况B,最终是 $c\big([l,t]\big)\subseteq c\big((t,r]\big)$
我们对这三个区间分别计算贡献。
- 第一部分 : $c\big([l,t]\big)\supseteq c\big((t,r]\big)\Rightarrow len\Big(c\big([l,t]∪(t,r]\big)\Big)=len\big(c([l,t])\big)$
- 第二部分 : 情况B $\Rightarrow len\Big(c\big([l,t]∪(t,r]\big)\Big)=dis\big(mid([1,t]),mid((t,r])\big)+\Big(len\big(c([l,t])\big)+len\big(c((t,r])\big)\Big)\Big/2$
- 第三部分 : $c\big([l,t]\big)\subseteq c\big((t,r]\big)\Rightarrow len\Big(c\big([l,t]∪(t,r]\big)\Big)=len\big(c((t,r])\big)$
( 注意,可能有 $c\big([l,t]\big)=c\big((t,r]\big)$ )
首先预处理出各个 $len,mid$。
第一部分和第三部分的贡献对 $len$ 求和即可计算。
第二部分的贡献中,$\Big(len\big(c([l,t])\big)+len\big(c((t,r])\big)\Big)\Big/2$ 也容易计算。
剩下的 $dis\big(mid([1,t]),mid((t,r])\big)$ 等价于形如 $\sum\limits_{v\in[l,r]}dis(u,v)$ 的问题。
这是个经典问题,做法见 :[P4211 [LNOI2014]LCA](https://www.luogu.com.cn/problem/P4211)。若使用树剖BIT,复杂度为 $O(\log^2)$ ,也可以使用全局平衡二叉树或者点分树做到 $O(\log)$。
接下来考虑如何求出这三部分的分界线。
当 $l$ 逐渐减小时,点集 $[l,t]$ 逐渐扩大,所以这三条分界线会单调右移,容易维护。
代码实现采用了全局平衡二叉树,复杂度 $O(n\log^2n)$。
```cpp
#include<algorithm>
#include<cstdio>
#include<vector>
#define ll long long
#define pb push_back
#define MaxN 200500
using namespace std;
vector<int> g[MaxN];
struct totode
{int tf,f,son,c;}b[MaxN];
int dep[MaxN];
void pfs1(int u)
{
b[u].c=1;
dep[u]=dep[b[u].f]+1;
for (int i=0,v;i<g[u].size();i++)
if ((v=g[u][i])!=b[u].f){
b[v].f=u;pfs1(v);
b[u].c+=b[v].c;
if (b[v].c>b[b[u].son].c)
b[u].son=v;
}
}
struct Node{
int f,l,r,tag,c;ll s;
inline void ladd(int t)
{tag+=t;s+=1ll*t*c;}
}a[MaxN<<1];int tn;
int st[MaxN],tc[MaxN],tot;
int build(int l,int r)
{
if (l==r)return st[l];
int c=0,mid;
for (int i=l;i<=r;i++)c+=tc[i];
for (int i=l,c2=0;i<=r;i++){
c2+=tc[i];
if (c2+c2>c){mid=i-1;break;}
}if (mid<l)mid=l;
a[c=++tn].c=r-l+1;
return
a[a[c].l=build(l,mid)].f=
a[a[c].r=build(mid+1,r)].f=c;
}
int dfn[MaxN],tp[MaxN],tim;
void pfs2(int u,int top)
{
tp[dfn[u]=++tim]=u;
b[u].tf=top;
if (!b[u].son){
for (tot=0;b[u].tf==top;u=b[u].f)tc[++tot]=u;
for (int i=1;i<=tot;i++)st[tot-i+1]=tc[i];
for (int i=1;i<=tot;i++)
tc[i]=b[st[i]].c-b[b[st[i]].son].c;
build(1,tot);
return ;
}pfs2(b[u].son,top);
for (int i=0,v;i<g[u].size();i++)
if ((v=g[u][i])!=b[u].f&&v!=b[u].son)
pfs2(v,v);
}
inline void up(int u)
{a[u].s=a[a[u].l].s+a[a[u].r].s;}
inline void ladd(int u){
if (a[u].tag){
a[a[u].l].ladd(a[u].tag);
a[a[u].r].ladd(a[u].tag);
a[u].tag=0;
}
}
void grt(int u){
for (tot=0;u;u=a[u].f)st[++tot]=u;
for (int i=tot;i>1;i--)ladd(st[i]);
}
void add(int u,int w){
grt(u);
a[st[1]].ladd(w);
for (int i=2,v;i<=tot;i++){
if ((v=a[st[i]].l)!=st[i-1])
a[v].ladd(w);
up(st[i]);
}
}
ll sum;
void qry(int u){
grt(u);ll las=sum;
sum+=a[st[1]].s;
for (int i=2,v;i<=tot;i++)
if ((v=a[st[i]].l)!=st[i-1])
sum+=a[v].s;
}
void padd(int x,int w)
{while(x){add(x,w);x=b[b[x].tf].f;}}
ll pqry(int x){
sum=0;
while(x){qry(x);x=b[b[x].tf].f;}
return sum;
}
int lca(int u,int v)
{
while(b[u].tf!=b[v].tf){
if (dep[b[v].tf]>dep[b[u].tf])swap(u,v);
u=b[b[u].tf].f;
}return dep[u]<dep[v] ? u:v;
}
int dis(int u,int v)
{return dep[u]+dep[v]-2*dep[lca(u,v)];}
int up(int u,int d){
while(dep[b[u].tf]>d)u=b[b[u].tf].f;
return tp[dfn[b[u].tf]+d-dep[b[u].tf]];
}
int lin(int u,int v,int l)
{
int t=lca(u,v);
if (dep[u]-dep[t]>=l)return up(u,dep[u]-l);
return up(v,l-dep[u]+2*dep[t]);
}
struct Data{int u,r;}s[MaxN];
Data merge(const Data &A,int u)
{
int d=dis(A.u,u);
if (d<=A.r)return A;
return (Data){lin(A.u,u,(d-A.r)/2),(d+A.r)/2};
}
bool in(const Data &A,const Data &B)
{return A.r+dis(A.u,B.u)<=B.r;}
ll ans,o[MaxN],o2[MaxN];
Data q[MaxN];int tq;
bool cmpQ(const Data &A,const Data &B)
{return A.r<B.r;}
void solve(int L,int R)
{
if (L==R)return ;
int mid=(L+R)>>1;
solve(L,mid);solve(mid+1,R);
s[mid]=(Data){mid,0};
s[mid+1]=(Data){mid+1,0};
for (int l=mid-1;l>=L;l--)
s[l]=merge(s[l+1],l);
for (int r=mid+2;r<=R;r++)
s[r]=merge(s[r-1],r);
o[mid]=0;
for (int r=mid+1;r<=R;r++){
o[r]=o[r-1]+s[r].r;
o2[r]=o2[r-1]+dep[s[r].u];
}
int p1=mid,p2=mid;
tq=0;
for (int l=mid;l>=L;l--){
while(p1<R&&in(s[p1+1],s[l]))p1++;
while(p2<R&&!in(s[l],s[p2+1]))p2++;
p2=max(p2,p1);
ans+=2ll*(p1-mid)*s[l].r+2*(o[R]-o[p2])
+1ll*(p2-p1)*s[l].r+(o[p2]-o[p1])
+1ll*(p2-p1)*dep[s[l].u]+(o2[p2]-o2[p1]);
if (p1<p2){
if (p1>mid)q[++tq]=(Data){-s[l].u,p1};
q[++tq]=(Data){s[l].u,p2};
}
}sort(q+1,q+tq+1,cmpQ);
for (int i=mid+1,p=1;i<=R;i++){
padd(s[i].u,1);
while(p<=tq&&q[p].r==i){
if (q[p].u<0)ans+=2*pqry(-q[p].u);
else ans-=2*pqry(q[p].u);
p++;
}
}for (int i=mid+1;i<=R;i++)padd(s[i].u,-1);
}
int n;
int main()
{
scanf("%d",&n);tn=n*2-1;
for (int i=1;i<=tn;i++)a[i].c=1;
for (int i=1,u,v;i<n;i++){
scanf("%d%d",&u,&v);
g[u].pb(n+i);g[n+i].pb(u);
g[v].pb(n+i);g[n+i].pb(v);
}pfs1(1);pfs2(1,1);
solve(1,n);
printf("%lld\n",ans>>1);
return 0;
}
```