NOIP2024 树上查询 题解

· · 题解

来一个考场做法和后续优化。

首先有一个 O(1) 求 lca 的前置知识:我们在 dfs 序上考虑,设求点 x,yx\neq y)的 lca,不妨设 dfn_x<dfn_y,则 lca 为 [dfn_x+1,dfn_y] 区间上深度最小的点的父亲。手玩一下就可以证明了。

可以发现多个点的 lca 是可合并可重信息,于是可以用 st 表维护区间 lca。

于是可以得到一个 O(nq) 的暴力,由于只要考虑长度恰好等于 k 的区间,我们枚举 O(n) 个区间,然后求出 lca,比较即可。可以获得 32 的好成绩。

看到剩下的除了正解就是链了,于是考虑链。我们不妨直接把链考虑成序列:a_i 表示点 i 的深度。则询问相当于最大化子区间内所有长度 \ge k 的区间的 \max。一个直接的想法是二分答案 x,则我们把值 \ge x 的称为可行点,其它称为断点。注意到除了包含区间边界的时候,只有极长的可行点区间是有用的。即区间 [l,r] 有用当且仅当 \min^r_{i=l}a_i\ge xa_{l-1}<xa_{r+1}<x。其余只需要考虑特判区间 [L,L+k-1][R-k+1,R] 即可(L,R,k 为题面中的询问参数)。

这种区间一看就非常符合直觉,看上去非常有优化前途啊!我们不妨设 a_{l-1}\ge a_{r+1},则可以发现 [l,r] 的所有数都 >a_{l-1},则 a_{r+1}l-1 右边的第一个 \le a_{l-1} 的;同理,当 a_{l-1}\le a_{r+1} 时,a_{l-1}a_{r+1} 左边第一个更小的。也就是说,最多只有这样 2n 个区间是有用的。

朴素做,对于一个有用区间 (l,r,x),其中 x 表示 [l,r] 区间的 lca 深度,则要求 l\ge Lr\le Rr-l+1\ge k。这就是三维数点板子了,我们里先后随便找一维排序,然后用树套树解决。时间复杂度 O((n+q)\log^2n),使用线段树套 treap 空间复杂度 O((n+q)\log n)。考场上我花了半小时实现了该做法(虽然还调了半小时),期望得分最好有 64

考虑是否可以扩展到树上。注意到链可做的原因是,lca 可以表示为区间 \min。我们脑洞大开,用前置知识中 O(1) lca 来考虑。在 dfn 序列中,区间 lca 相当于钦定了若干单点(区间内所有点的 dfn 序),然后求所有点形成的左开右闭区间的 \min 的最小值,可以发现就是 dfn 序列上最大最小值区间的最小值。同时考虑这个拆法,为了模仿链的做法,我们也类似地把这个大区间拆成若干个小区间的并,其中每个小区间是原序列中相邻两点在 dfn 序列上表示的左开右闭区间。也就是说,令 a_i 表示 \min^{\max(dfn_i,dfn_{i+1})}_{j=\min(dfn_i,dfn_{i+1})+1}dep_j-1,则区间 lca 的深度即为 \min^{r-1}_{i=l}a_i,因为这些小区间一定能覆盖整个大区间。也就是说我们在令 a_idep_{lca(i,i+1)} 时可以用和链一样的做法做,即找到 2n 个有用区间尝试。注意这个做法需要特判 k=1。时间复杂度 O((n+q)\log^2n),不知道能不能过,感觉常数不大。

zhuzhu2891 告诉我每个区间只需要保留前后缀即可。具体地,在三维偏序的 l,r,len 三个限制中,我们分别只考虑 l,lenr,len 的限制,令新区间被原区间包含即可。从刚才二分答案的角度思考可以感性理解只保留前后缀的正确性。注意到虽然长度缩小后区间 \min 可能会改变,但是按原来的值算只会更劣,一定会被更优的区间前后缀覆盖,所以这个做法可以仅用二维偏序实现,按 len 排序之后实现单点修改区间求 \max 即可。时间复杂度 O((n+q)\log n)

