P11364 [NOIP2024] 树上查询 题解

· · 题解

考虑对序列直接分治,然后计算贡献。

先预处理出线段树上每个结点表示的区间的全部答案,一共有 O(n\log n) 个值。

然后对于每个询问,把它放在线段树上。现在我们只需要对于每个线段树上的结点,一些跨越线段树中点的询问,并且我们只需要考虑子区间也跨越中点的贡献。

设左儿子区间 [l,m] 右儿子 [m+1,r],可以发现有贡献的区间一定跨越 mm+1。换句话说,它是 \text{LCA}(m,m+1) 的祖先。所以可以将每个点替换为它和 \text{LCA}(m,m+1)\text{LCA},于是 [l,r] 的所有点都构成祖先后代关系,求 \text{LCA} 深度我们只需要对两个点的深度取较小值就可以了。

那这个问题就是简单的了:我们建立一个平面,x,y 轴分别是子区间左端点到 m 的距离,以及右端点到 m+1 的距离,且每个点有个权值 a_i 表示 [i,m+1] (这里指左侧,右侧为 [m,i])范围内的点的 \text{LCA} 的深度。

则区间 [l',r']\text{LCA} 深度为 \min(a_{l'},a_{r'})[l,r] 关于某一个 k 的答案就是 0\le x\le m-l, 0\le y\le r-m+1 的矩形内,斜线 x+y=k 上的最大权值。

考察权值 \ge k 的点集的形状:因为 a 在两维上都单调,所以这样的点是 x,y 轴上分别取一个前缀合法,即一个包含原点的子矩形,且这些矩形互相包含。我们称这样的矩形为关键的。每个关键的矩形 (x',y'),都能对 k=0\sim x'+y' 产生贡献。依此,就可以线性地求出每个斜线的答案了。预处理的复杂度就是 O(n\log n)

现在来考虑询问。一个询问在线段树上递归时,与上面的问题形式相同,只是在合并答案时,并非能取到所有的 x+y=k 的点,它还对 x,y 有一个上界限制。具体地,设询问为 (L,R,k) 则,m-x\ge L,m+1+y\le R。也就是说,询问的是一个斜的线段的最小值。

考虑关键矩形对询问的贡献。显然我们要求出最小的与该线段有交的关键矩形。由于关键矩形互相包含,所以如果一个矩形和线段有交,则比它大的都和线段有交。所以我们预处理出所有关键矩形,然后二分寻找。每个询问会在线段树上发生 O(\log n) 次合并,如果每次都二分,复杂度为 O(\log^2n),不可接受。

不过注意到一个区间在线段树上递归时,仅会发生一次左右端点都不取到线段树结点的左右端点的情况(l<L\le R<r),剩下的每次合并,都满足 L=lr=R。考虑对于前者,使用二分计算,对于后者优化复杂度。

不妨设 L=l,则这条线段应该长这样:

设线段的“断头”处为 (x_0,y_0),考察这两个矩形:

这两个有如下两种位置关系(前者为红色,后者为蓝色):

可以发现,无论哪种情况,这两个矩形中较大的一定是最优解。所以我们也预处理一下这两种矩形即可。

另外一种线段 R=r 显然做法相同。

此处合并的复杂度变为 O(1),所以整道题的复杂度为 O(n\log n)

实现中,无需比较矩形大小,显然权值 \min(a_x,a_y) 更小的矩形更大,记录这个权值即可。

因为洛谷评测机较慢,我的考场代码在洛谷上会 TLE,仅供参考。我在官方成绩中已通过。

