题解:P9776 [HUSTFC 2023] 狭义线段树

· · 题解

没看懂官方题解,所以来篇题解。

考虑维护每个号的叶子节点的 f(i),其中 1\leq i\leq n。直接用线段树维护。

考虑操作二是什么东西。dfs 序有很强的性质,维护出 [L_i,R_i] 表示 i 号树节点能控制的叶子节点的区间,那么考虑所有区间的并一定是连续的,这一点可以用正常询问区间 [ql,qr] 的定位过程来理解。最终的 \mathcal{S}={\textstyle \bigcup_{i\in [s,t]}}S_i 一定是由若干节点的 [L_i,R_i] 并出来的,直接找 L_i 的区间最小值以及 R_i 的区间最大值即可。ST 表找就行。

操作一的单点的意思是往子树中所有叶子的 f(i) 都加上 x,因此可以把一个子树统一处理掉。

考虑从 u 开始扫,如果 u 的整个子树的树标号点都在 [s,t] 的操作区间,那么整体处理,并且跳过 u 的整个子树;否则对 u 进行单点修改处理,然后 u\to u+1

整体处理,考虑令 b_i=\text{dep}_i,然后 a_i 为变量,需要维护区间加 a_i 以及区间 \sum a_ib_i,容易线段树维护。

跳跃过程的复杂度,考虑一个深度最多走 O(1) 次,那么总的跳跃次数就是 O(\log n)

总时间复杂度 O((n+q)\log^2n)

// Problem: P9776 [HUSTFC 2023] 狭义线段树
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P9776
// Memory Limit: 256 MB
// Time Limit: 2000 ms
// Author: nullptr_qwq
// 
// Powered by CP Editor (https://cpeditor.org)

// 私は猫です