#include<bits/stdc++.h>
#define REP(i,a,n) for(int i=(a);i<(int)(n);++i)
#define pb push_back
using namespace std;
int read(){
    int res=0;char c=getchar();
    while(c<48||c>57)c=getchar();
    do res=(res<<1)+(res<<3)+(c^48),c=getchar();while(c>=48&&c<=57);
    return res;
}
struct queries{
    int l,r,id;
};
struct ds{
    int seg[2000005];
    void build(int l,int r,int p){
        seg[p]=0;
        if(l==r)return;
        int m=(l+r)>>1;
        build(l,m,p*2+1);build(m+1,r,p*2+2);
    }
    void update(int pos,int l,int r,int p,int val){
        seg[p]=max(seg[p],val);
        if(l==r)return;
        int m=(l+r)>>1;
        if(m>=pos)update(pos,l,m,p*2+1,val);
        else update(pos,m+1,r,p*2+2,val);
    }
    int query(int l,int r,int s,int t,int p){
        if(l<=s&&t<=r)return seg[p];
        int m=(s+t)>>1,res=0;
        if(m>=l)res=query(l,r,s,m,p*2+1);
        if(m<r)res=max(res,query(l,r,m+1,t,p*2+2));
        return res;
    }
}s1,s2;
int n,q;
vector<int>v[500005];
int an[22][500005],fa[500005],dep[500005],a[500005];
int st[22][500005],mx[22][500005],dfn[500005];
int tot;
int ans[500005];
vector<queries>qr[500005],add[500005];
int getmax(int x,int y){return dep[x]<dep[y]? x:y;}
void dfs(int x,int pre,int d){
    fa[x]=pre;dfn[x]=tot++;an[0][dfn[x]]=pre;dep[x]=d;
    for(auto i:v[x])if(i!=pre)dfs(i,x,d+1);
}
int getlca(int x,int y){
    if(x==y)return x;
    x=dfn[x];y=dfn[y];if(x>y)swap(x,y);
    int s=__lg(y-x);
    return getmax(an[s][x+1],an[s][y-(1<<s)+1]);
}
int query(int l,int r){
    int s=__lg(r-l+1);
    return dep[getlca(st[s][l],st[s][r-(1<<s)+1])]+1;
}
int qmax(int l,int r){
    int s=__lg(r-l+1);
    return max(mx[s][l],mx[s][r-(1<<s)+1]);
}
signed main(){
    freopen("query.in","r",stdin);
    freopen("query.out","w",stdout);
    n=read();
    REP(i,1,n){
        int x=read()-1,y=read()-1;
        v[x].pb(y);v[y].pb(x);
    }
    dfs(0,-1,0);
    REP(j,0,__lg(n-1)){
        REP(i,1,n-(1<<(j+1))+1)an[j+1][i]=getmax(an[j][i],an[j][i+(1<<j)]);
    }
    REP(i,0,n)st[0][i]=i,mx[0][i]=dep[i]+1;
    REP(j,0,__lg(n)){
        REP(i,0,n-(1<<(j+1))+1)st[j+1][i]=getlca(st[j][i],st[j][i+(1<<j)]);
        REP(i,0,n-(1<<(j+1))+1)mx[j+1][i]=max(mx[j][i],mx[j][i+(1<<j)]);
    }
    REP(i,0,n-1)a[i]=dep[getlca(i,i+1)];
    stack<int>st;
    REP(i,0,n-1){
        while(!st.empty()&&a[st.top()]>a[i])st.pop();
        if(!st.empty()){
            int x=st.top()+1;
            if(x<i){
                queries y={x,i-1,query(x,i)};
                add[i-x].pb(y);
            }
        }
        st.push(i);
    }
    while(!st.empty())st.pop();
    for(int i=n-2;i>=0;--i){
        while(!st.empty()&&a[st.top()]>a[i])st.pop();
        if(!st.empty()){
            int x=st.top()-1;
            if(x>i){
                queries y={i+1,x,query(i+1,x+1)};
                add[x-i].pb(y);
            }
        }
        st.push(i);
    }
    q=read();
    REP(i,0,q){
        int l=read()-1,r=read()-1,k=read();
        if(k==1)ans[i]=qmax(l,r);
        else{
            ans[i]=max(query(l,l+k-1),query(r-k+1,r));
            --r;--k;
            qr[k].pb({l,r,i});
        }
    }
    --n;
    s1.build(0,n-1,0);s2.build(0,n-1,0);
    for(int i=n;i>=1;--i){
        for(auto j:add[i]){
            s1.update(j.l,0,n-1,0,j.id);
            s2.update(j.r,0,n-1,0,j.id);
        }
        for(auto j:qr[i]){
            ans[j.id]=max(ans[j.id],s1.query(j.l,j.r-i+1,0,n-1,0));
            ans[j.id]=max(ans[j.id],s2.query(j.l+i-1,j.r,0,n-1,0));
        }
    }
    REP(i,0,q)cout<<ans[i]<<"\n";
    cerr<<clock()<<endl;
    return 0;
}