题解:P7968 [COCI 2021/2022 #2] Osumnjičeni

· · 题解

题意

给定 n 个区间,若若干个编号连续的区间两两相离则可以合并处理,问对于 l,r 的区间要处理所有区间至少需要处理几次。

贪心

得到一个询问区间,我们发现:

  1. 必有一组从 l 开始;
  2. 编号靠后的元素能扩展的极大组中最靠右的元素编号更大。

这样的话我们不难得出从最靠左的元素开始尽量合并处理能得到最优解。

计算极大区间

首先我们要得出从 l 开始的最大能合并处理的 r,发现:

  1. 空集合法;

我们从右至左遍历,用线段树(维护标签 t1,t2 分别指该点能确定的在与这个点所在区间相交或包含这个点时的最小标号)得出与 l 相交的编号大于 l 的最小标号。接下来结合 l + 1 的结果即可。

倍增

显然,接下来要做的就是用倍增优化贪心。

写这一章是为了推销来自大佬博客的优化版倍增。

我们在 nxt_i,即走一步的结果的基础上,建立 jmp_i,即跳跃指针。

建立指针时,如果 nxt_i 通过指针接下来跳的两步的跨越步数相等,jmp_i 赋值为 jmp_{jmp_{nxt_i}},否则 jmp_i = nxt_i。跳链时优先跳 jmp_i,无法跳到 jmp_i 时再试图跳 nxt_i

相比于普通倍增:

  1. 功能性基本相同;
  2. 空间开销仅为 \operatorname{O}(n),空间更小;
  3. 保证跳跃长度为奇数,这在一些神题中很有用,如【NOIP2012提高】开车旅行。

这样我们就解决了这道题。

代码

#include<bits/stdc++.h>
using namespace std;
int t1[1600009],t2[1600009],lsh[400009],dep[200009],cnt,q,fl,fr,n,l[200009],r[200009];
void add(int k,int l,int r,int lq,int rq,int v){
    //printf(" %d %d %d %d %d %d\n",k,l,r,lq,rq,v);
    if(l > rq || r < lq)
        return;
    t1[k] = min(t1[k],v);
    if(l >= lq && r <= rq)
        t2[k] = min(t2[k],v);
    else{
        int t = (l + r) >> 1;
        //printf("%d\n",t);
        add(k << 1,l,t,lq,rq,v);
        add((k << 1) | 1,t + 1,r,lq,rq,v);
    }
}
int nxt[200009],jmp[200009];
int query(int k,int l,int r,int lq,int rq){
    //printf("%d %d %d %d %d\n",k,l,r,lq,rq);
    if(l > rq || r < lq)
        return 0x3f3f3f3f;
    if(l >= lq && r <= rq)
        return t1[k];
    int m = (l + r) >> 1;
    //printf("%d\n",m);
    return min(t2[k],min(query(k << 1,l,m,lq,rq),query((k << 1) | 1,m + 1,r,lq,rq)));
}
int main(){
    scanf("%d",&n);
    for(int i = 1; i <= n; i ++){
        scanf("%d %d",&l[i],&r[i]);
        lsh[++cnt] = l[i];
        lsh[++cnt] = r[i];
    } 
    memset(t1,0x3f,sizeof(t1));
    memset(t2,0x3f,sizeof(t2));
    sort(lsh + 1,lsh + cnt + 1);
    cnt = unique(lsh + 1,lsh + cnt + 1) - (lsh + 1);
    for(int i = 1; i <= n; i ++)
        l[i] = lower_bound(lsh + 1,lsh + cnt + 1,l[i]) - (lsh),r[i] = lower_bound(lsh + 1,lsh + cnt + 1,r[i]) - (lsh);
    nxt[n] = jmp[n] = n + 1;
    dep[n] = 1;
    nxt[n + 1] = jmp[n + 1] = n + 1;
    add(1,1,cnt,1,cnt,n + 1);
    add(1,1,cnt,l[n],r[n],n);
    //puts("OOO");
    for(int i = n - 1; i > 0; i --){
        nxt[i] = query(1,1,cnt,l[i],r[i]);
        nxt[i] = min(nxt[i],nxt[i + 1]);
        add(1,1,cnt,l[i],r[i],i);
        dep[i] = dep[nxt[i]] + 1;
    //  printf("%d\n",nxt[i]);
        if(dep[nxt[i]] - dep[jmp[nxt[i]]] == dep[jmp[nxt[i]]] - dep[jmp[jmp[nxt[i]]]])
            jmp[i] = jmp[jmp[nxt[i]]];
        else
            jmp[i] = nxt[i];
        //printf("%d\n",nxt[i]);
    }
    scanf("%d",&q);
    while(q--){
        scanf("%d %d",&fl,&fr);
        int o = fl;
        while(o <= fr){
            if(jmp[o] <= fr)
                o = jmp[o];
            else
                o = nxt[o];
        }
        printf("%d\n",dep[fl] - dep[o]);
    }
}