P4362题解

· · 题解

P4362 [NOI2002]贪吃的九头龙-题解 --zsy

传送门

这个题显然是一个树形dp。我们先来总结一下树形dp的套路:

  1. 定义f数组意义,注意要考虑到题目的一些特殊要求(比如本题的大头),还要考虑到如何输出结果。

  2. 思考如何将一个点的所有f值由这个点的儿子节点转移过来,即我们常说的状态转移方程。

  3. 将方程放到dfs深搜中更新f值,最终输出答案。

所以,我们先来考虑这个题怎么样定义f数组来处理特殊要求。

一、 关于大头

我们整理题目可以发现,关于大头,本题大概有这样两个限制:

  1. 大头必须吃掉1号节点

  2. 大头必须吃掉k个节点

所以,我们大可以用f[i][j]表示对于i号节点,它与它的儿子们一共有j个节点被大头吃掉了。

这样有什么好处呢?我们可以发现,最后我们只需要输出f[1][k]就万事大吉了——等等,第一个限制是不是还没考虑?

这样,我们可以再把f开一维[0/1],用f[i][j][0/1]表示i号节点有没有被大头吃掉,0表示不是大头吃的,1表示是大头吃的。现在我们可以输出f[1][k][1],就完完全全考虑完了大头的限制啦!

二、转移方程?

对于每一个节点u来说,我们需要把f[u][0-k][0/1]全部更新才算更新完了所有状态。考虑什么情况下会增加难受值——如果有3个或以上的头,虽然可能有的边连的两个点都没有被大头吃掉,但是一定可以通过剩下的头让这条边不被吃掉(这个应该很显然吧)。但是如果只有两个头,我们会发现,如果这两个点都没被大头吃掉,那它们一定被另一个头吃了。所以,我们在判断的时候要格外注意m=2的情况。

对于大头不吃u点的情况,可能由v,u大头都不吃和吃v不吃u两种情况转移过来;吃u点同理。dp方程大概长这样,反正就是考虑一下各种可能就好啦。

f_{u,j,0}=min(f_{u,j,0},min(f_{v,t,0}+f_{u,j-t,0}+[m==2]* w,f_{v,t,1}+f_{u,j-t,0})) f_{u,j,1}=min(f_{u,j,1},min(f_{v,t,1}+f_{u,j-t,1}+w,f_{v,t,0}+f_{u,j-t,1}))

(u,v,w指的是一条边的父亲,儿子和边难受值)

这样我们就可以把它放在dfs里得出答案啦!

for(int j=0;j<=k;++j){
    for(int t=0;t<=j;++t){
        f[u][j][0]=min(f[u][j][0],min(f[v][t][0]+f[u][j-t][0]+(m==2)*w,f[v][t][1]+f[u][j-t][0]));
        f[u][j][1]=min(f[u][j][1],min(f[v][t][1]+f[u][j-t][1]+w,f[v][t][0]+f[u][j-t][1]));
    }
}

然而它挂了……我们思考一下为什么呢?原来,这题在更新f[u][j]的时候会被f[u][j-t]更新,所以这个东西就开始自己自己了……所以我们要用一个数组先去记录一下f[u],然后就可以放心做dp了。

完了吗?还是没有……我就是卡在了-1上。我们思考一下,什么情况下会无解呢?注意到题目中的一个条件——每个组都要有果子,也就是每个头都要吃到果子。如果大头吃剩下的比剩下的头数要少,是不是就无解了?所以我们只需要判断n-k<m-1就可以判掉无解情况了。

Code

#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int N=1000+4;
struct Edge{int v,w,nxt;}e[N<<1];
int h[N],f[N][N][2],tmp[N][2];
int tot,m,n,k;
inline void add(int u,int v,int w){
    e[++tot]=(Edge){v,w,h[u]};
    h[u]=tot;
}
void dfs(int u,int fa){
    f[u][0][0]=f[u][1][1]=0;
    for(int i=h[u];i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(v==fa) continue;
        dfs(v,u);
        memcpy(tmp,f[u],sizeof(f[u]));
        memset(f[u],0x3f,sizeof(f[u]));
        for(int j=0;j<=k;++j){
            for(int t=0;t<=j;++t){
                f[u][j][0]=min(f[u][j][0],min(f[v][t][0]+tmp[j-t][0]+(m==2)*w,f[v][t][1]+tmp[j-t][0]));
                f[u][j][1]=min(f[u][j][1],min(f[v][t][1]+tmp[j-t][1]+w,f[v][t][0]+tmp[j-t][1]));
            }
        }
    }
}
int main(){
    memset(f,0x3f,sizeof(f));
    scanf("%d%d%d",&n,&m,&k);
    for(int i=1,u,v,w;i<n;++i){
        scanf("%d%d%d",&u,&v,&w);
        add(u,v,w),add(v,u,w);
    }
    if(n-k<m-1){printf("-1\n");return 0;}
    dfs(1,0);
    printf("%d\n",f[1][k][1]);
    return 0;
}

完结撒花!