P4751题解
前置芝士:P4719
众所周知,出题人为了卡离线算法和
看到现在的题解里面没有一个是树剖的
首先我把我的 P4719 代码贴上来。
#include<cstdio>
namespace something_useful
{
inline int max(int a,int b){ return a>b?a:b;}
inline int read()
{
int s=0,w=1;char ch;
while((ch=getchar())>'9'||ch<'0') if(ch=='-') w=-1;
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
return s*w;
}
struct line
{
int nx;
int to;
};
struct graph
{
int head[100001];
line a[200001];
int tot;
void ad(int u,int v)
{
tot++;
a[tot].to=v;
a[tot].nx=head[u];
head[u]=tot;
}
void add(int u,int v){ ad(u,v),ad(v,u);}
int st(int num){ return head[num];}
int to(int num){ return a[num].to;}
int nx(int num){ return a[num].nx;}
};
struct mat
{
int a[3][3];
int m,n;
mat (){ a[1][1]=a[1][2]=a[2][1]=a[2][2]=0;}
mat (int mm,int nn){ m=mm,n=nn,mat();}
mat operator * (mat b)
{
mat ans;
for(int i=1;i<=2;i++)
for(int k=1;k<=2;k++)
{
int s=a[i][k];
for(int j=1;j<=2;j++)
ans.a[i][j]=max(ans.a[i][j],s+b.a[k][j]);
}
return ans;
}
void operator = (mat b)
{
for(int i=1;i<=2;i++)
for(int j=1;j<=2;j++)
a[i][j]=b.a[i][j];
}
};
}
using namespace something_useful;
int f[100001][2];
int hson[100001];
int siz[100001];
int dfn[100001];
int top[100001];
int end[100001];
int rk[100001];
int fa[100001];
mat t[400001];
mat g[100001];
int a[100001];
graph in;
int tot;
int n,m;
void read_all()
{
n=read(),m=read();
for(int i=1;i<=n;i++) a[i]=read();
for(int i=1;i<n;i++) in.add(read(),read());
}
void dfs(int num)
{
siz[num]=1;f[num][1]=a[num];
for(int i=in.st(num);i;i=in.nx(i))
{
int p=in.to(i);
if(p==fa[num]) continue;
fa[p]=num;dfs(p);
siz[num]+=siz[p];
if(siz[p]>siz[hson[num]]) hson[num]=p;
f[num][1]+=f[p][0],f[num][0]+=max(f[p][1],f[p][0]);
}
}
void dfs(int num,int t)
{
top[num]=t;
dfn[num]=++tot;
rk[tot]=num;
end[t]=tot;
if(hson[num]) dfs(hson[num],t);
g[num].a[2][2]=-0x3f3f3f3f;g[num].a[2][1]=a[num];
for(int i=in.st(num);i;i=in.nx(i))
{
int p=in.to(i);
if(p==fa[num]||p==hson[num]) continue;
dfs(p,p);
g[num].a[1][1]+=max(f[p][0],f[p][1]);
g[num].a[2][1]+=f[p][0];
}
g[num].a[1][2]=g[num].a[1][1];
}
void push_up(int p){ t[p]=t[p*2]*t[p*2|1];}
void build(int p,int pl,int pr)
{
if(pl==pr)
{
t[p]=g[rk[pl]];
return ;
}
int mid=(pl+pr)>>1;
build(p*2,pl,mid);
build(p*2|1,mid+1,pr);
push_up(p);
}
void init()
{
dfs(1);dfs(1,1);
build(1,1,n);
}
mat get(int p,int pl,int pr,int l,int r)
{
if(l<=pl&&pr<=r) return t[p];
int mid=(pl+pr)>>1;
if(mid<l) return get(p*2|1,mid+1,pr,l,r);
if(r<=mid) return get(p*2,pl,mid,l,r);
return get(p*2,pl,mid,l,r)*get(p*2|1,mid+1,pr,l,r);
}
void change(int p,int pl,int pr,int x)
{
if(pl==pr){ t[p]=g[rk[x]];return ;}
int mid=(pl+pr)>>1;
if(x<=mid) change(p*2,pl,mid,x);
else change(p*2|1,mid+1,pr,x);
push_up(p);
}
void update(int pos,int val)
{
g[pos].a[2][1]+=val-a[pos];
a[pos]=val;
while(pos)
{
mat last=get(1,1,n,dfn[top[pos]],end[top[pos]]);
change(1,1,n,dfn[pos]);
mat now=get(1,1,n,dfn[top[pos]],end[top[pos]]);
pos=fa[top[pos]];
g[pos].a[1][1]+=max(now.a[1][1],now.a[2][1])-max(last.a[1][1],last.a[2][1]);
g[pos].a[1][2]=g[pos].a[1][1];
g[pos].a[2][1]+=now.a[1][1]-last.a[1][1];
}
}
void run()
{
int pos=read(),val=read();
update(pos,val);
mat ans=get(1,1,n,1,end[1]);
printf("%d\n",max(ans.a[1][1],ans.a[2][1]));
}
int main()
{
read_all();
init();
for(int i=1;i<=m;i++) run();
return 0;
}
前面的那个 namespace 是我的链式前向星模板和快读,大家可以不用管它。
下面我们来看哪里需要更改。
第一步:数组大小
原题的数据范围是
注意:不要忘了线段树的大小要开到
第二步:强制在线
注意到这一题是强制在线的,所以我们在读入节点编号的时候要异或上一次的答案。
注意:
-
只有节点编号要异或,修改的值是不用的;
-
别忘了给
lastans 赋初值以及每次算完答案要赋值。
至此,我们的修改部分完成了。
提交之后,你有可能会发现 TLE 了(我挂了一个点)或者是 MLE。
当然,到这里大家可以选择第十个点暴力直接 A 题,但是本着学习新技巧而不是 A 题的态度,我们开始卡常。
因为比较简单,所以先讲一下 MLE 的解决方案。
第一步:矩阵数组
本题的矩阵是 a[2][2] 就可以了,这样子与原来的 a[3][3] 对比,矩阵的空间会小到原来的一半。
第二步:存图
不知道有多少小伙伴用的是 vector,如果还是 MLE,可以考虑改成链式前向星,节省空间。
注意:如果你算过了你的空间确实不会炸,而且基本没有点 AC,那么大概率是写挂了,建议看一眼有没有无限递归。
先来我们来讲 TLE 解决方案。
第一步:读入
使用快读可以减少时间复杂度,大家应该都会,这里不细说。
第二步:矩阵乘法
本题的矩阵是
mat operator * (mat b)
{
mat ans;
ans.a[0][0]=max(a[0][0]+b.a[0][0],a[0][1]+b.a[1][0]);
ans.a[1][0]=max(a[1][0]+b.a[0][0],a[1][1]+b.a[1][0]);
ans.a[0][1]=max(a[0][0]+b.a[0][1],a[0][1]+b.a[1][1]);
ans.a[1][1]=max(a[1][0]+b.a[0][1],a[1][1]+b.a[1][1]);
return ans;
}
第三步:线段树
这是卡常卡的最多的一部分。
使用树剖的时候,时间复杂度的瓶颈在于对某个点权值修改的部分,现在我们对这部分进行卡常。
可以发现,我们在查询线段树的时候每次只会查找一个重链的矩阵之积,考虑对每一个重链单独开一个线段树,以减少常数。
有人可能会问:你这不就是一个水到极致,基本没用的常数优化吗,给你一条链不就相当于啥都没有。
真的是这样吗?我们来看一眼变化。
原来的代码:
void update(int pos,int val)
{
g[pos].a[2][1]+=val-a[pos];
a[pos]=val;
while(pos)
{
mat last=get(1,1,n,dfn[top[pos]],end[top[pos]]);
change(1,1,n,dfn[pos]);
mat now=get(1,1,n,dfn[top[pos]],end[top[pos]]);
pos=fa[top[pos]];
g[pos].a[1][1]+=max(now.a[1][1],now.a[2][1])-max(last.a[1][1],last.a[2][1]);
g[pos].a[1][2]=g[pos].a[1][1];
g[pos].a[2][1]+=now.a[1][1]-last.a[1][1];
}
}
更改过后的代码:
void update(int pos,int val)
{
g[pos].a[1][0]+=val-a[pos];
a[pos]=val;
while(pos)
{
mat last=t[root[top[pos]]];
change(root[top[pos]],dfn[top[pos]],end[top[pos]],dfn[pos]);
mat now=t[root[top[pos]]];
pos=fa[top[pos]];
g[pos].a[0][0]+=max(now.a[0][0],now.a[1][0])-max(last.a[0][0],last.a[1][0]);
g[pos].a[0][1]=g[pos].a[0][0];
g[pos].a[1][0]+=now.a[0][0]-last.a[0][0];
}
}
注意一下:你可能会发现修改的时候矩阵的位置不一样,这是因为我原来 MLE 过,然后把矩阵从
有什么变化呢?
没错!循环中少了两次函数的调用!
至此,我们的树剖已经可以通过这道题。
最后,给大家放上完整代码:
#include<cstdio>
namespace something_useful
{
inline int max(int a,int b){ return a>b?a:b;}
inline int read()
{
int s=0,w=1;char ch;
while((ch=getchar())>'9'||ch<'0') if(ch=='-') w=-1;
while(ch>='0'&&ch<='9') s=s*10+(ch^48),ch=getchar();
return s*w;
}
struct line
{
int nx;
int to;
};
struct graph
{
int head[1000001];
line a[2000000];
int tot;
void ad(int u,int v)
{
tot++;
a[tot].to=v;
a[tot].nx=head[u];
head[u]=tot;
}
void add(int u,int v){ ad(u,v),ad(v,u);}
int st(int num){ return head[num];}
int to(int num){ return a[num].to;}
int nx(int num){ return a[num].nx;}
};
struct mat
{
int a[2][2];
mat operator * (mat b)
{
mat ans;
ans.a[0][0]=max(a[0][0]+b.a[0][0],a[0][1]+b.a[1][0]);
ans.a[1][0]=max(a[1][0]+b.a[0][0],a[1][1]+b.a[1][0]);
ans.a[0][1]=max(a[0][0]+b.a[0][1],a[0][1]+b.a[1][1]);
ans.a[1][1]=max(a[1][0]+b.a[0][1],a[1][1]+b.a[1][1]);
return ans;
}
};
}
using namespace something_useful;
int f[1000001][2];
int hson[1000001];
int root[1000001];
int siz[1000001];
int dfn[1000001];
int top[1000001];
int end[1000001];
int rk[1000001];
int fa[1000001];
int ls[4000001];
int rs[4000001];
mat t[4000001];
mat g[1000001];
int a[1000001];
graph in;
int last;
int tot;
int n,m;
int nod;
void read_all()
{
n=read(),m=read();
for(int i=1;i<=n;i++) a[i]=read();
for(int i=1;i<n;i++) in.add(read(),read());
}
void dfs(int num)
{
siz[num]=1;f[num][1]=a[num];
for(int i=in.st(num);i;i=in.nx(i))
{
int p=in.to(i);
if(p==fa[num]) continue;
fa[p]=num;dfs(p);
siz[num]+=siz[p];
if(siz[p]>siz[hson[num]]) hson[num]=p;
f[num][1]+=f[p][0],f[num][0]+=max(f[p][1],f[p][0]);
}
}
void dfs(int num,int t)
{
top[num]=t;
dfn[num]=++tot;
rk[tot]=num;
end[t]=tot;
if(hson[num]) dfs(hson[num],t);
g[num].a[1][1]=-0x3f3f3f3f;g[num].a[1][0]=a[num];
for(int i=in.st(num);i;i=in.nx(i))
{
int p=in.to(i);
if(p==fa[num]||p==hson[num]) continue;
dfs(p,p);
g[num].a[0][0]+=max(f[p][0],f[p][1]);
g[num].a[1][0]+=f[p][0];
}
g[num].a[0][1]=g[num].a[0][0];
}
void push_up(int p){ t[p]=t[ls[p]]*t[rs[p]];}
void build(int &p,int pl,int pr)
{
p=++nod;
if(pl==pr)
{
t[p]=g[rk[pl]];
return ;
}
int mid=(pl+pr)>>1;
build(ls[p],pl,mid);
build(rs[p],mid+1,pr);
push_up(p);
}
void init()
{
dfs(1);dfs(1,1);
for(int i=1;i<=n;i++) if(top[i]==i) build(root[i],dfn[i],end[i]);
}
mat get(int p,int pl,int pr,int l,int r)
{
if(l<=pl&&pr<=r) return t[p];
int mid=(pl+pr)>>1;
if(mid<l) return get(rs[p],mid+1,pr,l,r);
if(r<=mid) return get(ls[p],pl,mid,l,r);
return get(ls[p],pl,mid,l,r)*get(rs[p],mid+1,pr,l,r);
}
void change(int p,int pl,int pr,int x)
{
if(pl==pr){ t[p]=g[rk[x]];return ;}
int mid=(pl+pr)>>1;
if(x<=mid) change(ls[p],pl,mid,x);
else change(rs[p],mid+1,pr,x);
push_up(p);
}
void update(int pos,int val)
{
g[pos].a[1][0]+=val-a[pos];
a[pos]=val;
while(pos)
{
mat last=t[root[top[pos]]];
change(root[top[pos]],dfn[top[pos]],end[top[pos]],dfn[pos]);
mat now=t[root[top[pos]]];
pos=fa[top[pos]];
g[pos].a[0][0]+=max(now.a[0][0],now.a[1][0])-max(last.a[0][0],last.a[1][0]);
g[pos].a[0][1]=g[pos].a[0][0];
g[pos].a[1][0]+=now.a[0][0]-last.a[0][0];
}
}
void run()
{
int pos=read()^last,val=read();
update(pos,val);
mat ans=t[root[1]];
printf("%d\n",last=max(ans.a[0][0],ans.a[1][0]));
}
int main()
{
read_all();
init();
for(int i=1;i<=m;i++) run();
return 0;
}
感谢观看!