P4074题解

· · 题解

这题蒟蒻调了整整一天半,实在是恶心,但是做完感觉自己对树上莫队的理解加深了许多(我是用手重新打的)。

题目大意

这题给的是一颗挺明显的树(n 个节点,n-1 条边),然后要求路径上的一些数据(路径上的糖果类型的出现次数对应的 w* 糖果对应的快乐值的和),还带修改。

思路

这题是棵树,要求路径,支持修改,可以离线,答案是可以在统计过程中计算的(也只有在统计过程中计算),这不明显的莫队吗?还是带修的树上莫队,板板在手,天下我有,这里讲解一下怎么实现。

因为是路径,所以 dfs 序显然没用了,这里需要用到欧拉序和 LCA。欧拉序就是在 dfs 中每次进入和离开 dfs 时都加入序列,这序列的方便之处就在于将两个点之间重复的节点抹去,剩下的就是路径了,当然,如果要绕过他们的 LCA,的话就需要单独处理一下(也只需要单独处理 LCA)。求 LCA 时就可以初始化欧拉序。最后就是莫队了,记得要处理重复的节点(开一个数组,每次异或 1)。

带修莫队加进去其实影响不大,甚至都不需要你去改树上莫队的板子(但我还是重新打了一遍...),就是这里的对应关系十分复杂,推荐画一下图然后进行记录,不然就会像我一样傻看着代码搞一天不晓得那里错了...

讲完了,上代码

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int MAXN=100005;
inline void qr(register int &ret){//手打快读 
    register int x=0,f=0;
    register char ch=getchar();
    while(ch<'0'||ch>'9') f|=ch=='-',ch=getchar();
    while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
    ret=f?-x:x;
} 
int n,m,q,w[MAXN],v[MAXN],c[MAXN],ttt,fr,op,x,y,LCA;//单纯的杂项(啥都有) 
int fa[MAXN],ord[MAXN<<1],fir[MAXN],las[MAXN],son[MAXN],sze[MAXN],head[MAXN],nxt[MAXN<<1],to[MAXN<<1],dep[MAXN],top[MAXN],ccnt,nord;//LCA 相关 
int cntr,cntp,idx[MAXN<<1],s,l,r,t,tot,cnt[MAXN],vis[MAXN],L,R,T,ans[MAXN];//莫队相关 
struct node1{//查询 
    int l,r,id,ti,lca;
    bool operator<(const node1 x)const{//奇怪的双重奇偶排序 
        return idx[l]==idx[x.l]?(idx[l]&1?(idx[r]==idx[x.r]?(idx[r]&1?ti<x.ti:ti>x.ti):idx[r]<idx[x.r]):(idx[r]==idx[x.r]?(idx[r]&1?ti<x.ti:ti>x.ti):r>x.r)):l<x.l;
    }
}p[MAXN];
struct node2{//修改 
    int i,from,to;
}ch[MAXN];
void add(int x){
    ++cnt[x];
    tot+=v[x]*w[cnt[x]];
}
void del(int x){//这里的对应关系用人脑压不下!! 
    tot-=v[x]*w[cnt[x]];
    --cnt[x];
}
void work(int x){//这里是树上莫队的处理(没学过的理解一下) 
    vis[x]?del(c[x]):add(c[x]);
    vis[x]^=1;
}
void dealadd(int x){//时间的修改 1 
    if(vis[ch[x].i]) add(ch[x].to),del(ch[x].from);//如果修改的节点对答案有影响,就顺手处理一手影响 
    c[ch[x].i]=ch[x].to;                           
}
void dealdel(int x){//时间的修改 2
    if(vis[ch[x].i]) del(ch[x].to),add(ch[x].from);//同上 
    c[ch[x].i]=ch[x].from;//修改原数组 
}
void adde(int fr,int To){//加边 
    nxt[++ccnt]=head[fr],head[fr]=ccnt,to[ccnt]=To;
}
void dfs1(int i){//这里我用的是树剖求 LCA 看不懂的同学可以去康康其他的 
                 //树剖的初始化 1 
    dep[i]=dep[fa[i]]+1;
    sze[i]=1;
    ord[++nord]=i;
    fir[i]=nord;
    for(int j=head[i];j;j=nxt[j]){
        if(to[j]==fa[i]) continue;
        fa[to[j]]=i;
        dfs1(to[j]);
        sze[i]+=sze[to[j]];
        if(son[i]==1||sze[to[j]]>sze[son[i]]) son[i]=to[j];
    }
    ord[++nord]=i;
    las[i]=nord;
    return;
}
void dfs2(int i,int Top){//树剖的初始化 2
    top[i]=Top;
    if(son[i]) dfs2(son[i],Top);
    for(int j=head[i];j;j=nxt[j]){
        if(to[j]==fa[i]||to[j]==son[i]) continue;
        dfs2(to[j],to[j]);
    }
    return;
}
int getlca(int x,int y){//跳祖先 
    while(top[x]^top[y]){
        if(dep[top[x]]>=dep[top[y]]) x=fa[top[x]];
        else y=fa[top[y]];
    }
    return dep[x]>dep[y]?y:x;
}
signed main() {
    qr(n),qr(m),qr(q);
    for(int i=1;i<=m;++i) qr(v[i]);
    for(int i=1;i<=n;++i) qr(w[i]);
    for(int i=1;i<n;++i){
        qr(fr),qr(ttt);
        adde(fr,ttt),adde(ttt,fr);
    }
    fa[1]=1;
    dfs1(1),dfs2(1,1);//初始化 
    for(int i=1;i<=n;++i) qr(c[i]);
    s=pow(nord,2.0/3);//处理分块 1(这里是 2/3 次方!!!) 
    for(int i=1;i<=nord;++i) idx[i]=(i+s-1)/s;//处理分块 2
    for(int i=1;i<=q;++i){
        qr(op),qr(x),qr(y);
        if(op){
            ++cntp;
            if(fir[x]>fir[y]) swap(x,y);
            LCA=getlca(x,y);
            if(LCA==x){//如果两点在一条链上 
                p[cntp].l=fir[x],p[cntp].r=fir[y];
                p[cntp].id=cntp,p[cntp].ti=cntr;
            }else{//不在一条链上 
                p[cntp].l=las[x],p[cntp].r=fir[y];
                p[cntp].id=cntp,p[cntp].ti=cntr;
                p[cntp].lca=LCA;
            }
        }else{
            ++cntr;
            ch[cntr].i=x,ch[cntr].from=c[x],ch[cntr].to=c[x]=y;//记得修改原数组 
        }
    }
    for(int i=cntr;i>=1;--i) c[ch[i].i]=ch[i].from;//记得还原 
    sort(p+1,p+cntp+1);//排序 
    l=1,r=t=0;
    for(int i=1;i<=cntp;++i){//莫队板板(如果这都不会,你还来做这题?) 
        L=p[i].l,R=p[i].r,T=p[i].ti;
        while(t>T) dealdel(t--);
        while(t<T) dealadd(++t);
        while(l>L) work(ord[--l]);//这里是 ord[(l 或 r)]!!!调了半天的我就是因为这里...
        while(r<R) work(ord[++r]);
        while(l<L) work(ord[l++]);
        while(r>R) work(ord[r--]);
        if(p[i].lca) work(p[i].lca);//如果不在一条链上,就处理一下 LCA 
        ans[p[i].id]=tot;
        if(p[i].lca) work(p[i].lca);
    }
    for(int i=1;i<=cntp;++i) printf("%lld\n",ans[i]);
    return 0;
}