题解 P5618 【[SDOI2015]道路修建】

· · 题解

用线段树维护最小生成树。

我们要维护以下几个东西。

$l_{val},r_{val}$:区间最左(右)的竖边的权值。 $l_{max},r_{max}$:区间最左(右)的竖边及其左(右)边的所有边的最大值。 $heng_{max}$:区间横边的最大值。 $sum$:答案。 然后我们在合并区间的时候,左区间的最右竖边和右区间的最左竖边和中间的横边会形成一个环,然后根据上面的维护信息找到环中的最大值,去掉即可。 然后有一个细节: 假如我们要去掉环上的最大边正好是左区间的唯一一条竖边,那么合并后的区间的最左竖边就不能从左区间取了,要取右区间的最左竖边,同时合并后的区间的l_max也要改。所以我们在记录一下区间里有多少条竖边,然后判断是否是这个特殊情况即可。 右区间出现这种情况同理。 ```cpp #include<cstdio> #include<algorithm> #define max(x,y) (x>y?x:y) using namespace std; int n,m,tot,x0,y0,x1,y1,l,r; char ch[101]; long long w,heng[2][60001],shu[60001]; struct node{ int l,r,tot,l_val,r_val; long long l_max,r_max,sum,heng_max; void read(int x) {l=r=x;l_max=r_max=l_val=r_val=sum=shu[x];tot=1;heng_max=0;} }tree[250001]; node pushup(node x,node y) { node now; now.l=x.l;now.r=y.r; now.heng_max=max(max(heng[0][x.r],heng[1][x.r]),max(x.heng_max,y.heng_max)); int del=max(max(heng[0][x.r],heng[1][x.r]),max(x.r_max,y.l_max)); now.sum=x.sum+y.sum+heng[0][x.r]+heng[1][x.r]-del;now.tot=x.tot+y.tot; now.l_val=x.l_val;now.r_val=y.r_val; now.l_max=x.l_max;now.r_max=y.r_max; if (del==x.r_val) { now.tot--; if (x.tot==1) { now.l_val=y.l_val; now.l_max=max(max(heng[0][x.r],heng[1][x.r]),max(x.heng_max,y.l_max)); } } else if (del==y.l_val) { now.tot--; if (y.tot==1) { now.r_val=x.r_val; now.r_max=max(max(heng[0][x.r],heng[1][x.r]),max(x.r_max,y.heng_max)); } } return now; } void build(int k,int l,int r) { if (l==r) { tree[k].read(l); return; } int mid=(l+r)>>1; build(k*2,l,mid);build(k*2+1,mid+1,r); tree[k]=pushup(tree[k*2],tree[k*2+1]); } void update(int k,int l,int r,int x) { if (l==r&&l==x) { tree[k].read(x); return; } int mid=(l+r)>>1; if (x<=mid) update(k*2,l,mid,x); else update(k*2+1,mid+1,r,x); tree[k]=pushup(tree[k*2],tree[k*2+1]); } void change(int k,int l,int r,int x) { if (l==r&&l==x) return; int mid=(l+r)>>1; if (x<=mid) change(k*2,l,mid,x); else change(k*2+1,mid+1,r,x); tree[k]=pushup(tree[k*2],tree[k*2+1]); } node query(int k,int l,int r,int x,int y) { if (x<=l&&r<=y) return tree[k]; int mid=(l+r)>>1; if (y<=mid) return query(k*2,l,mid,x,y); if (x>mid) return query(k*2+1,mid+1,r,x,y); return pushup(query(k*2,l,mid,x,y),query(k*2+1,mid+1,r,x,y)); } int main() { scanf("%d%d",&n,&m); for (int i=1;i<n;i++) scanf("%lld",&heng[0][i]); for (int i=1;i<n;i++) scanf("%lld",&heng[1][i]); for (int i=1;i<=n;i++) scanf("%lld",&shu[i]); getchar();getchar(); build(1,1,n); for (int i=1;i<=m;i++) { scanf("%s",ch); if (ch[0]=='C') { scanf("%d%d%d%d%lld",&x0,&y0,&x1,&y1,&w); getchar(); if (y0==y1) shu[y0]=w,update(1,1,n,y0); else { if (y0>y1) swap(y0,y1); heng[x0-1][y0]=w; change(1,1,n,y0); } } else { scanf("%d%d",&l,&r);getchar(); printf("%lld\n",query(1,1,n,l,r).sum); } } return 0; } ```