题解 P4074 【[WC2013]糖果公园】

· · 题解

来一发直接在树上暴力分块的题解!

看到题解里各位大佬写的什么括号序列,可是蒟蒻表示根本不会啊 QAQ
于是我写了一篇在树上分块的题解

还记得当时是怎么让莫队优化区间修改+查询问题的吗?就是分块+排序。现在问题跑到了树上,怎么分块呢?

我们可以考虑这样来分块:
从任意节点作为根,开始 \text{dfs},每经过一个点就把它压进栈里。

当一个节点到栈顶的长度大于分块大小时,就把的这部分的点全部弹出,并分为一块。
对于最后剩下来的点,单独再分一块。

具体是什么样的呢?这里借用了一张图来解释:

此处块的大小为3,按照我们刚才说的方法,以顶端那个节点为根,就分成了这样。  

具体代码实现如下:

void dfs(int u,int fa){
    int bottom = top; //top是一个全局变量,初始为0
    stack[++top] = u; //当前节点入栈
    int v,l = adj[u].size();
    for(int i=0;i<l;++i){
        v = adj[u][i];
        if(v==fa) continue;
        dfs(v,u);
        if(top-bottom>block){//栈顶到当前节点的长度大于分块大小
            ++idx;
            while(top!=bottom) be[stack[top--]] = idx; //弹出这部分元素,分为一块(be[i]表示i节点属于哪一块)
        }
    }
}

分块的问题解决了,排序又该怎么办?

跟普通序列上的莫队方法差不多:
以两端节点 所属的块 为第一、二关键字,以时间为第三关键字排序。

bool cmp(query a,query b){
    //u,v是询问路径的两端节点
    if(be[a.u]==be[b.u]){
        if(be[a.v]==be[b.v]) return a.t<b.t;
        return be[a.v]<be[b.v];
    }
    return be[a.u]<be[b.u];
}

最后,也是最难搞的问题:区间移动。
乍一看好像没什么难的,实际上也没什么难的。

一个节点要从 u 移到 v 的话,让它们不断地向上跳,直到跳到一起。这样做等价于一个点移到另一个点。

跳的过程中,用一个数组 \text{vis} 表示该节点在不在查询的路径上。
每到一个点,如果以前在路径上,那现在肯定就不在了;反之亦然。
随着这样不断地更新经过的节点就完成了区间移动。

结果你发现,如果移动节点时,要跨过 \text{lca}(u,v) 的话,还是会有问题,如下图:

当我们按上述步骤把 u 移到 u'v 移到 v'时,我们发现:

那这好办,最后再更新一下这两个节点,轻松解决问题。 对于更新操作,可以维护一个数组 $\text{num}$,记录路径上各种糖果有多少个。当路径上增/减第 $i

种糖果时,对结果的贡献是 \text{w}[\text{num}[i]]\times\text{v}[i],随便推一下就出来了。

最后,确定分块的 \text{dfs} 可以和求 \text{lca} 的预处理写在一起,减少码量。

时间复杂度 \text O (n^\frac53)

代码如下:

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<vector>
#define N 100003
#define reg register
#define ll long long
using namespace std;

struct query{
    int u,v,t,id;
    query(int u=0,int v=0,int t=0,int id=0):u(u),v(v),t(t),id(id){}
};

struct change{
    int u,last,next;
    change(int u=0,int last=0,int next=0):u(u),last(last),next(next){}
};

query q[N];
change c[N];
int val[N],w[N],fa[N],depth[N],son[N],stack[N];
int sub[N],top[N],be[N],clr[N],sum[N];
bool vis[N];
int n,m,T,tp,block,idx;
ll res;
ll ans[N];
vector<int> adj[N];

inline void swap(int &aa,int &bb){
    aa ^= bb;
    bb ^= aa;
    aa ^= bb;
}

inline void read(int &x){
    x = 0;
    char c = getchar();
    while(c<'0'||c>'9') c = getchar();
    while(c>='0'&&c<='9'){
        x = (x<<3)+(x<<1)+(c^48);
        c = getchar();
    }
}

