P8304 题解

· · 题解

要求一个字符串 s 的任意前后缀中 0 的个数不超过 1 的个数。不妨把 0 视为 -1,则变为该字符串的任意一个前后缀和 pre_i,suf_i\ge 0。现在要在 [l,r] 内选出一个子序列,可以贪心,从左往右选,如果当前是 0 且前缀和为 0 则不能选,否则选上,然后再对第一次产生的子序列从右往左选一遍,得到的就是最长的合法子序列。

显然,在第一次中删除了 (\max\limits_{i=1}^n -pre_i) 个,但是第一次会对第二次产生影响。设第一次后的后缀和为 suf',则有 suf'_i=suf_i+(\max\limits_{j=1}^n -pre_j) - (\max\limits_{j=1}^{i-1}-pre_j),后面部分相当于去掉在 i 后面被删除的 0 的影响。于是总删除个数为

(\max_{i=1}^n -pre_i)+(\max_{i=1}^n -suf'_i)&=(\max_{i=1}^n -pre_i)+\max_{i=1}^n [-suf_i+(\max\limits_{j=1}^{i-1}-pre_j)]-(\max\limits_{j=1}^n -pre_j)\\ &=\max_{i=1}^n [-suf_i+(\max\limits_{j=1}^{i-1}-pre_j)]\\ &=\max_{1\le j<i\le n} (-suf_i-pre_j)\\ &=-\min_{1\le i<j \le n}(pre_i+suf_j) \end{align*}

现在相当于要选一个无交集的前缀后缀和,使它们的和最小,即要求剩下中间的子段和最大。于是问题转化为单点修改,查询区间最大子段和,使用线段树维护即可。时间复杂度 O((n+q)\log n)

#include<iostream>
#include<cstdio>
#define N 1000010
using namespace std;
struct T{
    int mx,lmx,rmx,sum;
}t[N*4];
int n,q,a[N];
T operator +(T A,T B){
    T C;
    C.sum=A.sum+B.sum;
    C.mx=max(max(A.mx,B.mx),A.rmx+B.lmx);
    C.lmx=max(A.lmx,A.sum+B.lmx);
    C.rmx=max(B.rmx,B.sum+A.rmx);
    return C;
}
void build(int u,int l,int r){
    if(l==r){
        if(a[l])t[u].sum=t[u].mx=t[u].lmx=t[u].rmx=a[l];
        else t[u].sum=t[u].mx=t[u].lmx=t[u].rmx=-1;
        return;
    }
    int mid=(l+r)>>1;
    build(u<<1,l,mid);
    build(u<<1|1,mid+1,r);
    t[u]=t[u<<1]+t[u<<1|1];
    return;
}
void update(int u,int l,int r,int p,int v){
    if(l==r){
        if(v)t[u].sum=t[u].mx=t[u].lmx=t[u].rmx=v;
        else t[u].sum=t[u].mx=t[u].lmx=t[u].rmx=-1;
        return;
    }
    int mid=(l+r)>>1;
    if(p<=mid)update(u<<1,l,mid,p,v);
    if(p>mid)update(u<<1|1,mid+1,r,p,v);
    t[u]=t[u<<1]+t[u<<1|1];
    return;
}
T query(int u,int l,int r,int L,int R){
    if(L<=l&&r<=R)return t[u];
    int mid=(l+r)>>1;
    T ss={0,0,0,0};
    if(L<=mid)ss=ss+query(u<<1,l,mid,L,R);
    if(R>mid)ss=ss+query(u<<1|1,mid+1,r,L,R);
    return ss;
}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);cout.tie(0);
    int op,l,r;
    string s;
    cin>>n>>q;
    cin>>s;
    for(int i=1;i<=n;i++)
        a[i]=s[i-1]-'0';
    build(1,1,n);
    while(q--){
        cin>>l>>r;
        T ans=query(1,1,n,l,r);
        // cout<<ans.sum<<' '<<ans.mx<<endl;
        cout<<(r-l+1-(-(ans.sum-ans.mx))?r-l+1-(-(ans.sum-ans.mx)):-1)<<'\n';
    }
    return 0;
}