题解:P10574 [JRKSJ R8] 暴风雪

· · 题解

这我哪会啊。。。

感觉比大多数根号题要优美太多了。

Solution

考虑到 w\ge 1,直接从下往上跳给每个点的权值 ckmax 即可。一个观察是 ckmax 的权值只有 O(\sqrt n) 段(这是因为 x 的祖先中,只有 O(\sqrt n) 个点存在不包含 x 的子树中有深度与 x 相等的点),考虑对着这个做。

先考虑如何求分段。这相当于每次找到最浅的祖先,使得其存在一个不包含 x 的子树中有深度为 x 的点。考虑长剖,显然跳轻边时这一定是答案,否则这相当于找一条长链向上的第一个 \ge x 的数。考虑对每个 x 开一个 vector 维护所有轻子树深度大于等于 d_x 的点,显然 vector 内不会加入超过轻子树 siz 个点,于是复杂度是 O(n) 的,查询直接在对应的 vector 上从上往下不断跳直到不是 x 的祖先即可。这样即可 O(n) 预处理 O(\sqrt n) 求分段。维护每个节点每一层权值的事情在长剖过程中顺便做即可。

然后考虑怎么做 ckmax。考虑这样一个事情:记 a_{1\sim k} 为从上往下 ckmax 的区间长度,由 a_i 的意义有 \sum i\times a_i\le n,我们有 \sum_i\log(a_i+1)=O(\sqrt n)。贺一个 hos_lyric 的 simple proof:

\sum_{i}\log(a_i+1)&\sim \sum_{i}\sum_{j}[a_i\ge 2^j]\\ &\le\sum_{j}\sqrt{\frac{2n}{2^j}}\\ &=O(\sqrt n) \end{aligned}

于是我们在重剖上维护链 ckmax,这样一共会产生 O(\sqrt n+\log n) 个区间,根据上面的分析拆到线段树上就是 O(\sqrt n+\log^2 n) 个区间。容易发现这些区间到根的路径并上有 O(\sqrt n+\log^2 n) 个节点左右两侧都有被选的节点,O(\log^2 n) 个节点只有一侧有被选的节点(这也是为什么这部分不能长剖,如果用了这里就是 O(\sqrt n\log n) 了)。精细实现即可做到 O\big(n+q(\sqrt n+\log^2 n)\big)

Code

