关于单点修改,查询树的带权重心的研究

· · 题解

提供一个 O((n+q)\log n) 的在线做法。

题目详见 P3345。

一、全局平衡二叉树

全局平衡二叉树可以解决链修改、链查询的问题,复杂度为单次 O(\log n)

以下是全局平衡二叉树的构建。对于一棵树,我们首先给它做轻重链剖分。将每条重链拎出来单独建一棵二叉树,这棵二叉树的每个点的权值是它的轻子树大小和加上它自己。然后我们每一层都取带权中心,这个中心的在重链上的祖先变成左子树,在重链上的后代变成右子树。这样,我们将每一条重链都建成了一棵平衡二叉树,其中序遍历恰好是这条重链从浅到深的节点序列。对于重链与重链之间的轻边,我们由较深的重链对应的二叉树的根节点连向它的重链顶端在原树上的父亲节点,但是子节点记录父亲,父亲节点不记录儿子。

从任意一个点开始,我们向上遍历,如前文所说,最多经过 O(\log n) 条轻边,并且由于平衡二叉树节点权重由轻子树大小决定,从而每经过一条重边,其子树大小至少变为原来的两倍。因此,全局平衡二叉树的深度是 O(\log n) 的。

对于修改和查询操作,我们直接从初始节点往上遍历,在每个节点维护一些值和标记即可。

由于全局平衡二叉树的修改和查询从下往上,所以标记永久化的常数显著小于懒标记下传的。

二、本题做法

全局平衡二叉树只能解决链上的问题,所以考虑把答案中的补给站固定在一条链上。

结论 1:使得从树上的任意一点出发到树上所有点的距离乘上目标点的点权和最小的起点一定是树的带权重心。

证明:假设使得答案最小的起点是 u,对于一条连接 uv 的长度为 c 的边,我们假设 u 那一侧的点权和为 s_uv 的那一侧点权和为 s_v,那么如果有 s_u<s_v,那么我们将起点从 u 移动到 v,可以减少 (s_v-s_u)c 的带权距离,与 u 使得答案最小的假设矛盾。 从而 u 一定能保证它的每一个子树的权值和不超过总权值的一半,符合带权重心的定义。□

结论 2:令任意一个节点为根,带权重心是带权 DFS 序的重心在树上所代表的节点的祖先。

证明:我们假设所有节点权值和是 s,令任意一个节点为根,在带权 DFS 序中:

那么,带权 DFS 序重心在树的重心的子树内。□

结论 3:令任意一个节点为根,带权重心子树权值和大于等于所有节点权值和的一半并且其任意一个子节点节点的子树权值和小等于所有节点权值和的一半。

利用上面这 3 个性质,我们可以将 DFS 序预处理出来,动态维护其中位数。处理的时候,我们可以维护每个节点的子树权值和,修改就是将修改的节点到根的链上所有节点加上增加的点权,查询就是在中位数到根节点的链上二分总权值和的一半到了哪个节点。最后就是计算重心到所有节点的带权距离之和。

寻找重心的部分,动态维护中位数,可以用线段树二分;在链上加,可以用全局平衡二叉树做;在链上二分,也用全局平衡二叉树。

对于链上的二分,由于我们倾向于使用标记永久化的全局平衡二叉树,所以没有办法知道向上遍历的每个结点的真实值。于是,采用如下方法:每遍历到一棵平衡二叉树的根节点,我们就一直走向左儿子,累加懒标记。如果这棵平衡二叉树的最左边的节点的子树权值和大于等于所有节点权值和的一半,那就说明重心在这棵平衡二叉树内。接下来就是平衡树上的二分了,依然是 O(\log n) 的。

计算距离的部分,我们可以将 \sum\text{dis}(i,x)\times\text{val}(i) 拆成 \text{dep}(x)\times\sum\text{val}(i)+\sum\text{dep}(i)\times\text{val}(i)-2\sum\text{dep}(\text{lca}(x,i))\times\text{val}(i),第一、二个部分可以 O(1) 维护,第三个部分可以对于每个点,将其到根的链上所有边加上点权乘以边权,最后计算出 x 到根的链上的所有数的和即可,同样用全局平衡二叉树维护。

总时间复杂度 O((n+q)\log n),总空间复杂度 O(n)。即便这题没有卡树剖的数据,也拿到了目前的最优解。