// this is my noip.
#include <bits/stdc++.h>
using namespace std;
constexpr int Spp{1<<20};
char buf[Spp],*p1,*p2;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,Spp,stdin),p1==p2)?EOF:*p1++)
template<typename T>
void read(T &x) {
    char c;int f{1};
    do x=(c=getchar())-'0';
    while (!isdigit(c)&&c!='-');
    if (c=='-') x=0,f=-1;
    while (isdigit(c=getchar()))
        x=x*10+(c-'0');
    x*=f;
}
template <typename T,typename ...Args>
void read(T &x,Args &...args) { read(x);read(args...);}
constexpr int N(5e5),LG{18};
vector<int> e[N+5];
int fz[N+5],fa[N+5],dep[N+5],lda,dfn[N+5];
int st[LG+1][N+5];
void init(int u,int fa) {
    ::fa[u]=fa;
    dfn[u]=++lda;
    fz[lda]=u;
    dep[u]=dep[fa]+1;
    for (auto v:e[u])
        if (v!=fa)
            init(v,u);
}
int Max(int u,int v) {
    return dep[fz[u]]<dep[fz[v]]?u:v;
}
int LCA(int u,int v) {
    if (u==v) return u;
    u=dfn[u];v=dfn[v];
    if (u>v) swap(u,v);
    int z{__lg(v-u)};
    return fa[fz[Max(st[z][v],st[z][u+(1<<z)])]];
}
vector<int> ans[N*4],ans2[N*4];
vector<pair<int,int>> lo[N*4],ro[N*4],oo[N*4],lr[N*4],nl[N*4],nr[N*4];
int LC[N*4];
void build(int p,int L,int R) {
    ans[p].resize(R-L+2);
    ans2[p].resize(R-L+2);
    if (L==R) {
        ans[p][1]=dep[L];
        lo[p].emplace_back(L,L);
        ro[p].emplace_back(R,R);
        LC[p]=L;
        return;
    }
    int mid{L+R>>1};
    build(p<<1,L,mid);
    build(p<<1|1,mid+1,R);
    LC[p]=LCA(LC[p<<1],LC[p<<1|1]);
    for (int i{1};i<=mid-L+1;++i) ans[p][i]=ans[p<<1][i];
    for (int i{1};i<=R-mid;++i) ans[p][i]=max(ans[p][i],ans[p<<1|1][i]);
    lo[p]=lo[p<<1];
    ro[p]=ro[p<<1|1];
    for (auto [x,y]:lo[p<<1|1]) nr[p].emplace_back(LCA(x,mid),y),lo[p].emplace_back(LCA(x,LC[p<<1]),y);
    for (auto [x,y]:ro[p<<1]) nl[p].emplace_back(LCA(x,mid+1),y),ro[p].emplace_back(LCA(x,LC[p<<1|1]),y);
    merge(nl[p].begin(),nl[p].end(),nr[p].begin(),nr[p].end(),back_inserter(oo[p]),[](auto x,auto y){return dep[x.first]>dep[y.first];});
    int l{mid},r{mid+1},an{N};
    for (auto [x,y]:oo[p]) {
        l=min(l,y);
        r=max(r,y);
        lr[p].emplace_back(l,r);
        an=min(an,dep[x]);
        ans2[p][r-l+1]=max(ans2[p][r-l+1],an);
    }
    ans[p][R-L+1]=max(ans[p][R-L+1],ans2[p][R-L+1]);
    for (int i{R-L};i>=1;--i) ans[p][i]=max({ans[p][i],ans2[p][i],ans[p][i+1]});
    lo[p<<1].clear();lo[p<<1].shrink_to_fit();
    ro[p<<1|1].clear();ro[p<<1|1].shrink_to_fit();
}
int n;
int qry(int p,int l,int r,int k,int L,int R) {
    if (l<=L&&R<=r)
        return k>(R-L+1)?0:ans[p][k];
    int mid{L+R>>1};
    int res{0};
    if (l<=mid) res=max(res,qry(p<<1,l,r,k,L,mid));
    if (r>mid) res=max(res,qry(p<<1|1,l,r,k,mid+1,R));
    if (k<=R-L+1&&l<=mid&&r>mid) {
        if (l<=L) {
            if (r-k+1>=L)
                res=max(res,min(ans2[p][k],dep[nl[p][max(0,mid-(r-k+1))].first]));
        } else if (r>=R) {
            if (l+k-1<=R)
                res=max(res,min(ans2[p][k],dep[nr[p][max(0,(l+k-1)-mid-1)].first]));
        } else {
            int ll{0},rr{R-L};
            int l1{l},l2{r-k+1};
            int r2{r},r1{l+k-1};
            while (ll<=rr) {
                int md{ll+rr>>1};
                auto [LL,RR]{lr[p][md]};
                if (RR-LL+1<k) {
                    ll=md+1;
                    continue;
                }
                if (LL<=l1&&RR>=r1||LL<=l2&&RR>=r2||LL>=l&&LL+k-1<=r||RR<=r&&RR-k+1>=l) {
                    res=max(res,dep[oo[p][md].first]);
                    rr=md-1;
                } else ll=md+1;
            }
        }
    }
    return res;
}
int main() {
    // freopen("query.in","r",stdin);
    // freopen("query.out","w",stdout);
    read(n);
    for (int i{1};i<n;++i) {
        int u,v;read(u,v);
        e[u].push_back(v);
        e[v].push_back(u);
    }
    init(1,0);
    iota(st[0]+1,st[0]+1+n,1);
    for (int i{1};i<=LG;++i)
        for (int j{1<<i};j<=n;++j)
            st[i][j]=Max(st[i-1][j],st[i-1][j-(1<<i-1)]);
    build(1,1,n);
    int q;read(q);
    while (q--) {
        int l,r,k;read(l,r,k);
        cout<<qry(1,l,r,k,1,n)<<"\n";
    }
    return 0;
}