题解 P5559【失昼城的守星使】
题意:给定一棵树,边带权,每个点有黑白两种颜色,你需要支持翻转某个点的颜色,或者查询所有黑点到某一条链的距离之和。
分两种情况讨论:在 LCA 的子树内,在 LCA 的子树外。
为了方便,我们规定:
先看第一种情况:( 给定的链为
我们分别计算每个链上结点作为 LCA 时对答案的贡献:
其中加号表示正贡献,减号表示负贡献。整理可得答案为
再看第二种情况:
还是一样的,列表计算贡献:( 根据
整理可得:
可以发现,答案当中有着非常多形如
每次改变结点颜色的时候,只需要修改从它到根结点的链上这四个式子的值。修改和查询都可以用树链剖分解决。
注意
#include<bits/stdc++.h>
#define ll long long
#define For(i,a,b) for(int i=(a);i<=(b);++i)
#define Rof(i,a,b) for(int i=(a);i>=(b);--i)
using namespace std;
typedef pair<int,int> Pair;
const int Maxn=2e5;
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0' || ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0' && ch<='9')
{
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
int n,m,kazuha,all,typ[Maxn+5]; ll dep[Maxn+5],sd[Maxn+5];
int siz[Maxn+5],fa[Maxn+5],son[Maxn+5];
int dfn[Maxn+5],pre[Maxn+5],top[Maxn+5],cur;
vector<Pair> v[Maxn+5];
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
inline void dfs1(int x,int f)
{
fa[x]=f,siz[x]=1,sd[x]=dep[x];
for(auto i:v[x])
{
int y=i.first,z=i.second;
if(y==f) continue;
dep[y]=dep[x]+z;
dfs1(y,x),siz[x]+=siz[y],sd[x]+=sd[y];
son[x]=(siz[y]>siz[son[x]]?y:son[x]);
}
}
inline void dfs2(int x,int t)
{
top[x]=t,dfn[x]=++cur,pre[cur]=x;
if(son[x]) dfs2(son[x],t);
for(auto i:v[x])
{
int y=i.first;
if(y==fa[x] || y==son[x]) continue;
dfs2(y,y);
}
}
struct SegTree
{
ll t[Maxn*4+5],tag[Maxn*4+5],val[Maxn+5],sum[Maxn+5]; int opt;
inline void push_up(int p) {t[p]=t[ls(p)]+t[rs(p)];}
inline void f(int p,ll k,int len) {t[p]+=len*k,tag[p]+=k;}
inline void f2(int p,ll k,int l,int r) {t[p]+=(sum[r]-sum[l-1])*k,tag[p]+=k;}
inline void push_down(int l,int r,int p)
{
int mid=(l+r)>>1;
f(ls(p),tag[p],mid-l+1),f(rs(p),tag[p],r-mid),tag[p]=0;
}
inline void push_down2(int l,int r,int p)
{
int mid=(l+r)>>1;
f2(ls(p),tag[p],l,mid),f2(rs(p),tag[p],mid+1,r),tag[p]=0;
}
inline void Build(int l,int r,int p)
{
if(l==r) {t[p]=val[pre[l]]; return;}
int mid=(l+r)>>1;
Build(l,mid,ls(p)),Build(mid+1,r,rs(p)),push_up(p);
}
inline void Update(int nl,int nr,int l,int r,int p,ll k)
{
if(l<=nl && nr<=r) {t[p]+=(nr-nl+1)*k,tag[p]+=k; return;}
int mid=(nl+nr)>>1; push_down(nl,nr,p);
if(l<=mid) Update(nl,mid,l,r,ls(p),k);
if(r>mid) Update(mid+1,nr,l,r,rs(p),k);
push_up(p);
}
inline void Update2(int nl,int nr,int l,int r,int p,ll k)
{
if(l<=nl && nr<=r) {t[p]+=(sum[nr]-sum[nl-1])*k,tag[p]+=k; return;}
int mid=(nl+nr)>>1; push_down2(nl,nr,p);
if(l<=mid) Update2(nl,mid,l,r,ls(p),k);
if(r>mid) Update2(mid+1,nr,l,r,rs(p),k);
push_up(p);
}
inline ll Count(int nl,int nr,int l,int r,int p)
{
if(l<=nl && nr<=r) return t[p];
int mid=(nl+nr)>>1; ll res=0;
if(!opt) push_down(nl,nr,p);
else push_down2(nl,nr,p);
if(l<=mid) res+=Count(nl,mid,l,r,ls(p));
if(r>mid) res+=Count(mid+1,nr,l,r,rs(p));
push_up(p); return res;
}
inline void Modify(int x,int y,ll z)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
Update(1,n,dfn[top[x]],dfn[x],1,z),x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
Update(1,n,dfn[x],dfn[y],1,z);
}
inline void Modify2(int x,int y,ll z)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
Update2(1,n,dfn[top[x]],dfn[x],1,z),x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
Update2(1,n,dfn[x],dfn[y],1,z);
}
inline ll Find(int x,int y)
{
ll res=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
res+=Count(1,n,dfn[top[x]],dfn[x],1),x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
res+=Count(1,n,dfn[x],dfn[y],1);
return res;
}
} T[4];
inline int LCA(int x,int y)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];
}
return (dep[x]<dep[y]?x:y);
}
inline void Change(int x)
{
if(!typ[x])
{
T[0].Modify(x,1,1);
T[1].Modify(x,1,dep[x]);
T[2].Modify2(x,1,1);
T[3].Modify2(x,1,1);
typ[x]=1,all++;
}
else
{
T[0].Modify(x,1,-1);
T[1].Modify(x,1,-dep[x]);
T[2].Modify2(x,1,-1);
T[3].Modify2(x,1,-1);
typ[x]=0,all--;
}
}
inline ll Query(int x,int y)
{
int d=LCA(x,y);
ll sdd=T[1].Count(1,n,dfn[d],dfn[d],1),szd=T[0].Count(1,n,dfn[d],dfn[d],1);
ll s1=sdd-dep[d]*szd+T[3].Find(x,y)-T[2].Find(x,y)-szd*(dep[fa[d]]-dep[d]),s2=0;
if(fa[d])
{
ll sd1=T[1].Count(1,n,1,1,1);
s2=dep[d]*(all-szd)+sd1-sdd;
s2=s2+2ll*T[3].Find(1,d)-2ll*T[2].Find(1,fa[d]);
}
return s1+s2;
}
int main()
{
n=read(),m=read(),kazuha=read(),all=n;
T[0].opt=0,T[1].opt=0,T[2].opt=1,T[3].opt=1;
For(i,1,n-1)
{
int a=read(),b=read(),c=read();
v[a].push_back(make_pair(b,c)),v[b].push_back(make_pair(a,c));
}
For(i,1,n) typ[i]=read();
dfs1(1,0),dfs2(1,1);
For(i,1,n) T[2].sum[i]=T[2].sum[i-1]+dep[pre[i]];
For(i,1,n) T[3].sum[i]=T[3].sum[i-1]+dep[fa[pre[i]]];
For(i,0,3)
{
if(i==0) For(j,1,n) T[i].val[j]=siz[j];
if(i==1) For(j,1,n) T[i].val[j]=sd[j];
if(i==2) For(j,1,n) T[i].val[j]=1ll*siz[j]*dep[j];
if(i==3) For(j,1,n) T[i].val[j]=1ll*siz[j]*dep[fa[j]];
T[i].Build(1,n,1);
}
For(i,1,n) if(!typ[i]) typ[i]=1,Change(i);
while(m--)
{
int opt=read(),x=read(),y;
if(opt==1) Change(x);
else y=read(),printf("%lld\n",Query(x,y));
}
return 0;
}