#include<bits/stdc++.h>
#define ull unsigned long long
#define ll long long
#define pb push_back
#define mkp make_pair
#define fi first
#define se second
#define inf 1000000000
#define infll 1000000000000000000ll
#define pii pair<int,int>
#define rep(i,a,b,c) for(int i=(a);i<=(b);i+=(c))
#define per(i,a,b,c) for(int i=(a);i>=(b);i-=(c))
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define dF(i,a,b) for(int i=(a);i>=(b);i--)
#define cmh(sjy) while(sjy--)
#define lowbit(x) (x&(-x))
#define HH printf("\n")
#define eb emplace_back
#define poly vector<int>
using namespace std;
ll read(){
    ll x=0,f=1;char c=getchar();
    while(c<'0'||c>'9'){if(c=='-') f=-1;c=getchar();}
    while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
    return x*f;
}
const int mod=998244353,maxn=500005;
inline int qpow(int x,ll y){
    int rt=1;
    for(;y;y>>=1,x=1ll*x*x%mod) if(y&1) rt=1ll*rt*x%mod;
    return rt;
}
inline void inc(int &x,int y){ x=(x+y>=mod)?(x+y-mod):(x+y); }
inline void dec(int &x,int y){ x=(x>=y)?(x-y):(x+mod-y); }
inline void mul(int &x,int y){ x=1ll*x*y%mod; }
inline int add(int x,int y){ return (x+y>=mod)?(x+y-mod):(x+y); }
inline int sub(int x,int y){ return (x>=y)?(x-y):(x+mod-y); }
inline int prod(int x,int y){ return 1ll*x*y%mod; }
inline void chkmax(int &x,int y){ x=max(x,y); }
inline void chkmin(int &x,int y){ x=min(x,y); }
int n,dep[maxn],fa[maxn],mnl[maxn],mxr[maxn],mxd,siz[maxn],rev[maxn];
int st1[maxn][25],st2[maxn][25];
vector<int>g[maxn];
namespace seg{
    #define ls (o<<1)
    #define rs (o<<1|1)
    int lz[maxn<<2],t[maxn<<2];
    void update(int o,int l,int r,int ql,int qr,int x){
        inc(t[o],1ll*(min(r,qr)-max(l,ql)+1)*x%mod);
        if(ql<=l&&qr>=r) return inc(lz[o],x),void();
        int mid=(l+r)>>1;
        if(ql<=mid)update(ls,l,mid,ql,qr,x);
        if(qr>mid) update(rs,mid+1,r,ql,qr,x);
    }
    int query(int o,int l,int r,int ql,int qr,int R=0){
        if(ql<=l&&qr>=r) return add(t[o],1ll*R*(r-l+1)%mod);
        int mid=(l+r)>>1,res=0;
        if(ql<=mid)res=query(ls,l,mid,ql,qr,add(R,lz[o]));
        if(qr>mid)inc(res,query(rs,mid+1,r,ql,qr,add(R,lz[o])));
        return res;
    }
}
int findl(int l,int r){
    int t=__lg(r-l+1);
    return min(st1[l][t],st1[r-(1<<t)+1][t]);
}
int findr(int l,int r){
    int t=__lg(r-l+1);
    return max(st2[l][t],st2[r-(1<<t)+1][t]);
}
namespace seg2{
    #define ls (o<<1)
    #define rs (o<<1|1)
    int t[maxn<<2],sum[maxn<<2],lz[maxn<<2];
    void build(int o,int l,int r){
        if(l==r)return t[o]=dep[rev[l]],void();
        int mid=(l+r)>>1;
        build(ls,l,mid),build(rs,mid+1,r),t[o]=add(t[ls],t[rs]);
    }
    void maketag(int o,int x){ inc(lz[o],x),inc(sum[o],1ll*x*t[o]%mod); }
    void pushdown(int o){ if(lz[o]) maketag(ls,lz[o]),maketag(rs,lz[o]),lz[o]=0; }
    void update(int o,int l,int r,int ql,int qr,int x){
        if(ql<=l&&qr>=r)return maketag(o,x),void();
        int mid=(l+r)>>1; pushdown(o);
        if(ql<=mid)update(ls,l,mid,ql,qr,x);
        if(qr>mid) update(rs,mid+1,r,ql,qr,x);
        sum[o]=add(sum[ls],sum[rs]);
    }
    int query(int o,int l,int r,int ql,int qr){
        if(ql<=l&&qr>=r)return sum[o];
        int mid=(l+r)>>1,res=0; pushdown(o);
        if(ql<=mid)res=query(ls,l,mid,ql,qr);
        if(qr>mid)inc(res,query(rs,mid+1,r,ql,qr));
        return res;
    }
}
void solve(){
    n=read(),dep[1]=1;
    F(i,2,(n<<1)-1) g[fa[i]=read()].push_back(i),dep[i]=dep[fa[i]]+1,chkmax(mxd,dep[i]);
    int cnt=0;
    function<void(int)>dfs=[&](int u){
        siz[u]=1;
        if(g[u].empty())return mnl[u]=mxr[u]=++cnt,rev[cnt]=u,void();
        mnl[u]=inf;
        for(int v:g[u]) dfs(v),siz[u]+=siz[v],chkmax(mxr[u],mxr[v]),chkmin(mnl[u],mnl[v]);
    }; dfs(1);
    const int m=(n<<1)-1;
    F(i,1,m)st1[i][0]=mnl[i],st2[i][0]=mxr[i];
    F(j,1,23)F(i,1,m-(1<<j)+1)st1[i][j]=min(st1[i][j-1],st1[i+(1<<(j-1))][j-1]);
    F(j,1,23)F(i,1,m-(1<<j)+1)st2[i][j]=max(st2[i][j-1],st2[i+(1<<(j-1))][j-1]);
    seg2::build(1,1,n);
    int cmh=read();
    cmh(cmh){
        int op=read();
        if(op==1){
            int s=read(),t=read(),x=read(),u=s;
            while(u<=t){
                const int v=u+siz[u]-1;
                if(v<=t){
                    seg::update(1,1,n,mnl[u],mxr[u],sub(0,1ll*x*(dep[u]-1)%mod));
                    seg2::update(1,1,n,mnl[u],mxr[u],x);
                    u+=siz[u];
                }
                else seg::update(1,1,n,mnl[u],mxr[u],x),++u;
            }
        }
        if(op==2){
            int s=read(),t=read(),x=read();
            seg::update(1,1,n,findl(s,t),findr(s,t),x);
        }
        if(op==3){
            int l=read(),r=read(),ans=seg::query(1,1,n,l,r);
            inc(ans,seg2::query(1,1,n,l,r));
            printf("%d\n",ans);
        }
    }
}
signed main(){
    int sjy=1;
    cmh(sjy) solve();
}