题解:P4866 Zrz_orz Loves Secondary Element

· · 题解

题解摘自做题记录。

分析

不难发现,只有最多 2m 个点是有用的。考虑先对这 m 个点建立虚树。

对于一种行走方案,因为每条经过过的边都走了 2 遍,所以就相当于将每条边的边权乘 2 之后的边权和。下面任何表示的距离都是原树上距离的 2 倍。

发现 m 十分小,考虑暴力 DP。定义状态函数 f_{i,j} 表示 i 为根的子树中,花费 j 的时间最多能得到的价值。因为我们是路径,所以需要保证我们选择的点形成连通块。那么对于 u 来说,它就是必选的了。那么对于选择 u 的一个儿子 v 的子树,u\to v 这条边也是必须经过的。有:f_{u,x}=\max\limits_{y=0}^{x-w} f_{u,y}+f_{v,x-y-w}。这个直接树上背包的时间复杂度是 O(mV) 的。

然后对于询问,预处理前缀 \max 可以做到 O(1)。算上建虚树的复杂度,就是 O(m\log m+mV)

代码

il void dfs1(int u,int fa){
    dfn[u]=++cnt;
    f[u][0]=fa,dep[u]=dep[fa]+1;
    for(re int i=1;i<22;++i) f[u][i]=f[f[u][i-1]][i-1];
    for(auto v:e[u])
    if(v.x!=fa){
        dis[v.x]=dis[u]+2*v.y;
        dfs1(v.x,u);
    }
    return ;
} 
il int lca(int x,int y){
    if(dep[x]<dep[y]) swap(x,y);
    for(re int i=21;i>=0;--i) if(dep[f[x][i]]>=dep[y]) x=f[x][i];
    if(x==y) return x;
    for(re int i=21;i>=0;--i) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
    return f[x][0];
}
il void build(){
    len=m;
    b[++len]=1;
    int len_=len;
    sort(b+1,b+len+1,[](int x,int y){
        return dfn[x]<dfn[y];
    });
    for(re int i=2;i<=len_;++i) b[++len]=lca(b[i],b[i-1]);
    sort(b+1,b+len+1,[](int x,int y){
        return dfn[x]<dfn[y];
    });
    len=unique(b+1,b+len+1)-(b+1);
    for(re int i=1;i<=len;++i) id[b[i]]=i,rev[i]=b[i];
    for(re int i=1;i< len;++i) E[id[lca(b[i],b[i+1])]].push_back({i+1,dis[b[i+1]]-dis[lca(b[i],b[i+1])]});
    return ;
}
il void dfs2(int u,int s){
    siz[u]=0;
    if(s==0){
        dp1[u][0]=val[rev[u]];
        for(auto v:E[u]){
            dfs2(v.x,s);
            for(re int w=min(M,siz[u]+v.y+siz[v.x]);w>=0;--w)
            for(re int x=0;x<=min(w-v.y,siz[u]);++x)
                dp1[u][w]=max(dp1[u][w],dp1[u][x]+dp1[v.x][w-v.y-x]);
            siz[u]+=v.y+siz[v.x];
        }
        for(re int i=1;i<=M;++i)
            dp1[u][i]=max(dp1[u][i],dp1[u][i-1]);
    }
    else{
        dp2[u][0]=val[rev[u]];
        for(auto v:E[u]){
            dfs2(v.x,s);
            for(re int w=min(M,siz[u]+v.y+siz[v.x]);w>=0;--w)
            for(re int x=0;x<=min(w-v.y,siz[u]);++x)
                dp2[u][w]=max(dp2[u][w],dp2[u][x]+dp2[v.x][w-v.y-x]);
            siz[u]+=v.y+siz[v.x];
        }   
        for(re int i=1;i<=M;++i)
            dp2[u][i]=max(dp2[u][i],dp2[u][i-1]);
    }
    return ;
}

il void solve(){
    n=rd,m=rd,t=rd;
    if(m<=20) M=100000;
    else M=100;
    for(re int i=1;i< n;++i){
        int u=rd,v=rd,w=rd;
        e[u].push_back({v,w}),
        e[v].push_back({u,w});
    }
    for(re int i=1;i<=m;++i){
        int x=rd;
        b[i]=x,val[x]=rd;
    }
    dfs1(1,0),build(),dfs2(1,(M==100));
    while(t--){
        int x=rd;
        if(M==100) printf("%lld\n",dp2[1][x]);
        else printf("%lld\n",dp1[1][x]);
    }
    return ;
}