#include<bits/stdc++.h>
using namespace std;
int plen,ptop,pstk[40];
char rdc[1<<14],wtc[1<<23],*rS,*rT;
#define gc() (rS==rT?rT=(rS=rdc)+fread(rdc,1,1<<14,stdin),(rS==rT?EOF:*rS++):*rS++)
#define pc(x) wtc[plen++]=(x)
#define flush() fwrite(wtc,1,plen,stdout),plen=0
template<class T=int>inline T read(){
    T x=0;char ch;bool f=0;
    while(!isdigit(ch=gc()))if(ch=='-')f=!f;
    do x=(x<<1)+(x<<3)+(ch^48);while(isdigit(ch=gc()));
    return f?-x:x;
}
inline int read(char*const s){
    char ch,*t=s;
    while(!isgraph(ch=gc()));
    do *t++=ch;while(isgraph(ch=gc()));
    return (*t)=0,t-s;
}
template<class T>inline void write(T x){
    if(plen>=8000000)flush();
    if(!x)return pc('0'),void();
    if(x<0)pc('-'),x=-x;
    for(;x;x/=10)pstk[ptop++]=x%10;
    while(ptop)pc(pstk[--ptop]^48);
}
inline void write(const char*s){
    if(plen>=8000000)flush();
    for(int i=0;*(s+i);pc(*(s+(i++))));
}
inline void write(char*const s){
    if(plen>=8000000)flush();
    for(int i=0;*(s+i);pc(*(s+(i++))));
}
const int _=11e4;
vector<pair<int,int> >e[_];
long long tot,s[_];
int n,q,m,c,val,h[_],sz[_],f[_],eu[_],in[_],si[_],de[_];
int v[_],d[_],p[_],ls[_],rs[_],o[_],w[_*2],t[_],sv[_],se[_];
void dfs(int x,int fa){
    sz[x]=1;eu[++c]=x;in[x]=c;f[x]=fa;
    for(auto y:e[x])
        if(y.first!=fa){
            de[y.first]=de[x]+(v[y.first]=y.second);
            dfs(y.first,x);
            sz[x]+=sz[y.first];
            if(sz[y.first]>sz[h[x]])h[x]=y.first;
        }
}
int build(int l,int r){
    int x=l,y=r;
    while(y-x>1){
        int mid=x+y>>1;
        if(2*(si[mid]-si[l])<=si[r]-si[l])x=mid;
        else y=mid;
    }
    y=d[x];se[y]=sv[r]-sv[l-1];
    if(l<x)p[ls[y]=build(l,x-1)]=y;
    if(x<r)p[rs[y]=build(x+1,r)]=y;
    return y;
}
int cat(int tp){
    int c=0;
    for(int x=tp;x;x=h[x])
        for(auto y:e[x])
            if(y.first!=f[x]&&y.first!=h[x])p[cat(y.first)]=x;
    for(int x=tp;x;x=h[x])d[++c]=x,si[c]=si[c-1]+sz[x]-sz[h[x]],sv[c]=sv[c-1]+v[x];
    return build(1,c);
}
int main(){
    n=read();q=read();
    for(int i=1,x,y,z;i<n;i++){
        x=read();y=read();z=read();
        e[x].push_back({y,z});
        e[y].push_back({x,z});
    }
    dfs(1,0);cat(1);
    for(m=1;m<n;m<<=1);
    for(int i=1,j,x,y,d,g,k;i<=q;i++){
        long long z=0;bool r=1;
        x=read();y=read();
        for(j=x;j;j=p[j]){
            s[j]+=z;
            if(r)t[j]+=y,t[rs[j]]-=y,z+=1ll*y*(v[j]+se[ls[j]]),s[j]-=1ll*y*se[rs[j]];
            if((r=j!=ls[p[j]])&&j!=rs[p[j]])z=0;
        }
        val+=y;k=val+1>>1;tot+=1ll*de[x]*y;
        for(j=in[x]+m-1;j;j>>=1)w[j]+=y;
        for(j=1;j<m;)w[j<<1]<k?(k-=w[j<<1],j=j<<1|1):j=j<<1;
        for(j=eu[j-m+1];j;j=p[j])
            if(j!=ls[p[j]]&&j!=rs[p[j]]){
                for(d=0,k=j;k;k=ls[k])d+=t[k];
                if(d*2>=val)break;
            }
        for(d=0;j;d*2>=val?(g=j,j=rs[j]):j=ls[j])d+=t[j];
        for(j=g,y=0,z=0,r=1;j;j=p[j]){
            if(r)z+=s[j]-s[rs[j]]-1ll*se[rs[j]]*t[rs[j]],y+=v[j]+se[ls[j]];
            z+=1ll*y*t[j];
            if((r=j!=ls[p[j]])&&j!=rs[p[j]])y=0;
        }
        write(1ll*de[g]*val+tot-2*z),pc('\n');
    }
    flush();
    return 0;
}