题解 P4242 【树上的毒瘤】

· · 题解

本来是想着好像可以在虚树上点分治来做一些不是很好计算的东西?

然后本来想出一道在虚树上点分治+FFT来求一下特定长度的路径?结果发现这东西套在虚树上复杂度不对。

然后就出点树形dp不好做的东西……然后由于套了虚树所以可以带修改……然后……

然后就变得非常码农

然后就有了这个题:树上多次询问,每次给定一些关键点,需要统计出每个关键点到其他关键点的所有路径上的颜色段数之和。

首先是询问的总点数和n同阶,马上就能看出是虚树。

然后就是在虚树上统计答案,上点分治。

考虑怎么统计经过当前分治中心的路径对答案的贡献。

对于分治中心的一棵子树,我们现在只统计经过中心的路径的贡献,那么我们可以求出从中心出发进入其他子树的路径的总贡献,然后我们dfs这棵子树,某个节点的答案只需要把中心那边的答案接上中心到当前点的那一段路径的贡献即可。这样就可以统计出所有点的答案。对于当前分治中心,它自己的答案要加上从它出发向它所在分治块内的路径的贡献。

在虚树上还要注意,只能统计询问点的答案,有一些询问点的lca也在虚树上但不能给其他询问点算贡献。然后由于虚树上的边带表原树上的一条链,所以虚树上的边权是原树上那条链上的颜色段数。

然后我们再写个树链剖分来求lca、链修改、链查询颜色段数。

然后就没有然后了,码吧

复杂度O((\sum m)\log^2n)

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;

#define lc x<<1
#define rc x<<1|1
#define For(i,_beg,_end) for(int i=(_beg),i##end=(_end);i<=i##end;++i)
#define Rep(i,_beg,_end) for(int i=(_beg),i##end=(_end);i>=i##end;--i)

