题解 P1295 【[TJOI2011]书架】

· · 题解

线段树+二分

(数据范围只有1e5,为何不直接用线段树愉快地莽过去呢qwq

首先分析题意,得到DP方程

f_i=min(f_{j-1}+max(h_j\sim h_i))\ \ (sum_{j\sim i}\le m)

用线段树存zn_j=f_{j-1}+max(h_j\sim h_i)),求f_i直接区间最小值就行了

考虑怎么在线段树上维护zn的值

线段树上每个节点存这个区间zn的最小值,h的最大值以及f的最大值

每当i向右移动一位,先考虑合法区间的左端点会不会因为sum_{j\sim i}>m向右移动。然后因为右端加入了个h_i导致一些jmax(h_j\sim h_i)变大,可以发现从 最靠右的 满足h_j\ge h_i的位置往右,这个值都会变为h_i,可以二分找出那个位置然后在线段树上区间修改h的值。最后别忘了把zn_{i+1}赋值为f_i+h_{i+1}

代码写起来也非常简单

#include<map>
#include<cmath>
#include<ctime>
#include<queue>
#include<vector>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define qmin(x,y) (x=min(x,y))
#define qmax(x,y) (x=max(x,y))
#define mp(x,y) make_pair(x,y)
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
inline int read(){
    int ans=0,fh=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){
        if(ch=='-') fh=-1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9')
        ans=ans*10+ch-'0',ch=getchar();
    return ans*fh;
}
const int maxn=1e5+100;
const int inf=0x7fffffff;
int n,m,h[maxn],zn[maxn<<2],fn[maxn<<2],hx[maxn<<2],lz[maxn<<2];
#define lc (o<<1)
#define rc (o<<1|1)
inline void paint(int o,int z){
    hx[o]=z,zn[o]=fn[o]+z,lz[o]=z;
}
inline void pushdown(int o){
    if(!lz[o]) return;
    paint(lc,lz[o]),paint(rc,lz[o]),lz[o]=0;
}
inline void updata(int o){
    fn[o]=min(fn[lc],fn[rc]);
    hx[o]=max(hx[lc],hx[rc]);
    zn[o]=min(zn[lc],zn[rc]);
}
inline void revise(int o,int l,int r,int ql,int qr,int qz){
    if(ql==l&&qr==r){paint(o,qz);return;}
    int mid=l+r>>1;pushdown(o);
    if(ql<=mid) revise(lc,l,mid,ql,min(qr,mid),qz);
    if(qr> mid) revise(rc,mid+1,r,max(ql,mid+1),qr,qz);
    updata(o);
}
inline int gethx(int o,int l,int r,int ql,int qr){
    if(ql==l&&qr==r) return hx[o];
    int mid=l+r>>1,Ans=0;pushdown(o);
    if(ql<=mid) qmax(Ans,gethx(lc,l,mid,ql,min(qr,mid)));
    if(qr> mid) qmax(Ans,gethx(rc,mid+1,r,max(ql,mid+1),qr));
    return Ans;
}
inline int getzn(int o,int l,int r,int ql,int qr){
    if(ql==l&&qr==r) return zn[o];
    int mid=l+r>>1,Ans=inf;pushdown(o);
    if(ql<=mid) qmin(Ans,getzn(lc,l,mid,ql,min(qr,mid)));
    if(qr> mid) qmin(Ans,getzn(rc,mid+1,r,max(ql,mid+1),qr));
    return Ans;
}
inline void insert(int o,int l,int r,int qd,int qf,int qh){
    if(l==r){fn[o]=qf,hx[o]=qh,zn[o]=qf+qh;return;}
    int mid=l+r>>1;pushdown(o);
    if(qd<=mid) insert(lc,l,mid,qd,qf,qh);
    else insert(rc,mid+1,r,qd,qf,qh);
    updata(o);
}
inline int getlx(int l,int r,int z){
    int Ans=r+1,R=r;r++;
    while(l<r){
        int mid=l+r>>1;
        if(gethx(1,1,n,mid,R)<z)//如果区间h最大值小于h_i的话 
            Ans=r=mid;
        else l=mid+1;
    }
    return Ans;
}
int main(){
//  freopen("nh.in","r",stdin);
//  freopen("zhy.out","w",stdout);
    n=read(),m=read();
    for(int i=1;i<=n;i++) h[i]=read();
    int sum=0;insert(1,1,n,1,0,h[1]);
    for(int l=1,r=1;r<=n;r++){
        sum+=h[r];while(sum>m) sum-=h[l++];//判断左端点是否应该右移 
        int x=getlx(l,r-1,h[r]);//二分找到满足h_j>h_i的最右的j 
        if(x<r) revise(1,1,n,x,r-1,h[r]);//将它们的h统一修改为h_i 
        int fr=getzn(1,1,n,l,r);//求出f_i 
        if(r==n) printf("%d\n",fr);//输出 
        else insert(1,1,n,r+1,fr,h[r+1]);//zn_{i+1}=f_i+h_{i+1} 
    }
    return 0;
}