对数据结构的爱

· · 题解

为了信仰来做这个题。

然后自己被题解和自己绕晕了好久……所以我会尽量讲得清楚一点。

由于 1e6 的数据范围,所以猜测是一个 polylog 的做法,来考虑线段树。

然后考虑需要在线段树节点上维护什么:

每一次的模拟都需要一个数与当前区间进行运算——所以考虑以某种方式在线段树上维护一个分段函数

如果这样,我们从左到右依次将得到的上一个函数值代入这个区间的函数,就可以在 \log 次函数求值中得到答案。

接下来就考虑如何对每一个线段树来保存,计算这个函数。

首先考虑存储。我们发现,这一个函数的图像,就是阶梯状地不断减 p,同时减 p 的断点个数不超过区间长度

所以我们大可对于每一个区间维护可能出现的每一个断点——对于区间 k,定义 c_x 为最小的 v,满足以 v 作为输入,得到的值是 v+sum_k-xp,(sum 表示线段树上区间的和)。简单地说,就是最小的值使得这个值经过这个区间后被减去了 x 次 p。对每一个区间如此保存空间复杂度显然为 O(n\log n)

如果我们对每一个区间维护出来了这个数组,我们的查询就已经可以在单次 O(\log^2 n) 的时间内完成。(每一次在这个区间上的数组二分得到应该减多少个 p,然后将得到的函数值给下一个区间)。

然后就可以开始考虑怎么求这个数组了。直接求是不现实的,根据线段树上维护信息的一般定式,我们考虑如何来 pushup。

对于一组分属左右区间的 c_x,c_y,如果其满足 c_{x+1}-1+sum_{lson}-xp \geq c_y(需要减 x 次的最大的值经过左区间后可以在右区间减去 y 次),那么就可以使用 \max(c_x,c_y-sum_{lson}+xp) 来更新 c_{x+y}

直接合并是平方复杂度,考虑优化。没什么其他路可走,就只能考虑有没有单调性来加快计算的速度。

先说明一个引理:对于任意一个区间,c_{x+1} - c_x \geq p

这个的证明是显然的:将 c_x 代入,在减完 x 次 p 过后必然会在一段连续的加过后得到一个 0,而这个 0 后面是一个非正数。

而我们如果要使得函数值在这个区间上多减一次 p,必然要在这个 0 的基础上再多加 p 以抵消后面的那个非正数才行。

我们令 f(x,y) = \max(c_x,c_y-sum_{lson}+xp),并且 x,y 要满足上文所述的限制。

定理:f(x,y) < f(x+1,y-1)

证明:因为 x,y 合法,所以有 c_{y} \leq c_{x+1}-1+sum_{lson}-xp \rightarrow c_y-sum_{lson}+xp \leq c_{x+1}-1 \rightarrow \max(c_x,c_y-sum_{lson}+xp) < c_{x+1},于是就证完了。

有了这个简单的定理,我们只要在能动 y 的时候尽量动 y,那就可以保证答案正确。(因为每一次都用的是最优的去更新能更新的每一个)。复杂度线性。

综上,所有的预处理可以在 O(n\log n) 的时间内完成,所有的询问可以在 O(m\log^2 n) 的时间内回答。

代码:1.3K,不折不扣的小清新数据结构题。

#include <cstdio>
#include <vector>
#include <algorithm>
using std::min;using std::max;
typedef long long LL; 
const LL inf = 1e18;
const int maxn = 1e6+5;
int n,m,p,a[maxn];
LL sum[maxn<<2];
std :: vector <LL> c[maxn<<2];
void pushup(int k,int l,int r,int mid){
    sum[k] = sum[k<<1] + sum[k<<1|1];
    for(int x=0,y=0;x<=mid-l+1;++x){
        if(y==r-mid+1)--y;
        for(;y<=r-mid;++y){
            LL nd = c[k<<1][x+1]-1-1ll*x*p+sum[k<<1];
            if(c[k<<1|1][y] > nd){--y;break;}
            c[k][x+y] = min(c[k][x+y],max(c[k<<1][x],c[k<<1|1][y]+1ll*x*p-sum[k<<1]));
        }
    }
}
void build(int k,int l,int r){
    c[k].resize(r-l+3);for(int i=0;i<=r-l+2;++i)c[k][i] = (i?inf:-inf);
    if(l == r)return sum[k]=a[l],c[k][1]=p-a[l],void();
    int mid = l+r>>1;
    build(k<<1,l,mid),build(k<<1|1,mid+1,r),pushup(k,l,r,mid);
}
LL getans(int k,int l,int r,int x,int y,LL in){
    if(l == x && r == y){
        int pos = std :: upper_bound(c[k].begin(),c[k].end(),in)-c[k].begin()-1;
        return in+sum[k]-1ll*p*pos;
    }
    int mid = l+r>>1;
    if(mid < x)return getans(k<<1|1,mid+1,r,x,y,in);
    if(y <= mid)return getans(k<<1,l,mid,x,y,in);
    return getans(k<<1|1,mid+1,r,mid+1,y,getans(k<<1,l,mid,x,mid,in));
}
int main(){
    scanf("%d %d %d",&n,&m,&p);
    for(int i=1;i<=n;++i)scanf("%d",&a[i]);
    build(1,1,n);while(m--){int l,r;scanf("%d %d",&l,&r),printf("%lld\n",getans(1,1,n,l,r,0));}
    return 0;
}