bool Mst;
#include<bits/stdc++.h>
using namespace std;
using ui=unsigned int;
using ll=long long;
using ull=unsigned long long;
using i128=__int128;
using u128=__uint128_t;
using pii=pair<int,int>;
#define fi first
#define se second
constexpr int N=3e5+5,mod=998244353;
inline ll add(ll x,ll y){return (x+=y)>=mod&&(x-=mod),x;}
inline ll Add(ll &x,ll y){return x=add(x,y);}
inline ll sub(ll x,ll y){return (x-=y)<0&&(x+=mod),x;}
inline ll Sub(ll &x,ll y){return x=sub(x,y);}
inline ll qpow(ll a,ll b){
    ll res=1;
    for(;b;b>>=1,a=a*a%mod)
        if(b&1)res=res*a%mod;
    return res;
}
int n,q;vector<int> Gr[N];
int fat[N],dep[N],siz[N],h[N],sh[N],son[N],sonh[N];
int dfc,dfn[N],rnk[N],top[N],toph[N];
inline void dfs1(int x){
    dep[x]=dep[fat[x]]+1,siz[x]=h[x]=sh[x]=1;
    for(const auto &y:Gr[x]){
        dfs1(y);
        siz[x]+=siz[y];
        h[x]=max(h[x],h[y]+1);
        if(h[son[x]]<h[y])
            son[x]=y;
        if(siz[sonh[x]]<siz[y])
            sonh[x]=y;
    }
    for(const auto &y:Gr[x])
        if(son[x]!=y)
            sh[x]=max(sh[x],h[y]+1);
}
inline void dfs2(int x){
    rnk[dfn[x]=++dfc]=x;
    top[x]=son[fat[x]]==x?top[fat[x]]:x;
    toph[x]=sonh[fat[x]]==x?toph[fat[x]]:x;
    if(!sonh[x])return;
    dfs2(sonh[x]);
    for(const auto &y:Gr[x])
        if(y!=sonh[x])
            dfs2(y);
}
vector<vector<int> > vec[N];
vector<ll> sum[N],lsum[N];
int L[N<<2],R[N<<2],M[N<<2];ll Min[N<<2],Max[N<<2],Tag[N<<2],Ans[N<<2];
inline void pushup(int p){
    Min[p]=min(Min[p<<1],Min[p<<1|1]);
    Max[p]=max(Max[p<<1],Max[p<<1|1]);
    Ans[p]=Ans[p<<1]+Ans[p<<1|1];
}
inline void pushTag(int p,ll v){
    Min[p]=Max[p]=Tag[p]=v,Ans[p]=(R[p]-L[p]+1)*v;
}
inline void pushdown(int p){
    if(~Tag[p])
        pushTag(p<<1,Tag[p]),pushTag(p<<1|1,Tag[p]),Tag[p]=-1;
}
inline void build(int l,int r,int p=1){
    L[p]=l,R[p]=r,M[p]=(l+r)>>1,Tag[p]=-1;
    if(l==r)return;
    build(L[p],M[p],p<<1);
    build(M[p]+1,R[p],p<<1|1);
}
inline void upd(ll v,int p){
    if(Min[p]>=v)return;
    if(Max[p]<v){
        pushTag(p,v);
        return;
    }
    pushdown(p);
    upd(v,p<<1),upd(v,p<<1|1);
    pushup(p);
}
struct node{
    int l,r;ll v;
    node(){l=r=v=0;}
    node(int _l,int _r,ll _v){l=_l,r=_r,v=_v;}
};
node sta[N];int tot;
inline void work(int p=1){
    while(tot&&sta[tot].r<L[p])tot--;
    if(!tot||R[p]<sta[tot].l)return;
    if(sta[tot].l<=L[p]&&R[p]<=sta[tot].r){
        upd(sta[tot].v,p);
        return;
    }
    pushdown(p);
    work(p<<1),work(p<<1|1);
    pushup(p);
}
inline ll qry(int l,int r,int p=1){
    if(l<=L[p]&&R[p]<=r)return Ans[p];
    pushdown(p);
    if(r<=M[p])return qry(l,r,p<<1);
    if(M[p]<l)return qry(l,r,p<<1|1);
    return qry(l,r,p<<1)+qry(l,r,p<<1|1);
}
int ver[N],len;
inline void Ins(int u,int v,ll w){
    for(;dep[u]>dep[v];u=fat[toph[u]])
        sta[++tot]=node(max(dfn[toph[u]],dfn[v]+1),dfn[u],w);
}
inline void Mdf(int x,int w){
    tot=0;
    for(int u=x;u;u=fat[top[u]]){
        int r=top[u],d=dep[x]-dep[r];
        lsum[u][dep[x]-dep[u]]+=w,sum[r][d]+=w;
        ll cur=sum[r][d];len=0;
        for(const auto &o:vec[r][d]){
            if(dep[o]>=dep[u])break;
            ver[++len]=o,cur-=lsum[o][dep[x]-dep[o]];
        }
        ver[++len]=u,cur-=lsum[u][dep[x]-dep[u]];
        reverse(ver+1,ver+len+1);
        for(int i=1,o;i<=len;i++){
            o=ver[i],cur+=lsum[o][dep[x]-dep[o]];
            Ins(o,i==len?fat[r]:ver[i+1],cur);
        }
    }
    work();
}
bool Med;
int main(){
    cerr<<abs(&Mst-&Med)/1048576.0<<endl;
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);
    cin>>n>>q;
    for(int i=2;i<=n;i++){
        cin>>fat[i];
        Gr[fat[i]].emplace_back(i);
    }
    dfs1(1),dfs2(1),build(1,n);
    for(int i=1;i<=n;i++){
        if(top[i]==i){
            vec[i].resize(h[i]);
            for(int u=i;u;u=son[u])
                for(int d=0;d<sh[u];d++)
                    vec[i][d+dep[u]-dep[i]].emplace_back(u);
            sum[i].resize(h[i]);
        }
        lsum[i].resize(sh[i]);
    }
    while(q--){
        int x,w,y;cin>>x>>w>>y;
        Mdf(x,w);
        cout<<qry(dfn[y],dfn[y]+siz[y]-1)<<'\n';
    }
    return 0;
}