题解:P6118 [JOI 2019 Final] 独特的城市 / Unique Cities

· · 题解

P6118 [JOI 2019 Final] 独特的城市 / Unique Cities

考虑对于一个点来说,独特的城市有什么性质。独特的城市等价于深度唯一,那么如果我们找到一个涵盖深度最多的链,即最长链,那么独特的城市一定在这条链上。反证即可,不在最长链上的点一定能在最长链上找到对应深度。

那么我们只需要考察最长链上的点有哪些会成为独特的城市。有经典性质,每个点出发的最长链一定有一端是直径端点。所以如果我们以直径的两端分别跑一遍,那么每次最长链都可以看做是这个点到根的路径,根据性质这样一定可以统计到真正的最长链。

我们考虑从上到下遍历,用栈维护从根到当前点的独特点的集合。假设遍历到点 u,定义 dep_u 表示 u 的距离,mx_u,sec_u 分别表示从 u 往下的最长链和次长链的长度,用 son_u 表示 u 的长儿子:

注意到一个点自身的限制是比其子树中的点要强的,因此遍历到的点则

  1. 排除对自己和子树中都无用的决策。
  2. 处理长儿子。
  3. 排除对自己无用的决策,统计答案。
  4. 处理其他儿子。

先考虑有哪些点对于 u 和其子树中的点都是没有意义的,在次长链上的点可以消掉 u 上方路径中与 u 距离 \le sec_u 的点,不难发现次长链可以限制到整棵子树,因此与 u 距离 \le sec_u 的点对于其子树中的点都是没有用的,这些点可以直接扔掉。

然后对于 u 的儿子 v\ne son_u,同理不难发现最长链 mx_u 可以限制到其子树中的所有点,因此处理这些儿子时可以直接删掉与 u 距离 \le mx_u 的点。但是注意对于 son_umx_u 是无法限制的,因此 son_u 需要在排除 mx_u 之前先被处理。

所以整个算法流程就是:

  1. 把所有 dis(u,x)\le sec_u 的点都扔掉。
  2. 处理长儿子 son_u
  3. 把所有 dis(u,x)\le mx_u 的点扔掉,并统计 u 处的答案。
  4. 处理其他儿子。

时间复杂度 O(n)

代码:

#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

const int N=200010;
int n,m;
int h[N],e[N<<1],ne[N<<1],idx;
void add(int u,int v) { e[idx]=v,ne[idx]=h[u],h[u]=idx++; }
int c[N];

int dis[N];
void dfs_dis(int u,int fat)
{
    dis[u]=dis[fat]+1;
    for (int i=h[u];i!=-1;i=ne[i])
    {
        int v=e[i];
        if (v^fat) dfs_dis(v,u);
    }
}

int dep[N],mx[N],sec[N],son[N];
void dfs_d(int u,int fat)
{
    dep[u]=dep[fat]+1,mx[u]=0,sec[u]=0,son[u]=0;
    for (int i=h[u];i!=-1;i=ne[i])
    {
        int v=e[i];
        if (v==fat) continue;
        dfs_d(v,u);

        int dist=mx[v]+1;
        if (dist>mx[u]) { sec[u]=mx[u],mx[u]=dist,son[u]=v; }
        else if (dist>sec[u]) { sec[u]=dist; }
    }
}

int ans[N];
int stk[N],top;
int buc[N],res;
void del(int x) { if (!--buc[c[x]]) res--; }
void add(int x) { if (!buc[c[x]]++) res++; }

void dfs(int u,int fat)
{
    if (fat) { stk[++top]=fat, add(fat); };
    while (top && dep[u]-dep[stk[top]]<=sec[u]) del(stk[top--]);
    if (son[u]) dfs(son[u],u);
    while (top && dep[u]-dep[stk[top]]<=mx[u]) del(stk[top--]);
    ans[u]=max(ans[u],res);
    for (int i=h[u];i!=-1;i=ne[i])
    {
        int v=e[i];
        if (v==fat || v==son[u]) continue;
        dfs(v,u);
    }
    if (fat && stk[top]==fat) { top--; del(fat); }
}

int main()
{
    cin >> n >> m;
    memset(h,-1,sizeof(h));
    for (int i=1;i<n;i++)
    {
        int u,v;
        cin >> u >> v;
        add(u,v),add(v,u);
    }
    for (int i=1;i<=n;i++) cin >> c[i];

    dfs_dis(1,0);
    int s=0;
    for (int i=1;i<=n;i++) s=(dis[i]>dis[s] ? i : s);
    dfs_dis(s,0);
    int t=0;
    for (int i=1;i<=n;i++) t=(dis[i]>dis[t] ? i : t);

    top=0,memset(buc,0,sizeof(buc)),res=0,dfs_d(s,0),dfs(s,0);
    top=0,memset(buc,0,sizeof(buc)),res=0,dfs_d(t,0),dfs(t,0);

    for (int i=1;i<=n;i++) cout << ans[i] << "\n";

    return 0;
}