题解:P11408 [RMI 2020] 树咖 / Arboras

· · 题解

维护每个子树内的最长链和与最长链无交的次长链。考虑每次修改的影响,由于边权加的都是正数,当前这个点到祖先上的某个点的最长链都可能改变,可以树剖和线段树维护,用倍增定位。对于次长链,倍增找祖先上某些合法结点,用原来最长链的大小替换掉,合法结点的数量总和是均摊线性的,因为每次操作相当于对于长链剖分的推平。

时间复杂度 O(n\log^2n)

#include<bits/stdc++.h>
#include<ext/pb_ds/priority_queue.hpp>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
#define MP make_pair
#define pii pair<int,int>
const double PI=acos(-1.0);
template <class Miaowu>
inline void in(Miaowu &x){
    char c;x=0;bool f=0;
    for(c=getchar();c<'0'||c>'9';c=getchar())f|=c=='-';
    for(;c>='0'&&c<='9';c=getchar())x=(x<<1)+(x<<3)+(c^48);
    x=f?-x:x;
}
const int N=1e5+5;
const int mod=1e9+7;
vector<int>g[N];
ll a[N],dp[N],dpp[N];
int n,m,F[N][17],son[N],top[N],de[N],siz[N],fn,dfn[N],ans,dy[N],ans2;
inline void dfs1(int u){
    for(int i=1;i<17;i++)F[u][i]=F[F[u][i-1]][i-1];
    siz[u]=1;
    for(int v:g[u]){
        de[v]=de[u]+1,dfs1(v);
        siz[u]+=siz[v];
        if(dp[v]+a[v]>dp[u])dpp[u]=dp[u],dp[u]=dp[v]+a[v];
        else dpp[u]=max(dpp[u],dp[v]+a[v]);
        if(siz[v]>siz[son[u]])son[u]=v;
    }
}
inline void dfs2(int u){
    dfn[u]=++fn,dy[fn]=u;
    if(son[u])top[son[u]]=top[u],dfs2(son[u]);
    for(int v:g[u])if(v!=son[u])top[v]=v,dfs2(v);
}
struct BIT{
    ll t[N];
    inline void upd(int u,ll x){
        while(u<N)t[u]+=x,u+=u&-u;
    }
    inline ll qry(int u){
        ll res=0;
        while(u)res+=t[u],u-=u&-u;
        return res;
    }
}bit;
struct SGT{
    inline int ls(int u){return u<<1;}
    inline int rs(int u){return u<<1|1;}
    ll tag[N<<2],sa[N<<2],tag2[N<<2],tt[N<<2];
    int sum[N<<2],cs[N<<2],sum2[N<<2];
    inline void pd(int u,int l,int r){
        if(tt[u]==-1)return;
        tt[rs(u)]=tag[rs(u)]=tt[u],sum[rs(u)]=(tt[u]%mod*r+cs[rs(u)])%mod;
        tt[ls(u)]=tag[ls(u)]=tt[u]+sa[rs(u)],sum[ls(u)]=((tt[u]+sa[rs(u)])%mod*l+cs[ls(u)])%mod;
        tt[u]=-1;
    }
    inline void pu(int u,int l,int mid){
        sum[u]=(sum[ls(u)]+sum[rs(u)])%mod,tag[u]=tag[rs(u)];
        sum2[u]=(sum2[ls(u)]+sum2[rs(u)])%mod,tag2[u]=tag2[rs(u)];
        sa[u]=sa[ls(u)]+sa[rs(u)],cs[u]=((cs[ls(u)]+cs[rs(u)])%mod+sa[rs(u)]%mod*(mid-l+1)%mod)%mod;
    }
    inline void build(int u,int l,int r){
        tt[u]=-1;
        if(l==r)return tag[u]=dp[dy[l]],sum[u]=dp[dy[l]]%mod,sa[u]=a[dy[l]],sum2[u]=dpp[dy[l]]%mod,tag2[u]=dpp[dy[l]],void();
        int mid=l+r>>1;
        build(ls(u),l,mid),build(rs(u),mid+1,r);
        pu(u,l,mid);
    }
    inline void upda(int u,int l,int r,int p,ll x){
        if(l==r)return sa[u]+=x,void();
        int mid=l+r>>1;pd(u,mid-l+1,r-mid);
        mid>=p?upda(ls(u),l,mid,p,x):upda(rs(u),mid+1,r,p,x);
        pu(u,l,mid);
    }
    inline void upd2(int u,int l,int r,int p,ll x){
        if(l==r){
            if(tag2[u]<x)tag2[u]=x,sum2[u]=x%mod;
            return;
        }
        int mid=l+r>>1;pd(u,mid-l+1,r-mid);
        mid>=p?upd2(ls(u),l,mid,p,x):upd2(rs(u),mid+1,r,p,x);
        pu(u,l,mid);
    }
    inline ll upd(int u,int l,int r,int L,int R,ll x){
        if(l>=L&&r<=R)return sum[u]=(x*(r-l+1)%mod+cs[u])%mod,tag[u]=tt[u]=x,x+sa[u];
        int mid=l+r>>1;pd(u,mid-l+1,r-mid);
        if(mid<R)x=upd(rs(u),mid+1,r,L,R,x);
        if(mid>=L)x=upd(ls(u),l,mid,L,R,x);
        return pu(u,l,mid),x;
    }
    inline ll qry1(int u,int l,int r,int p){
        if(l==r)return tag[u];
        int mid=l+r>>1;pd(u,mid-l+1,r-mid);
        return mid>=p?qry1(ls(u),l,mid,p):qry1(rs(u),mid+1,r,p);
    }
}sgt;
inline void upd(int u,int v,ll x){
    while(de[top[u]]>de[v])x=sgt.upd(1,1,n,dfn[top[u]],dfn[u],x),u=F[top[u]][0];
    sgt.upd(1,1,n,dfn[v],dfn[u],x);
}
int main(){
    in(n);
    for(int i=2,u;i<=n;i++)in(u),g[++u].push_back(i),F[i][0]=u;
    for(int i=2;i<=n;i++)in(a[i]);
    de[1]=1,dfs1(1),top[1]=1,dfs2(1);
    sgt.build(1,1,n);
    for(int i=1;i<=n;i++)bit.upd(dfn[i],a[i]),bit.upd(dfn[i]+siz[i],-a[i]);
    printf("%d\n",(sgt.sum[1]+sgt.sum2[1])%mod),in(m);
    while(m--){
        int u,x,uu;in(u),in(x),a[++u]+=x,uu=u;
        bit.upd(dfn[u],x),bit.upd(dfn[u]+siz[u],-x);
        ll v=sgt.qry1(1,1,n,dfn[u]),qwq=bit.qry(dfn[u]);
        for(int i=16;i>=0;i--)if(F[u][i]&&v-bit.qry(dfn[F[u][i]])+qwq>sgt.qry1(1,1,n,dfn[F[u][i]]))u=F[u][i];
        int vv=uu;
        for(int i=16;i>=0;i--)if(F[vv][i]&&v-bit.qry(dfn[F[vv][i]])+qwq-x==sgt.qry1(1,1,n,dfn[F[vv][i]]))vv=F[vv][i];
        while(vv){
            vv=F[vv][0];int tv=vv;
            if(de[tv]<de[u])break;
            ll qaq=sgt.qry1(1,1,n,dfn[vv]),tuu=bit.qry(dfn[vv]);
            for(int i=16;i>=0;i--)if(F[vv][i]&&qaq-bit.qry(dfn[F[vv][i]])+tuu==sgt.qry1(1,1,n,dfn[F[vv][i]]))vv=F[vv][i];
            if(qaq-bit.qry(dfn[vv])+tuu!=sgt.qry1(1,1,n,dfn[vv]))break;
            sgt.upd2(1,1,n,dfn[tv],qaq);
        }
        sgt.upda(1,1,n,dfn[uu],x),upd(uu,u,v);
        if(F[u][0])sgt.upd2(1,1,n,dfn[F[u][0]],sgt.qry1(1,1,n,dfn[u])+a[u]);
        printf("%d\n",(sgt.sum[1]+sgt.sum2[1])%mod);
    }
    return 0;
}