题解 [ABC133F] Colorful Tree

· · 题解

考虑正常的 uv 之间的距离怎么去求:我们可以维护每个值到根的距离 dist_i ,然后通过计算 dist_u+dist_v-2 \times dist_{lca} 得到,这里就不过多赘述。

考虑本题,我们发现只要能够求出改变后的 dist_i 的值即可,不妨利用主席树维护每个点 i 到根节点 1 的每种颜色的出现次数 cnt 以及权值和 sum ,那么此时真正的 dist'_i=dist_i-sum+cnt \times y

而主席树的每个点的权值可以继承其父亲的权值。时空复杂度均为 O(N \log N)

#include<iostream>
#include<cstdio>
#include<string>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N=1e5+5;
int n,q,x,y,z,w;
struct node{
    int nxt;int to;int dist;int col;
}e[N*2];
int head[N],tot;
struct Chair{
    int ls;int rs;int sum;int cnt;
}Tree[N*32];
int root[N],Tcnt;

int fa[N],dep[N],sz[N],hson[N],top[N];
int val[N],col[N],dist[N];
/*这里使用了树剖求lca*/

inline void read(int &x) 
{
    int f=1;char c;
    for(x=0,c=getchar();c<'0'||c>'9';c=getchar()) if(c=='-') f=-1;
    for(;c>='0'&&c<='9';c=getchar()) x=(x<<1)+(x<<3)+(c^48); x*=f;
} 
inline int mn(int _x,int _y){return _x<_y?_x:_y;}
inline int mx(int _x,int _y){return _x>_y?_x:_y;}
inline int ab(int _x){return _x<0?-_x:_x;}
inline void add(int from,int to,int col,int dist)
{
    e[++tot].to=to;
    e[tot].dist=dist;e[tot].col=col;
    e[tot].nxt=head[from];
    head[from]=tot;
}

inline int New(int x)
{
    ++Tcnt;
    Tree[Tcnt]=Tree[x];
    return Tcnt;
}
inline int update(int x,int l,int r,int pos,int v)
{
    x=New(x);Tree[x].sum+=v;Tree[x].cnt++;
    if(l==r) return x;
    int mid=l+r>>1;
    if(pos<=mid) Tree[x].ls=update(Tree[x].ls,l,mid,pos,v);
    if(pos>mid) Tree[x].rs=update(Tree[x].rs,mid+1,r,pos,v);
    return x;
}
inline pair<int,int> query(int x,int l,int r,int pos)
{
    if(l==r) return make_pair(Tree[x].sum,Tree[x].cnt);
    int mid=l+r>>1;
    if(pos<=mid) return query(Tree[x].ls,l,mid,pos);
    if(pos>mid) return query(Tree[x].rs,mid+1,r,pos);
}
/*主席树的修改与查询*/
inline void dfs1(int x)
{
    sz[x]=1;
    for(int i=head[x];i;i=e[i].nxt)
    {
        int v=e[i].to;
        if(v==fa[x]) continue;
        fa[v]=x;dep[v]=dep[x]+1;
        val[v]=e[i].dist;col[v]=e[i].col;
        dist[v]=dist[x]+e[i].dist;
        dfs1(v);
        sz[x]+=sz[v];
        if(sz[v]>sz[hson[x]]) hson[x]=v;
    }
    return ;
}
inline void dfs2(int x,int tp)
{
    top[x]=tp;
    if(x!=1) root[x]=update(root[fa[x]],1,n,col[x],val[x]);
    /*如果不是根节点1就要继承父亲并且加入自己的值*/
    if(hson[x]) dfs2(hson[x],tp);
    for(int i=head[x];i;i=e[i].nxt)
    {
        int v=e[i].to;
        if(v==fa[x]||v==hson[x]) continue;
        dfs2(v,v);
    }   
}
inline int lca(int u,int v)
{
    while(top[u]!=top[v])
    {
        if(dep[top[u]]>dep[top[v]]) u=fa[top[u]];
        else v=fa[top[v]];
    }
    if(dep[u]<dep[v]) return u;
    else return v;
}

inline int getd(int x,int z,int w)
{
    pair<int,int> tmp=query(root[x],1,n,z);
    return dist[x]-tmp.first+w*tmp.second;//计算改变后的权值
}
int main()
{
    read(n);read(q);
    for(int i=1;i<n;i++)
    {
        read(x);read(y);read(z);read(w);
        add(x,y,z,w);add(y,x,z,w);
    }
    dep[1]=1;dfs1(1);dfs2(1,1);
    while(q--)
    {
        read(z);read(w);read(x);read(y);
        /这里注意读入的顺序 x,y才是查询的节点*/
        int lc=lca(x,y);
        printf("%d\n",getd(x,z,w)+getd(y,z,w)-2*getd(lc,z,w));
    }
    return 0;   
}