题解 P3523 【[POI2011]DYN-Dynamite】

· · 题解

从我认真记录以来第 3 道由于读错题面而GG的题目(

观察题目,显然有一个外层二分的模型,二分距离的最大值作为上界。

问题就转化成了:一棵树上选 m 个点,使得这 m 个点与关键节点的最小距离不超过mid

然后发现上面那个问题是不好解决的,根据二分判定,问题还可以进一步转化:一棵树上选 tot 个点,使得这 tot 个点与关键节点的最小距离不超过mid,最小化 tot

这大概就是一个最小点覆盖的问题,用最少的点覆盖所有关键节点。

下面是我对于这道题目给出的一些定义:

称点 x 覆盖点 y,当且仅当点 x 与点 y的距离不超过给定的距离限制 mid。(以下就直接使用这两个概念,不再重复)

而最小点覆盖的定义是,对于每个关键节点 y ,选定的 tot 个点中有任意一个点 x 到点 y 的距离不大于给定的距离限制 mid,令 tot 最小。

选定节点为你选择的节点,关键节点同题目定义。

接下来怎么解决呢?我们发现,如果一棵子树的根 root 到最远的关键节点的距离小于 mid ,那意味着这整棵子树是可以被这个 root 覆盖的.并且如果以这种方式自底而上递归,一定能够得出一种方案来解决这个问题,但这个方案很明显不是最优解。

为什么呢?经过观察,我们可以发现,如果每次贪心的采取这种方式,我们每次会重复覆盖一些点,这样就会得不到最优解,那怎么办呢?我们可以记录下每次未被之前覆盖的关键节点到 root 的最远距离,这样就可以减少掉多余的覆盖。

为什么说叫做减少多余的覆盖,而不是消除所有多余的覆盖呢?经过思考,我们可以发现,我们覆盖整颗子树 root 时,不一定会选定 root 作为用来覆盖这棵子树的点,我们可能会使用 root 下的子孙 y 来覆盖整颗子树,从而得到更优解。

根据上面的两个思路,我们就可以想到设 f_{root} 表示距 root 最远的未被覆盖的关键节点到 root 的距离,g_{root}root 到该子树下的被选定节点的最小距离。

DP过程就为:

这样就结束了么?是不是很好奇为什么没有用到d_i

我们还需要特判几种情况以及统计最小点覆盖。

以上DP过程大概解释一下,当我们将 f_x 置为 -\infty 时,说明该子树已被完全覆盖,就不会影响其祖先的 f 值。

一定要记得特判根啊!!!

Show the Code

#include<cstdio>
//f[i] 以i为根的子树未被覆盖的最远距离
//g[i] 以i为根的子树被选择的点的最近距离
#define max(a,b) ((a)>(b)? (a):(b))
#define min(a,b) ((a)<(b)? (a):(b))
typedef long long ll;
const ll inf=1e8;
int n,m,res=0,cnt=0,tot=0;
int b[300005];
ll f[300005],g[300005];
int h[300005],to[600005],ver[600005];
inline int read() {
    register int x=0,f=1;register char s=getchar();
    while(s>'9'||s<'0') {if(s=='-') f=-1;s=getchar();}
    while(s>='0'&&s<='9') {x=x*10+s-'0';s=getchar();}
    return x*f;
}
inline void add(int x,int y) {
    to[++cnt]=y;
    ver[cnt]=h[x];
    h[x]=cnt;
}
void dfs(int x,int fa,int mid) {
    f[x]=-inf;g[x]=inf;
    for(register int i=h[x];i;i=ver[i]) {
        int y=to[i];
        if(y==fa) continue;
        dfs(y,x,mid);
        f[x]=max(f[x],f[y]+1);
        g[x]=min(g[x],g[y]+1);
    }
    if(f[x]+g[x]<=mid) f[x]=-inf;//不需要再被覆盖,直接置为-inf
    if(g[x]>mid&&b[x]==1) f[x]=max(f[x],0);//当前无法覆盖,需要父亲帮忙
    if(f[x]==mid) f[x]=-inf,g[x]=0,++tot;//不需要再被覆盖
}
int check(int x) {
    tot=0;
    dfs(1,-1,x);
    if(f[1]>=0) ++tot;
    return tot<=m;
}
int main() {
    n=read(),m=read();
    for(register int i=1;i<=n;++i) b[i]=read();
    for(register int i=1;i<n;++i) {
        int x=read(),y=read();
        add(x,y);add(y,x);
    }
    int L=0,R=n;
    while(L<=R) {
        int mid=L+R>>1;
        if(check(mid)) {R=mid-1;res=mid;}
        else {L=mid+1;}
    }
    printf("%d\n",res);
    return 0;
}