题解:AT_abc383_g [ABC383G] Bar Cover

· · 题解

考虑 DP。

首先,不考虑重叠的限制。记 s_i 表示从 i 开始覆盖的贡献,记 f_{i,j} 表示前 i 个选了 j 个,那么转移显然为:

f_{i,j}=max(f_{i-1,j},f_{i-1,j-1}+s_i)

注意到,这个东西是凸的。所以我们可以 O(n) 的合并两个 f,于是考虑分治,枚举跨过分界线的那一段,记 g_{l,r,s_1,s_2} 表示区间 l,r,左边空了 s_1 个,右边空了 s_2 个时的 f 的差分数组,枚举中间的段合并即可。时间复杂度 O(nk^2\log n)

代码很丑。


#include<bits/stdc++.h>
#define int long long 
#define mid ((l+r)>>1)
#define ls (x<<1)
#define rs ((x<<1)|1) 
using namespace std;
namespace IO{
    char buff[1<<21],*p1=buff,*p2=buff;
    inline char getch(){
        return p1==p2&&(p2=((p1=buff)+fread(buff,1,1<<21,stdin)),p1==p2)?EOF:*p1++;
    }
    template<typename T>
    inline void read(T &x){
        char ch=getch();int fl=1;x=0;
        while(ch>'9'||ch<'0'){if(ch=='-')fl=-1;ch=getch();}
        while(ch<='9'&&ch>='0'){x=x*10+ch-48;ch=getch();}
        x*=fl;
    }
    template<typename T,typename ...Args>
    inline void read(T &x,Args& ...args){
        read(x);read(args...);
    }
    char obuf[1<<21],*p3=obuf;
    inline void putch(char ch){
        if(p3-obuf<(1<<21))*p3++=ch;
        else fwrite(obuf,p3-obuf,1,stdout),p3=obuf,*p3++=ch;
    }
    char ch[100];
    template<typename T>
    void write(T x){
        if(!x)return putch('0');
        if(x<0)putch('-'),x*=-1;
        int top=0;
        while(x)ch[++top]=x%10+48,x/=10;
        while(top)putch(ch[top]),top--;
    }
    template<typename T,typename ...Args>
    void write(T x,Args ...args){
        write(x);write(args...);
    }
    void flush(){fwrite(obuf,p3-obuf,1,stdout);}
}
using namespace IO;
const int N=2e5+5,Inf=2e14+100;
int n,k;
int a[N],s[N];
struct node{
    vector<int>f[5][5]; 
}tree[N<<2];
vector<int>tmp;
void sol(int x,int l,int r){
    int len=r-l+1;
    if(r-l+1<2*k){
        for(int i=0;i<k;i++){
            for(int j=0;j<k&&i+j<=len;j++){
                int sum=-Inf;
                for(int w=l+i;w+k-1+j<=r;w++)
                    sum=max(sum,s[w]);
                tree[x].f[i][j].push_back(sum); 
            }
        }
        return;
    }
    sol(ls,l,mid),sol(rs,mid+1,r);
    for(int s1=0;s1<k;s1++){
        for(int s2=0;s2<k;s2++){
            int Siz=(len-s1-s2)/k;
            for(auto i=tree[ls].f[s1][0].begin(),j=tree[rs].f[0][s2].begin();i!=tree[ls].f[s1][0].end()||j!=tree[rs].f[0][s2].end();){
                if(i!=tree[ls].f[s1][0].end()&&(j==tree[rs].f[0][s2].end()||*j<=*i))tree[x].f[s1][s2].push_back(*i++);
                else tree[x].f[s1][s2].push_back(*j++);
            }
            while(tree[x].f[s1][s2].size()<Siz)tree[x].f[s1][s2].push_back(-Inf);
            for(int w=0;w+1<k;w++){
                if(k-w-1+s2>(r-mid)||s1+w+1>(mid-l+1))continue;
                tmp.clear();
                int c=s[mid-w],ops=1;
                for(auto i=tree[ls].f[s1][w+1].begin(),j=tree[rs].f[k-w-1][s2].begin();i!=tree[ls].f[s1][w+1].end()||j!=tree[rs].f[k-w-1][s2].end()||ops;){
                    if(ops&&(i==tree[ls].f[s1][w+1].end()||*i<=c)&&(j==tree[rs].f[k-w-1][s2].end()||*j<=c))tmp.push_back(c),ops=0;
                    else if(i!=tree[ls].f[s1][w+1].end()&&(ops==0||c<=*i)&&(j==tree[rs].f[k-w-1][s2].end()||*j<=*i))tmp.push_back(*i++);
                    else tmp.push_back(*j++);
                }
                while(tmp.size()<Siz)tmp.push_back(-Inf);
                int last=0;
                int now1=0,now2=0;
                for(int i=0;i<Siz;i++){
                    now1+=tree[x].f[s1][s2][i];
                    now2+=tmp[i];
                    tree[x].f[s1][s2][i]=max(now1,now2)-last;
                    last=max(now1,now2);
                }
            }
        }
    }
}
int dp[N];
signed main(){
    read(n,k);
    for(int i=1;i<=n;i++)read(a[i]);
    for(int i=1;i+k-1<=n;i++)
        for(int j=0;j<k;j++)
            s[i]+=a[i+j];
    sol(1,1,n);
    memset(dp,-0x3f,sizeof dp);
    for(int s1=0;s1<k;s1++){
        for(int s2=0;s2<k;s2++){
            int Siz=(n-s1-s2)/k;
            int now=0;
            for(int i=0;i<Siz;i++){
                now+=tree[1].f[s1][s2][i];
                dp[i+1]=max(dp[i+1],now);
            }
        }
    }
    for(int i=1;i<=(n/k);i++){
        write(dp[i]),putch(' ');
    }
    flush();
    return 0;
}