题解:P10100 [ROIR 2023 Day 2] 石头

· · 题解

题解

模拟赛场上想到了做法,可惜最后没有写完。题目非常好,代码实现细节很多。我的做法复杂度是 O(n{\log^2\hspace{-0.1cm}n}),可能不是最优,但是思路可以说是非常顺畅易懂的。

乍一看不知道怎么做?写暴力!拿一个数组 num_{i,j} 记录第一步时将哪些石头涂成白色,可以使得第 i 块石头在第 j 步变成白色。

暴力代码如下:

    read(n,q);
    for(int i=1;i<=n;i++){
        read(a[i]);
    }
    a[0]=inf;
    a[n+1]=inf;
    for(int i=1;i<=n;i++){
        num[i][1].push_back(i);
        int fr=i-1,re=i+1;
        for(int j=2;j<=n;j++){
            if(a[fr]<a[re]){
                num[fr][j].push_back(i);
                fr--;
            }
            else{
                num[re][j].push_back(i);
                re++;
            }
        }
    }
    for(int i=1;i<=n;i++){
        for(int j=1;j<=n;j++){
            cout<<i<<" "<<j<<":";
            for(int l=0;l<num[i][j].size();l++){
                cout<<num[i][j][l]<<" ";
            }
            cout<<endl;
        }
    }

通过暴力,发现正好 j 次取到点 i 的点一定分布在 i 的两侧,且均为连续的区间。距离 i 越远的点,它取到 i 的次数一定是单调不降的。于是我们可以二分。对于每次询问的 p,k ,(以右侧为例,左侧也是一样的),先二分找到取到 p 的次数大于 k 的第一个点,再二分找到取到 p 的次数大于等于 k 的第一个点,就能确定取到 p 的次数等于 k 的点的个数了。

于是问题变成了知道点 x,y,求第一个给 y 染色时,需要多少次能取到 x?还是以 yx 右侧为例,我们来看一个例子:

1 4 6 2 3 9 7 8 5 10

考虑第一个给第 7 个点染色,求几次能取到第 1个点。显然 1,4,6,2,3,9,7 这些左边的数一定会被取完,关键是右边会取哪些?这取决于前 6 个点中的最大值(第 6 个点,值为 9),因为向左取到这个最大值后,才能向右边取最多的点。只需要找到第 7 个点右侧(不包括它本身)第一个比 9 大的数(第 10 个点,值为 10),前面的就都能取到。

于是又可以二分。先用 ST 表预处理出区间最大值,这样每次查询都是 O(1) 的。二分找到右边第一个比左边最大值大的点即可。

对于左边也是完全一样的,总共要用到 6 个二分。可以看代码再理解一下。

#include<algorithm>
#include<cmath>
#include<cstdio>
#include<cstring>
#include<ctime>
#include<iomanip>
#include<iostream>
#include<list>
#include<map>
#include<queue>
#include<set>
#include<stack>
#include<string>
#include<vector>
#define ll long long
#define DBG(x) cout << #x << "=" << x << endl
#define inf 0x3f3f3f3f3f3f3f3f
#define mod 998244353
#define N 200005
using namespace std;
template <typename T>
void read(T& x) {
    x = 0;
    ll t = 1;
    char ch;
    ch = getchar();
    while (ch < '0' || ch > '9') {
        if (ch == '-') {
            t = -1;
        }
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = x * 10 + (ch - '0');
        ch = getchar();
    }
    x *= t;
}
template <typename T, typename... Args>
void read(T& first, Args&... args) {
    read(first);
    read(args...);
}
template <typename T>
void write(T y) {
    T x = y;
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (x > 9) {
        write(x / 10);
    }
    putchar(x % 10 + '0');
}
template <typename T, typename... Ts>
void write(T arg, Ts... args) {
    write(arg);
    if (sizeof...(args) != 0) {
        putchar(' ');
        write(args...);
    }
}
int n,q,p,k,a[N],st[N][30],lg2[N];
int mx(int x,int y){
    if(x>y){
        return x;
    }
    return y;
}
inline int query(int l,int r){
    int len=lg2[r-l+1];
    return mx(st[l][len],st[r-(1<<len)+1][len]);
}
inline int got1(int x,int L){//L在x左边时,第一个染x,取到L的次数
    int sum=0;
    sum=sum+(x-L+1);
    int maxx=query(L,x-1);
    int l=x+1,r=n,ans=0;
    while(l<=r){
        int mid=(l+r)/2;
        if(query(x+1,mid)<maxx){
            ans=mid;
            l=mid+1;
        }
        else{
            r=mid-1;
        }
    }
    if(ans==0){
        return sum;
    }
    sum=sum+ans-x;
    return sum;
}
inline int got2(int x,int R){//R在x右边时,第一个染x,取到R的次数
    int sum=0;
    sum=sum+(R-x+1);
    int maxx=query(x+1,R);
    int l=0,r=x-1,ans=0;
    while(l<=r){
        int mid=(l+r)/2;
        if(query(mid,x-1)<maxx){
            ans=mid;
            r=mid-1;
        }
        else{
            l=mid+1;
        }
    }
    if(ans==0){
        return sum;
    }
    sum=sum+x-ans;
    return sum;
}
signed main(){
    read(n,q);
    lg2[1]=0;
    for(register int i=2;i<=n+1;i++){
        lg2[i]=lg2[i>>1]+1;
    }
    for(register int i=1;i<=n;i++){
        read(a[i]);
        st[i][0]=a[i];
    }
    a[0]=inf;
    a[n+1]=inf;
    st[0][0]=inf;
    st[n+1][0]=inf;
    n++;
    for(register int j=1;j<=lg2[n]+1;j++){
        for(register int i=0;i+(1<<j)-1<=n;i++){
            st[i][j]=mx(st[i][j-1],st[i+(1<<(j-1))][j-1]);
        }
    }
    while(q--){
        read(p,k);
        if(k==1){
            putchar('1');
            putchar('\n');
            continue;
        }
        int l=p+1,r=p+k,ans1=0,ans2=0,ans=0;
        if(p+k<=n){
            while(l<=r){
                int mid=(l+r)/2;
                if(got1(mid,p)>k){
                    ans1=mid;
                    r=mid-1;
                }
                else{
                    l=mid+1;
                }
            }
            l=p+1,r=p+k-1;
            while(l<=r){
                int mid=(l+r)/2;
                if(got1(mid,p)>=k){
                    ans2=mid;
                    r=mid-1;
                }
                else{
                    l=mid+1;
                }
            }
        }
        if(ans1>=ans2){
            ans+=(ans1-ans2);
        }
        l=p-k,r=p-1,ans1=0,ans2=0;
        if(p-k>=0){
            while(l<=r){
                int mid=(l+r)/2;
                if(got2(mid,p)>k){
                    ans1=mid;
                    l=mid+1;
                }
                else{
                    r=mid-1;
                }
            }
            l=p-k+1,r=p-1;
            while(l<=r){
                int mid=(l+r)/2;
                if(got2(mid,p)>=k){
                    ans2=mid;
                    l=mid+1;
                }
                else{
                    r=mid-1;
                }
            }
        }
        if(ans1<=ans2){
            ans+=(ans2-ans1);
        }
        write(ans);
        putchar('\n');
    }
    return 0;
}