P3874 [TJOI2010]砍树

· · 题解

前置知识:

树形DP

其实无所谓,和普通DP差不多

大致思路:

看到价值和重量不管三七二十一,直接考虑背包

然后再看一下数据范围,value和weight的范围极大,不能用传统的以价值作为下标的背包

但是n很小,仅到100,所以考虑以件数作为下标

转移方程:

设f[u][i]表示:
在以u为根的子树中,选择i个节点,所能得到的最大平均值

当然,记录最大平均值的同时,也要记录权值和与重量和,方便转移

那么自然有:

对于u的每棵子树v:
f[u][j]=f[v][k]+f[u][j-k];
其中,j-k>=1,因为要保证联通,所以u这个节点是必选的

具体做法:

首先,我开了一个结构体,用于存储状态


struct node{
    int v,w;double ave;//转移状态的权值和,重量和,平均值
    node(){}
    node(int a,int b,double p):v(a),w(b),ave(p){}
}f[E][E];

然后dfs

void dfs(int u,int fa){
    f[u][1]=node(val[u],wei[u],(double)val[u]/wei[u]);//初始化
    for(int i=head[u];i;i=nex[i])
        if(to[i]!=fa){//不能走回父亲
            dfs(to[i],u);//先遍历子树
            for(int j=n;j>1;j--){//注意,这里一定要逆序,否则重复选择该棵子树的点
                for(int k=1;k<=j;k++){
                    node p=f[to[i]][j-k];
                    node q=f[u][k];
                    double ave=(double)(p.v+q.v)/(p.w+q.w);
                    if(ave>=f[u][j].ave){
                        f[u][j]=node(p.v+q.v,p.w+q.w,ave);
                    }
                }
            }
        }
}

好了,完了

最终的代码

复杂度O(n^3)

#include<bits/stdc++.h>
#define E 209
#define inf 1e8
using namespace std;
inline int read(){
    int x=0;char c=getchar();
    while(!isdigit(c)) c=getchar();
    while(isdigit(c)) x=x*10+c-48,c=getchar();
    return x;
}
int n,kk,m,wei[E],val[E];
int head[E],nex[E],to[E],cnt;
inline void add(int u,int v){
    nex[++cnt]=head[u];
    head[u]=cnt;to[cnt]=v;
}
struct node{
    int v,w;double ave;
    node(){}
    node(int a,int b,double p):v(a),w(b),ave(p){}
}f[E][E];
void dfs(int u,int fa){
    f[u][1]=node(val[u],wei[u],(double)val[u]/wei[u]);
    for(int i=head[u];i;i=nex[i])
        if(to[i]!=fa){
            dfs(to[i],u);
            for(int j=n;j>1;j--){
                for(int k=1;k<=j;k++){
                    node p=f[to[i]][j-k];
                    node q=f[u][k];
                    double ave=(double)(p.v+q.v)/(p.w+q.w);
                    if(ave>=f[u][j].ave){
                        f[u][j]=node(p.v+q.v,p.w+q.w,ave);
                    }
                }
            }
        }
}
int main(){
    n=read(),kk=read();
    for(int i=1;i<=n;i++) val[i]=read();
    for(int i=1;i<=n;i++) wei[i]=read();
    for(int i=1,v,u;i<n;i++){
        u=read(),v=read();
        add(u,v);add(v,u);
    }
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
            f[i][j]=node(-inf,inf,-inf);
    dfs(1,0);double ans=0.0;
    for(int i=1;i<=n;i++)
        for(int j=kk;j<=n;j++){
//          cout<<i<<" "<<j<<" "<<f[i][j].v<<" "<<f[i][j].w<<endl;
            ans=max(ans,f[i][j].ave);
        }
    printf("%.2f",ans);
    return 0;
}