题解 AT4995 【[AGC034E] Complete Compress】

· · 题解

推销博客:https://www.cnblogs.com/syc233/p/13603919.html

注意到数据范围很小,所以我们可以枚举一个根,尽量让所有点移动到这个点上。

将两颗距离至少为 2 的棋子向中间移动一步,可能会出现两种情况:

显然第二种情况对于寻找最优解是不利的,为了让所有点移动到根上,每个点都应向上跳。所以我们只考虑以第一种情况操作。

zzy:

这里就涉及到一个经典的模型:有若干个元素被分成了若干个集合, 每次要找两个在不同集合中的元素匹配然后消掉。

设所有集合的总大小为 sum ,最大集合的大小为 max ,有两种情况:

再回到这道题。

考虑树上的一个点 u ,我们将它的子树中的点移动到 u 上。对于 u 子树中一个点 v ,我们必须对它进行 dis(u,v) 次操作才能将它移动到点 u 。这相当于将 u 子树中每个有棋子的点 v 拆成了一个有 dis(u,v) 个元素的集合,每次在两个集合中分别选择两个元素相消,则问题转化为了经典模型。

但是由于树结构的限制,有祖孙关系的两个点对应的集合不能同时选择。因此我们考虑树形DP,将问题转化为互不相交的子问题求解。

f_u 表示 u 的子树内最多能消去多少对元素,dis_u 表示 u 的子树中棋子拆成的集合的总大小,设 max 表示 u 的儿子对应的子树中棋子拆成的集合的最大的总大小,即 {\rm{max}}\{dis_v|v \in son_v\} 。则有转移:

root 为根时,当且仅当 f_{root}=\frac{dis_{root}}{2} ,即所有元素都消完时合法。

\text{Code}:
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <cmath>
#define maxn 2005
#define Rint register int
#define INF 0x3f3f3f3f
using namespace std;
typedef long long lxl;

template <typename T>
inline void read(T &x)
{
    x=0;T f=1;char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
    x*=f;
}

struct edge
{
    int u,v,next;
}e[maxn<<1];

int head[maxn],k;

inline void add(int u,int v)
{
    e[k]=(edge){u,v,head[u]};
    head[u]=k++;
}

int n,a[maxn];
int siz[maxn],dis[maxn],f[maxn];

inline void dfs(int u,int fa)
{
    siz[u]=a[u];
    dis[u]=0;
    int son=-1;
    for(int i=head[u];~i;i=e[i].next)
    {
        int v=e[i].v;
        if(v==fa) continue;
        dfs(v,u);
        siz[u]+=siz[v];
        dis[u]+=(dis[v]+=siz[v]);
        if(!~son||dis[son]<dis[v]) son=v;
    }
    if(!~son) return void(f[u]=0);
    if(dis[u]-dis[son]>=dis[son]) f[u]=dis[u]/2;
    else f[u]=dis[u]-dis[son]+min(f[son]*2,2*dis[son]-dis[u])/2;
}

char s[maxn];

int main()
{
    // freopen("AT4995.in","r",stdin);
    read(n);
    scanf(" %s",s+1);
    for(int i=1;i<=n;++i) a[i]=s[i]-'0';
    memset(head,-1,sizeof(head));
    for(int i=1,u,v;i<n;++i)
    {
        read(u),read(v);
        add(u,v);add(v,u);
    }
    int ans=-1;
    for(int i=1;i<=n;++i)
    {
        dfs(i,0);
        if(f[i]*2==dis[i]) ans=(!~ans||ans>f[i])?f[i]:ans;
    }
    printf("%d\n",ans);
    return 0;
}