题解 SP2916 【GSS5 - Can you answer these queries V】

· · 题解

大概可以算是线段树维护区间最大子段和的模板题了吧...

题意

对于长度为n的序列,回答m个询问,每个询问查询左端点在[l_1,r_1]之中,右端点在[l_2,r_2]之中的所有区间和的最大值,即:

\max_{l_1\leq l\leq l_2,r_1\leq r\leq r_2} \sum_{i=l}^{j} a[i]

( 其中保证l_1\leq l_2;r_1\leq r_2;n,m,|a[i]|\leq 1E4 ).

solution:

0 如何建立线段树

0.0 节点信息

建立线段树时,对于每个节点,额外存储4个信息pre,mid,suf,sum

详细信息如图:

可以看出,父节点的pre可由其左儿子的pre或其左儿子的sum加右儿子的pre更新,在此情况中,父节点的pre由前者更新;

0.1 更新

类比父节点pre的更新方式,可以写出其midsuf的更新方式:

(为了方便,我把这4个数据封装到一个struct中了...)

struct trans{
    int pre,mid,suf,sum;
};
trans merge(trans s1,trans s2){
    trans ans;
    ans.pre=max(s1.pre,s2.pre+s1.sum);
    ans.mid=max(max(s1.mid,s2.mid),s1.suf+s2.pre);
    //父节点的mid要由3个数据更新
    ans.suf=max(s1.suf+s2.sum,s2.suf);
    ans.sum=s1.sum+s2.sum;
    //别忘了更新sum
    return ans;
}

而对于叶节点,在建树时顺便更改其数据即可,这样就可以很容易地写出建树的代码了:

void build(int rt,int l,int r){
    p[rt].l=l;p[rt].r=r;
    if(l==r){
        p[rt].dat.pre=p[rt].dat.mid=p[rt].dat.suf=p[rt].dat.sum=dat[l];
        //我这里规定每个子序列不可以为空,如果这里规定序列可以为空,那相应的,查询时也应更改。
        return;
    }
    int mid=(l+r)/2;
    build(rt*2,l,mid);
    build(rt*2+1,mid+1,r);
    update(rt);
    return;
}

Tips:如果还没有写过这类问题,建议先做GSS1

1 区间最大字段和查询

与普通线段树差不多,此线段树的查询结果也由多个节点的值合并而成:

如图,比如查询的区间是[3,8],那么,查询的答案将由存储范围为[3,4][5,8]的两个区间合并而成;

我们不难发现,在建立线段树和查询时信息的合并方式是相同的,那么,我们也就可以很容易地写出查询的代码了:

trans query_sub_max(int rt,int l,int r){
    if(l>r){
    //这个是为了防止之后的查询中出现的BUG用的
        return (trans){0,0,0,0};
    }
    if(p[rt].l==l&&p[rt].r==r){
        return p[rt].dat;
    }
    int mid=p[rt*2].r;
    if(r<=mid)return query_sub_max(rt*2,l,r);
    else if(l>mid)return query_sub_max(rt*2+1,l,r);
    //如果要查寻的区间没有跨越其中一个儿子的范围,直接返回即可
    return merge(query_sub_max(rt*2,l,mid),query_sub_max(rt*2+1,mid+1,r));
}

可以看出,查询一次的复杂度为O(log(n))

2 对左右端点有限制的区间的查询

本题对左右端点的范围有限制,不能一次查询出结果,于是可以考虑分情况多次查询:

2.1 两个区间没有交集的情况

如图,可以看出,蓝框中的数是无论如何要选上的,为了最大化结果,只能最大化区间[l_1,r_1][l_2,r_2]的选中部分的和,所以,不难写出这种情况的代码:

if(r1<l2){
    int tmp=query_sub_max(1,l1,r1).suf;
    //区间[l1,r1]的最大后缀和
    tmp+=query_sub_max(1,r1+1,l2-1).sum;
    //蓝框部分的和
    tmp+=query_sub_max(1,l2,r2).pre;
    //区间[l2,r2]的最大前缀和
    return tmp;
}