void print(ll x){
    if(x>9) print(x/10);
    putchar(x%10+'0');
}

void dfs1(int u,int f){
    int bt = tp;
    stack[++tp] = u;
    fa[u] = f;
    depth[u] = depth[f]+1;
    sub[u] = 1;
    int v,t = -1,l = adj[u].size();
    for(int i=0;i<l;++i){
        v = adj[u][i];
        if(v==f) continue;
        dfs1(v,u);
        if(tp-bt>block){
            ++idx;
            while(tp!=bt) be[stack[tp--]] = idx;
        }
        sub[u] += sub[v];
        if(sub[v]>t){
            t = sub[v];
            son[u] = v;
        }
    }
}

void dfs2(int u,int f){
    top[u] = f;
    if(son[u]==0) return;
    dfs2(son[u],f);
    int v,l = adj[u].size();
    for(int i=0;i<l;++i){
        v = adj[u][i];
        if(v==fa[u]||v==son[u]) continue;
        dfs2(v,v);
    }
}

inline int lca(int u,int v){
    while(top[u]!=top[v]){
        if(depth[top[u]]<depth[top[v]])
            swap(u,v);
        u = fa[top[u]];    
    }
    if(depth[u]<depth[v]) return u;
    return v;
}

inline bool cmp(query a,query b){
    if(be[a.u]==be[b.u]){
        if(be[a.v]==be[b.v]) return a.t<b.t;
        return be[a.v]<be[b.v];
    }
    return be[a.u]<be[b.u];
}

inline void add(int u){
    ++sum[u];
    res += (ll)w[sum[u]]*val[u];
}

inline void del(int u){
    res -= (ll)w[sum[u]]*val[u];
    --sum[u];
}

inline void update(int u){
    if(vis[u]) del(clr[u]),vis[u] = false;
    else add(clr[u]),vis[u] = true;
}

inline void modify(int u,int t){
    if(vis[u]){
        del(clr[u]);
        add(t);
    }
    clr[u] = t;
}

void move(int u,int v){
    if(depth[u]<depth[v]) swap(u,v);
    while(depth[u]>depth[v]){
        update(u);
        u = fa[u];
    }
    while(u!=v){
        update(u),update(v);
        u = fa[u],v = fa[v];
    }
}

int main(){
    int op,t,qc,u,v;
    read(n),read(m),read(T);
    block = pow(n,2.0/3);
    for(reg int i=1;i<=m;++i) read(val[i]);
    for(reg int i=1;i<=n;++i) read(w[i]);
    for(reg int i=1;i<n;++i){
        read(u),read(v);
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    for(reg int i=1;i<=n;++i){
        read(clr[i]);
        stack[i] = clr[i];
    }
    t = qc = 0;
    for(reg int i=1;i<=T;++i){
        read(op),read(u),read(v);
        if(op==0){
            c[++t] = change(u,stack[u],v);
            stack[u] = v;
        }else{
            ++qc;
            q[qc] = query(u,v,t,qc);
        }
    }
    memset(stack,0,sizeof(stack));
    dfs1(1,0);
    dfs2(1,1);
    while(tp>0) be[stack[tp--]] = idx;
    sort(q+1,q+1+qc,cmp);
    t = 0; 
    u = v = 1;
    update(1);
    for(int i=1;i<=qc;++i){
        while(t<q[i].t) modify(c[t+1].u,c[t+1].next),++t;
        while(t>q[i].t) modify(c[t].u,c[t].last),--t;
        update(lca(u,v));
        if(u!=q[i].u) move(u,q[i].u),u = q[i].u;
        if(v!=q[i].v) move(v,q[i].v),v = q[i].v;
        update(lca(u,v));
        ans[q[i].id] = res;
    }
    for(int i=1;i<=qc;++i){
        print(ans[i]);
        putchar('\n');
    }
    return 0;
}