题解 P5618 【[SDOI2015]道路修建】
ImmortalWatcher
·
·
题解
用线段树维护最小生成树。
我们要维护以下几个东西。
$l_{val},r_{val}$:区间最左(右)的竖边的权值。
$l_{max},r_{max}$:区间最左(右)的竖边及其左(右)边的所有边的最大值。
$heng_{max}$:区间横边的最大值。
$sum$:答案。
然后我们在合并区间的时候,左区间的最右竖边和右区间的最左竖边和中间的横边会形成一个环,然后根据上面的维护信息找到环中的最大值,去掉即可。
然后有一个细节:
假如我们要去掉环上的最大边正好是左区间的唯一一条竖边,那么合并后的区间的最左竖边就不能从左区间取了,要取右区间的最左竖边,同时合并后的区间的l_max也要改。所以我们在记录一下区间里有多少条竖边,然后判断是否是这个特殊情况即可。
右区间出现这种情况同理。
```cpp
#include<cstdio>
#include<algorithm>
#define max(x,y) (x>y?x:y)
using namespace std;
int n,m,tot,x0,y0,x1,y1,l,r;
char ch[101];
long long w,heng[2][60001],shu[60001];
struct node{
int l,r,tot,l_val,r_val;
long long l_max,r_max,sum,heng_max;
void read(int x) {l=r=x;l_max=r_max=l_val=r_val=sum=shu[x];tot=1;heng_max=0;}
}tree[250001];
node pushup(node x,node y)
{
node now;
now.l=x.l;now.r=y.r;
now.heng_max=max(max(heng[0][x.r],heng[1][x.r]),max(x.heng_max,y.heng_max));
int del=max(max(heng[0][x.r],heng[1][x.r]),max(x.r_max,y.l_max));
now.sum=x.sum+y.sum+heng[0][x.r]+heng[1][x.r]-del;now.tot=x.tot+y.tot;
now.l_val=x.l_val;now.r_val=y.r_val;
now.l_max=x.l_max;now.r_max=y.r_max;
if (del==x.r_val)
{
now.tot--;
if (x.tot==1)
{
now.l_val=y.l_val;
now.l_max=max(max(heng[0][x.r],heng[1][x.r]),max(x.heng_max,y.l_max));
}
}
else if (del==y.l_val)
{
now.tot--;
if (y.tot==1)
{
now.r_val=x.r_val;
now.r_max=max(max(heng[0][x.r],heng[1][x.r]),max(x.r_max,y.heng_max));
}
}
return now;
}
void build(int k,int l,int r)
{
if (l==r)
{
tree[k].read(l);
return;
}
int mid=(l+r)>>1;
build(k*2,l,mid);build(k*2+1,mid+1,r);
tree[k]=pushup(tree[k*2],tree[k*2+1]);
}
void update(int k,int l,int r,int x)
{
if (l==r&&l==x)
{
tree[k].read(x);
return;
}
int mid=(l+r)>>1;
if (x<=mid) update(k*2,l,mid,x);
else update(k*2+1,mid+1,r,x);
tree[k]=pushup(tree[k*2],tree[k*2+1]);
}
void change(int k,int l,int r,int x)
{
if (l==r&&l==x) return;
int mid=(l+r)>>1;
if (x<=mid) change(k*2,l,mid,x);
else change(k*2+1,mid+1,r,x);
tree[k]=pushup(tree[k*2],tree[k*2+1]);
}
node query(int k,int l,int r,int x,int y)
{
if (x<=l&&r<=y) return tree[k];
int mid=(l+r)>>1;
if (y<=mid) return query(k*2,l,mid,x,y);
if (x>mid) return query(k*2+1,mid+1,r,x,y);
return pushup(query(k*2,l,mid,x,y),query(k*2+1,mid+1,r,x,y));
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<n;i++) scanf("%lld",&heng[0][i]);
for (int i=1;i<n;i++) scanf("%lld",&heng[1][i]);
for (int i=1;i<=n;i++) scanf("%lld",&shu[i]);
getchar();getchar();
build(1,1,n);
for (int i=1;i<=m;i++)
{
scanf("%s",ch);
if (ch[0]=='C')
{
scanf("%d%d%d%d%lld",&x0,&y0,&x1,&y1,&w);
getchar();
if (y0==y1) shu[y0]=w,update(1,1,n,y0);
else
{
if (y0>y1) swap(y0,y1);
heng[x0-1][y0]=w;
change(1,1,n,y0);
}
}
else
{
scanf("%d%d",&l,&r);getchar();
printf("%lld\n",query(1,1,n,l,r).sum);
}
}
return 0;
}
```