题解 P4242 【树上的毒瘤】
Salamander · · 题解
本来是想着好像可以在虚树上点分治来做一些不是很好计算的东西?
然后本来想出一道在虚树上点分治+FFT来求一下特定长度的路径?结果发现这东西套在虚树上复杂度不对。
然后就出点树形dp不好做的东西……然后由于套了虚树所以可以带修改……然后……
然后就变得非常码农
然后就有了这个题:树上多次询问,每次给定一些关键点,需要统计出每个关键点到其他关键点的所有路径上的颜色段数之和。
首先是询问的总点数和
然后就是在虚树上统计答案,上点分治。
考虑怎么统计经过当前分治中心的路径对答案的贡献。
对于分治中心的一棵子树,我们现在只统计经过中心的路径的贡献,那么我们可以求出从中心出发进入其他子树的路径的总贡献,然后我们
在虚树上还要注意,只能统计询问点的答案,有一些询问点的lca也在虚树上但不能给其他询问点算贡献。然后由于虚树上的边带表原树上的一条链,所以虚树上的边权是原树上那条链上的颜色段数。
然后我们再写个树链剖分来求lca、链修改、链查询颜色段数。
然后就没有然后了,码吧。
复杂度
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define lc x<<1
#define rc x<<1|1
#define For(i,_beg,_end) for(int i=(_beg),i##end=(_end);i<=i##end;++i)
#define Rep(i,_beg,_end) for(int i=(_beg),i##end=(_end);i>=i##end;--i)
template<typename T>T Max(const T &x,const T &y){return x<y?y:x;}
template<typename T>T Min(const T &x,const T &y){return x<y?x:y;}
template<typename T>int chkmax(T &x,const T &y){return x<y?(x=y,1):0;}
template<typename T>int chkmin(T &x,const T &y){return x>y?(x=y,1):0;}
template<typename T>void read(T &x){
T f=1;char ch=getchar();
for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1;
for(x=0;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0';
x*=f;
}
typedef long long LL;
const int maxn=100010,inf=0x7fffffff;
struct edge{
int to,nxt,len;
}e[maxn<<1];
struct Data{
int cl,cr,sum;
Data(){cl=cr=sum=0;}
Data(int l,int r,int s){cl=l;cr=r;sum=s;}
friend Data operator+(const Data &x,const Data &y){
if(!x.cl)return y;
if(!y.cl)return x;
return Data(x.cl,y.cr,x.sum+y.sum-(x.cr==y.cl));
}
}T[maxn<<2];
int n,m,q,num,head[maxn],sum,root,f[maxn],size[maxn];
int c[maxn],dfn[maxn],cnt,tp[maxn];
int anc[maxn],son[maxn],dep[maxn],rnk,pos[maxn],id[maxn];
int sta[maxn],top,h[maxn],sq[maxn],tag[maxn<<2];
LL ans[maxn],tot,all,path;
bool vis[maxn],is[maxn];
void addedge(int,int,int);
void Dfs1(int,int);
void Dfs2(int,int,int);
int Lca(int,int);
void Build(int,int,int);
void pushsame(int,int);
void pushdown(int);
Data Query(int,int,int,int,int);
void Modify(int,int,int,int,int,int);
int Query(int,int);
void Change(int,int,int);
bool comp(int,int);
void Build(void);
void get(int,int);
void get_sum(int,int);
void get_root(int,int);
void Dfs(int,int,int);
void Modify(int,int,int);
void Calc(int);
void Solve(int);
int main(){
read(n);read(q);
For(i,1,n) read(c[i]);
For(i,1,n-1){
int u,v;
read(u);read(v);
addedge(u,v,0);
}
Dfs1(1,0);Dfs2(1,0,1);
Build(1,1,n);
int op,u,v,col;
while(q--){
read(op);
if(op==1){
read(u);read(v);read(col);
Change(u,v,col);
}
else{
read(m);
For(i,1,m){
read(h[i]);
is[h[i]]=true;
sq[i]=h[i];
}
Build();Solve(1);
For(i,1,m) is[sq[i]]=false;
For(i,1,m) printf("%lld ",ans[sq[i]]),ans[sq[i]]=0;
putchar(10);
}
}
return 0;
}
void Calc(int x){
get(x,0);path=size[x];
Dfs(x,tot=0,1);all=tot;
if(is[x]) ans[x]+=tot;
for(int i=head[x];i;i=e[i].nxt) if(!vis[e[i].to]){
tot=0;
Dfs(e[i].to,x,e[i].len+1);
all-=tot;path-=size[e[i].to];
Modify(e[i].to,x,e[i].len);
all+=tot;path+=size[e[i].to];
}
vis[x]=true;
}
void Solve(int x){
Calc(x);
for(int i=head[x];i;i=e[i].nxt) if(!vis[e[i].to]){
f[sum=root=0]=inf;
get_sum(e[i].to,0);
get_root(e[i].to,0);
Solve(root);
}
vis[x]=false;
}
void Modify(int x,int fa,int lst){
if(is[x]) ans[x]+=all+1LL*lst*path;
for(int i=head[x];i;i=e[i].nxt){
if(e[i].to==fa||vis[e[i].to]) continue;
Modify(e[i].to,x,lst+e[i].len);
}
}
void Dfs(int x,int fa,int lst){
if(is[x]) tot+=lst;
for(int i=head[x];i;i=e[i].nxt){
if(e[i].to==fa||vis[e[i].to]) continue;
Dfs(e[i].to,x,lst+e[i].len);
}
}
void get(int x,int fa){
size[x]=is[x];
for(int i=head[x];i;i=e[i].nxt){
if(e[i].to==fa||vis[e[i].to]) continue;
get(e[i].to,x);
size[x]+=size[e[i].to];
}
}
void get_sum(int x,int fa){
sum++;
for(int i=head[x];i;i=e[i].nxt){
if(e[i].to==fa||vis[e[i].to]) continue;
get_sum(e[i].to,x);
}
}
void get_root(int x,int fa){
f[x]=0;size[x]=1;
for(int i=head[x];i;i=e[i].nxt){
if(e[i].to==fa||vis[e[i].to]) continue;
get_root(e[i].to,x);
size[x]+=size[e[i].to];
chkmax(f[x],size[e[i].to]);
}
chkmax(f[x],sum-size[x]);
if(f[x]<f[root]) root=x;
}
void addedge(int u,int v,int w){
e[++num].to=v;e[num].len=w;e[num].nxt=head[u];head[u]=num;
e[++num].to=u;e[num].len=w;e[num].nxt=head[v];head[v]=num;
}
void Dfs1(int x,int fa){
size[x]=1;dep[x]=dep[fa]+1;
dfn[x]=++cnt;anc[x]=fa;
for(int i=head[x];i;i=e[i].nxt) if(e[i].to!=fa){
Dfs1(e[i].to,x);
size[x]+=size[e[i].to];
if(size[e[i].to]>size[son[x]]) son[x]=e[i].to;
}
}
void Dfs2(int x,int fa,int up){
tp[x]=up;id[++rnk]=x;pos[x]=rnk;
if(son[x]) Dfs2(son[x],x,up);
for(int i=head[x];i;i=e[i].nxt)
if(e[i].to!=fa&&e[i].to!=son[x])
Dfs2(e[i].to,x,e[i].to);
}
int Lca(int u,int v){
int fx=tp[u],fy=tp[v];
while(fx!=fy){
if(dep[fx]>dep[fy]) fx=tp[u=anc[fx]];
else fy=tp[v=anc[fy]];
}
return dep[u]<dep[v]?u:v;
}
void Build(int x,int l,int r){
if(l==r)return T[x]=Data(c[id[l]],c[id[l]],1),void();
int mid=(l+r)>>1;
Build(lc,l,mid);Build(rc,mid+1,r);
T[x]=T[lc]+T[rc];
}
void pushsame(int x,int col){
tag[x]=col;
T[x]=Data(col,col,1);
}
void pushdown(int x){
if(tag[x]){
pushsame(lc,tag[x]);
pushsame(rc,tag[x]);
tag[x]=0;
}
}
Data Query(int x,int l,int r,int ql,int qr){
if(l>=ql&&r<=qr)return T[x];
int mid=(l+r)>>1;pushdown(x);
if(ql<=mid&&qr>mid)return Query(lc,l,mid,ql,qr)+Query(rc,mid+1,r,ql,qr);
else if(ql<=mid)return Query(lc,l,mid,ql,qr);
else return Query(rc,mid+1,r,ql,qr);
}
void Modify(int x,int l,int r,int ql,int qr,int col){
if(l>=ql&&r<=qr)return pushsame(x,col),void();
int mid=(l+r)>>1;pushdown(x);
if(ql<=mid) Modify(lc,l,mid,ql,qr,col);
if(qr>mid) Modify(rc,mid+1,r,ql,qr,col);
T[x]=T[lc]+T[rc];
}
int Query(int u,int v){
Data res;
int fx=tp[u],fy=tp[v];
while(fx!=fy){
res=Query(1,1,n,pos[fx],pos[u])+res;
fx=tp[u=anc[fx]];
}
res=Query(1,1,n,pos[v],pos[u])+res;
return res.sum-1;
}
void Change(int u,int v,int col){
int fx=tp[u],fy=tp[v];
while(fx!=fy){
if(dep[fx]>dep[fy]){
Modify(1,1,n,pos[fx],pos[u],col);
fx=tp[u=anc[fx]];
}
else{
Modify(1,1,n,pos[fy],pos[v],col);
fy=tp[v=anc[fy]];
}
}
if(dep[u]>dep[v]) swap(u,v);
Modify(1,1,n,pos[u],pos[v],col);
}
bool comp(int x,int y){return dfn[x]<dfn[y];}
void Build(){
sort(h+1,h+m+1,comp);
num=head[sta[top=1]=1]=0;
For(i,h[1]==1?2:1,m){
int lca=Lca(h[i],sta[top]);
if(lca!=sta[top]){
while(dfn[sta[top-1]]>dfn[lca]){
addedge(sta[top],sta[top-1],Query(sta[top],sta[top-1]));
top--;
}
if(sta[top-1]!=lca){
head[lca]=0;
addedge(sta[top],lca,Query(sta[top],lca));
sta[top]=lca;
}
else addedge(sta[top],lca,Query(sta[top],lca)),top--;
}
head[h[i]]=0;
sta[++top]=h[i];
}
For(i,1,top-1) addedge(sta[i+1],sta[i],Query(sta[i+1],sta[i]));
}