题解 P6074 【最小路径】

· · 题解

很好的一道01分数规划套长链剖分优化树形dp题(不愧是zjk大佬出的题)

题解:

首先是01分数规划的套路,二分答案k,然后每个点的点权转为val_i=a_i-k\times b_i,当前二分答案合法的条件是存在一条长度为m的树上路径的点权和小于等于0

于是当前问题转化为:给你一棵树,问你长度为m的树上路径的最小点权和

考虑树形dp,设f[x][i]表示x向下长度为i的链中的最小点权和

考虑转移,f[x][i]=val_i+\min\{f[y][i-1]\| y\in son_x\},f[x][0]=val_i

考虑答案,有两种:

考虑优化,发现转移方程里的val_i很难去掉

考虑转变dp状态,改设f[x][i]表示根到xi级儿子的链的最小点权和

考虑转移,f[x][i]=\min\{f[y][i-1]\| y\in son_x\},f[x][0]=pre_x

(其中pre_x表示根到x路径上的val的和)

考虑答案,类似原先的,只不过要注意减去祖先的冗余贡献

考虑优化,发现现在的转移方程就好看多了,是个经典的可以用长剖优化的方程

于是套用长链剖分,重链利用指针直接转移,轻链暴力转移

注意此时第二种情况的答案可以在暴力转移时同时处理

时间复杂度均摊一只log,空间复杂度线性

代码:

#pragma GCC optimize(3,"Ofast","inline")
#pragma GCC target("avx,avx2")
#include <bits/stdc++.h>
using namespace std;
template<class t> inline t read(t &x){
    char c=getchar();bool f=0;x=0;
    while(!isdigit(c)) f|=c=='-',c=getchar();
    while(isdigit(c)) x=(x<<1)+(x<<3)+(c^48),c=getchar();
    if(f) x=-x;return x;
}
template<class t> inline void write(t x){
    if(x<0) putchar('-'),write(-x);
    else{if(x>9) write(x/10);putchar('0'+x%10);}
}

const double eps=1e-4;
const int N=2e5+5;
double memory[N],ans=-1,res,*f[N],*cur,val[N];
int n,s[N],dep[N],a[N],b[N],m;
vector<int> g[N];

void dfs1(int x,int fa){ //长剖
    for(int y:g[x]) if(y^fa){
        dfs1(y,x);
        if(dep[y]>dep[s[x]]) s[x]=y;
    }
    dep[x]=dep[s[x]]+1;
}

void dfs2(int x,int fa){
    val[x]+=val[fa]; //处理val前缀和
    if(s[x]) f[s[x]]=f[x]+1,dfs2(s[x],x); //重链直接复制
    f[x][0]=val[x];
    for(int y:g[x]) if(y!=fa&&y!=s[x]){
        f[y]=cur;
        cur+=dep[y];
        dfs2(y,x);
        for(int i=0;i<dep[y];i++) if(m-i-1>=0&&dep[x]>m-i-1) res=min(res,f[x][m-i-1]+f[y][i]-val[x]-val[fa]); //处理第二种答案
        for(int i=0;i<dep[y];i++) f[x][i+1]=min(f[x][i+1],f[y][i]); //暴力转移
    }
    if(dep[x]>m) res=min(res,f[x][m]-val[fa]); //第一种答案
}

bool check(double k){
    res=1e18;
    for(int i=1;i<=n;i++) val[i]=a[i]-k*b[i];
    for(int i=1;i<=n;i++) memory[i]=1e18; //还原内存数组
    f[1]=cur=memory;
    cur+=dep[1];
    dfs2(1,0);
    return res<=0;
}

signed main(){
    read(n);read(m);
    for(int i=1;i<=n;i++) read(a[i]);
    for(int i=1;i<=n;i++) read(b[i]);
    for(int i=1,x,y;i<n;i++){
        read(x);read(y);
        g[x].push_back(y);
        g[y].push_back(x);
    }
    dfs1(1,0);
    double l=0,r=2000; //冷静分析,最大值只有2000
    while(l<=r-eps){
        double mid=(l+r)/2.0;
        if(check(mid)) ans=mid,r=mid-eps;
        else l=mid+eps;
    }
    if(ans<0) write(-1); //注意无解
    else printf("%.2lf",ans);
}