2.2 两个区间有交集的情况

而这种情况就比较复杂了,可分为3种情况,如图:

我们可以发现,其实情况1与情况2是可以用同一种方式求出的,即区间[l_1,l_2]suf+区间[l_2,r_2]pre(或区间[l_1,r_1]suf+区间[r_1,r_2]pre);

而情况3就很好写了,只需求出区间[l_2,r_1]mid即可;

Tips: 注意特判区间[l_1,r_1],[r_1,r_2]是否存在,即l_1l_2,r_1r_2是否相等。

代码如下:

int ans=query_sub_max(1,l2,r1).mid;
    if(l1<l2)ans=max(ans,query_sub_max(1,l1,l2).suf+query_sub_max(1,l2,r2).pre-dat[l2]);
    if(r2>r1)ans=max(ans,query_sub_max(1,l1,r1).suf+query_sub_max(1,r1,r2).pre-dat[r1]);
    return ans;
}

最后,附上AC代码:

#include<bits/stdc++.h>
#define debug(x) cerr<<#x<<" = "<<x<<endl;
#define db(x) debug(x)
using namespace std;
int dat[10005],n;
struct trans{
    int pre,mid,suf,sum;
    int fin(){
        return max(max(pre,suf),mid);
    }
    trans DB(){
        db(pre);db(mid);db(suf);db(sum);
        return *this;
    }
};
trans merge(trans s1,trans s2){
    trans ans;
    ans.pre=max(s1.pre,s2.pre+s1.sum);
    ans.mid=max(max(s1.mid,s2.mid),s1.suf+s2.pre);
    ans.suf=max(s1.suf+s2.sum,s2.suf);
    ans.sum=s1.sum+s2.sum;
    return ans;
}
struct node{
    int l,r;
    trans dat;
}p[40005];
void update(int rt){
    p[rt].dat=merge(p[rt*2].dat,p[rt*2+1].dat);
    return;
}
void build(int rt,int l,int r){
    p[rt].l=l;p[rt].r=r;
    if(l==r){
        p[rt].dat.pre=p[rt].dat.mid=p[rt].dat.suf=p[rt].dat.sum=dat[l];
        return;
    }
    int mid=(l+r)/2;
    build(rt*2,l,mid);
    build(rt*2+1,mid+1,r);
    update(rt);
    return;
}
trans query_sub_max(int rt,int l,int r){
    if(l>r){
        return (trans){0,0,0,0};
    }
    if(p[rt].l==l&&p[rt].r==r){
        return p[rt].dat;
    }
    int mid=p[rt*2].r;
    if(r<=mid)return query_sub_max(rt*2,l,r);
    else if(l>mid)return query_sub_max(rt*2+1,l,r);
    return merge(query_sub_max(rt*2,l,mid),query_sub_max(rt*2+1,mid+1,r));
}
int query(int l1,int r1,int l2,int r2){
    if(r1<l2){
        int tmp=query_sub_max(1,l1,r1).suf;
        tmp+=query_sub_max(1,r1+1,l2-1).sum;
        tmp+=query_sub_max(1,l2,r2).pre;
        return tmp;
    }
    int ans=query_sub_max(1,l2,r1).mid;
    if(l1<l2)ans=max(ans,query_sub_max(1,l1,l2).suf+query_sub_max(1,l2,r2).pre-dat[l2]);
    if(r2>r1)ans=max(ans,query_sub_max(1,l1,r1).suf+query_sub_max(1,r1,r2).pre-dat[r1]);
    return ans;
}
int main(){
    int t;
    scanf("%d",&t);
    while(t--){
        scanf("%d",&n);
        for(int i=1;i<=n;i++)scanf("%d",&dat[i]);
        build(1,1,n);
        int m;
        scanf("%d",&m);
        while(m--){
            int x1,x2,y1,y2;
            scanf("%d%d%d%d",&x1,&y1,&x2,&y2);
            printf("%d\n",query(x1,y1,x2,y2));
        }
    }
    return 0;
}