template<typename T>T Max(const T &x,const T &y){return x<y?y:x;}
template<typename T>T Min(const T &x,const T &y){return x<y?x:y;}
template<typename T>int chkmax(T &x,const T &y){return x<y?(x=y,1):0;}
template<typename T>int chkmin(T &x,const T &y){return x>y?(x=y,1):0;}
template<typename T>void read(T &x){
    T f=1;char ch=getchar();
    for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1;
    for(x=0;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0';
    x*=f;
}

typedef long long LL;
const int maxn=100010,inf=0x7fffffff;
struct edge{
    int to,nxt,len;
}e[maxn<<1];
struct Data{
    int cl,cr,sum;
    Data(){cl=cr=sum=0;}
    Data(int l,int r,int s){cl=l;cr=r;sum=s;}
    friend Data operator+(const Data &x,const Data &y){
        if(!x.cl)return y;
        if(!y.cl)return x;
        return Data(x.cl,y.cr,x.sum+y.sum-(x.cr==y.cl));
    }
}T[maxn<<2];
int n,m,q,num,head[maxn],sum,root,f[maxn],size[maxn];
int c[maxn],dfn[maxn],cnt,tp[maxn];
int anc[maxn],son[maxn],dep[maxn],rnk,pos[maxn],id[maxn];
int sta[maxn],top,h[maxn],sq[maxn],tag[maxn<<2];
LL ans[maxn],tot,all,path;
bool vis[maxn],is[maxn];

void addedge(int,int,int);
void Dfs1(int,int);
void Dfs2(int,int,int);
int Lca(int,int);
void Build(int,int,int);
void pushsame(int,int);
void pushdown(int);
Data Query(int,int,int,int,int);
void Modify(int,int,int,int,int,int);
int Query(int,int);
void Change(int,int,int);

bool comp(int,int);
void Build(void);

void get(int,int);
void get_sum(int,int);
void get_root(int,int);
void Dfs(int,int,int);
void Modify(int,int,int);
void Calc(int);
void Solve(int);

int main(){
    read(n);read(q);
    For(i,1,n) read(c[i]);
    For(i,1,n-1){
        int u,v;
        read(u);read(v);
        addedge(u,v,0);
    }
    Dfs1(1,0);Dfs2(1,0,1);
    Build(1,1,n);
    int op,u,v,col;
    while(q--){
        read(op);
        if(op==1){
            read(u);read(v);read(col);
            Change(u,v,col);
        }
        else{
            read(m);
            For(i,1,m){
                read(h[i]);
                is[h[i]]=true;
                sq[i]=h[i];
            }
            Build();Solve(1);
            For(i,1,m) is[sq[i]]=false;
            For(i,1,m) printf("%lld ",ans[sq[i]]),ans[sq[i]]=0;
            putchar(10);
        }
    }
    return 0;
}

void Calc(int x){
    get(x,0);path=size[x];
    Dfs(x,tot=0,1);all=tot;
    if(is[x]) ans[x]+=tot;
    for(int i=head[x];i;i=e[i].nxt) if(!vis[e[i].to]){
        tot=0;
        Dfs(e[i].to,x,e[i].len+1);
        all-=tot;path-=size[e[i].to];
        Modify(e[i].to,x,e[i].len);
        all+=tot;path+=size[e[i].to];
    }
    vis[x]=true;
}
void Solve(int x){
    Calc(x);
    for(int i=head[x];i;i=e[i].nxt) if(!vis[e[i].to]){
        f[sum=root=0]=inf;
        get_sum(e[i].to,0);
        get_root(e[i].to,0);
        Solve(root);
    }
    vis[x]=false;
}
void Modify(int x,int fa,int lst){
    if(is[x]) ans[x]+=all+1LL*lst*path;
    for(int i=head[x];i;i=e[i].nxt){
        if(e[i].to==fa||vis[e[i].to]) continue;
        Modify(e[i].to,x,lst+e[i].len);
    }
}
void Dfs(int x,int fa,int lst){
    if(is[x]) tot+=lst;
    for(int i=head[x];i;i=e[i].nxt){
        if(e[i].to==fa||vis[e[i].to]) continue;
        Dfs(e[i].to,x,lst+e[i].len);
    }
}
void get(int x,int fa){
    size[x]=is[x];
    for(int i=head[x];i;i=e[i].nxt){
        if(e[i].to==fa||vis[e[i].to]) continue;
        get(e[i].to,x);
        size[x]+=size[e[i].to];
    }
}
void get_sum(int x,int fa){
    sum++;
    for(int i=head[x];i;i=e[i].nxt){
        if(e[i].to==fa||vis[e[i].to]) continue;
        get_sum(e[i].to,x);
    }
}
void get_root(int x,int fa){
    f[x]=0;size[x]=1;
    for(int i=head[x];i;i=e[i].nxt){
        if(e[i].to==fa||vis[e[i].to]) continue;
        get_root(e[i].to,x);
        size[x]+=size[e[i].to];
        chkmax(f[x],size[e[i].to]);
    }
    chkmax(f[x],sum-size[x]);
    if(f[x]<f[root]) root=x;
}

void addedge(int u,int v,int w){
    e[++num].to=v;e[num].len=w;e[num].nxt=head[u];head[u]=num;
    e[++num].to=u;e[num].len=w;e[num].nxt=head[v];head[v]=num;
}
void Dfs1(int x,int fa){
    size[x]=1;dep[x]=dep[fa]+1;
    dfn[x]=++cnt;anc[x]=fa;
    for(int i=head[x];i;i=e[i].nxt) if(e[i].to!=fa){
        Dfs1(e[i].to,x);
        size[x]+=size[e[i].to];
        if(size[e[i].to]>size[son[x]]) son[x]=e[i].to;
    }
}
void Dfs2(int x,int fa,int up){
    tp[x]=up;id[++rnk]=x;pos[x]=rnk;
    if(son[x]) Dfs2(son[x],x,up);
    for(int i=head[x];i;i=e[i].nxt)
        if(e[i].to!=fa&&e[i].to!=son[x])
            Dfs2(e[i].to,x,e[i].to);
}
int Lca(int u,int v){
    int fx=tp[u],fy=tp[v];
    while(fx!=fy){
        if(dep[fx]>dep[fy]) fx=tp[u=anc[fx]];
        else fy=tp[v=anc[fy]];
    }
    return dep[u]<dep[v]?u:v;
}
void Build(int x,int l,int r){
    if(l==r)return T[x]=Data(c[id[l]],c[id[l]],1),void();
    int mid=(l+r)>>1;
    Build(lc,l,mid);Build(rc,mid+1,r);
    T[x]=T[lc]+T[rc];
}
void pushsame(int x,int col){
    tag[x]=col;
    T[x]=Data(col,col,1);
}
void pushdown(int x){
    if(tag[x]){
        pushsame(lc,tag[x]);
        pushsame(rc,tag[x]);
        tag[x]=0;
    }
}
Data Query(int x,int l,int r,int ql,int qr){
    if(l>=ql&&r<=qr)return T[x];
    int mid=(l+r)>>1;pushdown(x);
    if(ql<=mid&&qr>mid)return Query(lc,l,mid,ql,qr)+Query(rc,mid+1,r,ql,qr);
    else if(ql<=mid)return Query(lc,l,mid,ql,qr);
    else return Query(rc,mid+1,r,ql,qr);
}
void Modify(int x,int l,int r,int ql,int qr,int col){
    if(l>=ql&&r<=qr)return pushsame(x,col),void();
    int mid=(l+r)>>1;pushdown(x);
    if(ql<=mid) Modify(lc,l,mid,ql,qr,col);
    if(qr>mid) Modify(rc,mid+1,r,ql,qr,col);
    T[x]=T[lc]+T[rc];
}
int Query(int u,int v){
    Data res;
    int fx=tp[u],fy=tp[v];
    while(fx!=fy){
        res=Query(1,1,n,pos[fx],pos[u])+res;
        fx=tp[u=anc[fx]];
    }
    res=Query(1,1,n,pos[v],pos[u])+res;
    return res.sum-1;
}
void Change(int u,int v,int col){
    int fx=tp[u],fy=tp[v];
    while(fx!=fy){
        if(dep[fx]>dep[fy]){
            Modify(1,1,n,pos[fx],pos[u],col);
            fx=tp[u=anc[fx]];
        }
        else{
            Modify(1,1,n,pos[fy],pos[v],col);
            fy=tp[v=anc[fy]];
        }
    }
    if(dep[u]>dep[v]) swap(u,v);
    Modify(1,1,n,pos[u],pos[v],col);
}

bool comp(int x,int y){return dfn[x]<dfn[y];}
void Build(){
    sort(h+1,h+m+1,comp);
    num=head[sta[top=1]=1]=0;
    For(i,h[1]==1?2:1,m){
        int lca=Lca(h[i],sta[top]);
        if(lca!=sta[top]){
            while(dfn[sta[top-1]]>dfn[lca]){
                addedge(sta[top],sta[top-1],Query(sta[top],sta[top-1]));
                top--;
            }
            if(sta[top-1]!=lca){
                head[lca]=0;
                addedge(sta[top],lca,Query(sta[top],lca));
                sta[top]=lca;
            }
            else addedge(sta[top],lca,Query(sta[top],lca)),top--;
        }
        head[h[i]]=0;
        sta[++top]=h[i];
    }
    For(i,1,top-1) addedge(sta[i+1],sta[i],Query(sta[i+1],sta[i]));
}