vectorwyx 的博客

vectorwyx 的博客

My left brain has nothing right, my right brain has nothing left.

题解 P4362 【[NOI2002] 贪吃的九头龙】

posted on 2021-01-12 11:41:28 | under 题解 |

大头所吃掉的果子一定会将原树分割成若干棵子树,若小头的数量 $m-1$ 等于 $2$,对分割出来的子树进行黑白染色,一定存在一种方案使得每一对相邻的点颜色不同,因此小头与小头之间的边不会对答案产生贡献。若 $m-1>2$,我们只需要从这些子树上随便选出 $m-3$ 个点与任意 $m-3$ 个小头一一对应,再对其余的点进行黑白染色,这样的话就与前面一样了。

综上,若 $m\ge 3$,那么小头与小头之间的边一定不会产生贡献。把 $1$ 作为根结点,问题就转化成在原树上取包括根结点在内的 $k$ 个点使得难受值最小,令 $dp_{i,j,0/1}$ 表示从以 $i$ 为根的子树中取不包含/包含 $i$ 的 $j$ 个点所能达到的最小难受值,转移可以参考树形背包,即再做一个合并子树的 $dp$,令 $f_{i,j}$ 表示不选根结点时前 $i$ 棵子树选 $j$ 个的最小难受值。转移方程为 $f_{i,j}=\min(f_{i-1,q}+\min(dp_{p,j-q,0},dp_{p,j-q,1}))$,其中 $0\le j,q\le siz$, $p$ 为第 $i$ 棵子树的跟结点, $siz$ 为整棵树的结点个数。选根结点时的转移方程同理。最终 $dp_{1,k,1}$ 就是我们所求的答案。

若 $m=2$,那小头与小头之间的边一定会产生贡献,对转移方程稍作修改即可。

代码如下:

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#define ll long long
#define fo(i,x,y) for(int i=x;i<=y;++i)
#define go(i,x,y) for(int i=x;i>=y;--i)
using namespace std;
inline int read(){ int x=0,fh=1; char ch=getchar(); while(!isdigit(ch)){ if(ch=='-') fh=-1; ch=getchar(); } while(isdigit(ch)){ x=(x<<1)+(x<<3)+ch-'0'; ch=getchar(); } return x*fh; }

const int N=305;
struct Edge{
    int to,next,val;
}e[N<<1];
int head[N],tot,n,m,k,dp[N][N][2],f0[N],f1[N],siz[N];

void connect(int x,int y,int v){
    e[++tot]=(Edge){y,head[x],v};
    head[x]=tot; 
}

void dfs(int now,int from){
    for(int i=head[now];i;i=e[i].next){
        int p=e[i].to;
        if(p==from) continue;
        dfs(p,now);
    }
    memset(f0,0x3f,sizeof f0);
    memset(f1,0x3f,sizeof f1);
    siz[now]=1;
    f0[0]=f1[0]=0; 
    for(int i=head[now];i;i=e[i].next){
        int p=e[i].to;
        if(p==from) continue;
        siz[now]+=siz[p];
        fo(w,0,siz[p]) f0[siz[now]]=min(f0[siz[now]],f0[siz[now]-w]+min(dp[p][w][0]+(m==2)*e[i].val,dp[p][w][1]));
        go(j,siz[now]-1,0){
            f0[j]=f0[j]+(m==2)*e[i].val+dp[p][0][0];
            f1[j]=f1[j]+dp[p][0][0];
            fo(w,1,min(j,siz[p])){
                int q=j-w;
                f0[j]=min(f0[j],f0[q]+min(dp[p][w][0]+(m==2)*e[i].val,dp[p][w][1]));
                f1[j]=min(f1[j],f1[q]+min(dp[p][w][0],dp[p][w][1]+e[i].val));
            }           
        }
    }
    dp[now][0][0]=f0[0];
    fo(i,1,siz[now]) dp[now][i][0]=f0[i],dp[now][i][1]=f1[i-1];
}

int main(){
    cin>>n>>m>>k;
    if(k+m-1>n){
        cout<<-1;
        return 0;
    }
    fo(i,1,n-1){
        int x=read(),y=read(),v=read();
        connect(x,y,v);
        connect(y,x,v);
    }
    dfs(1,0);
    cout<<dp[1][k][1];
    return 0;
}

点个赞再走吧QAQ,谢谢!