题解:CF1527D MEX Tree

· · 题解

致敬传奇分类讨论。

题意

给你一棵 n 个点的树,顶点编号 0,1,\dots,n-1。定义一条路径的权值是这个路径上的点的编号构成集合的 \operatorname{mex}。对于每一个 0 \le i \le n 求出权值恰好是 i 的路径数量。

分析

求出恰好为 i 时的答案不好做。考虑求出权值大于等于 i 时的答案再差分回去。

如果某一个路径权值大于等于 i 就说明这个路径经过了 0,1,\dots,i-1。如果为 i 的答案不是 0 就说明至少有一个路径能经过这 i 个点。

因此我们从小到大枚举 i,同时维护 l,r 表示树链的左右端点。如果 i,l,r 不能被任一树链覆盖那么 i 及以后的答案就全是 0 了。否则我们更新树链的左右端点。答案利用乘法原理,相当于分别在“r 的除了 l 方向以外的其他连通分量”和“l 的除了 r 方向以外的其他连通分量”各选一个点。

考虑如何维护 l,r。对于已有的 l,r。有两种情况:祖孙关系或不是祖孙关系。如果 l,r 不是祖孙关系,考虑 i 在哪里合法。只有三个地方:l 子树、r 子树、l,r 路径上。前两种情况简单,但是什么时候 i 位于 l,r 路径上呢?

这里给出引理:

证明:

\operatorname{LCA}(a,b)=T,故左侧集合等于 \{c,T\}

讨论 c 的位置:

如果 ca 子树里面,则右侧集合等于 \{T,a\}。此时等式不成立。cb 子树里面时同理。

如果 c 不在 T 子树里面,则右侧集合等于 \{\operatorname{LCA}(T,c) \cdot 2\}。等式不成立。

如果 c=T 则等式显然成立。

如果 ca,b 路径上,不妨假设 \operatorname{LCA}(a,c)=T,那么右侧集合等于 \{T,c\}。等式成立。

其他情况下假设 \operatorname{LCA}(a,c)=T,那么右侧集合等于 \{T,\operatorname{LCA}(b,c)\}。由于 c 不在 a,b 路径上,因此等式不成立。

综上所述,等式成立的充要条件就是 ca,b 路径上。证毕。

这样一来我们就能够解决 l,r 不是祖孙关系的情形了。

然后讨论 l,r 呈祖孙关系的情况。这里设 rl 的祖先。这里 i 的合法位置就剩下了:l 子树、l,r 路径上以及 r 的除 l 方向的子树和 r 子树外这四种情况。这时前两种和第四种是简单的,主要来看第三种。

这里比较巧妙。如果 i 是第三种情况,那么 ir 子树中,因此 \operatorname{LCA}(r,i)=r。同时 i 不在“rl 方向子树”里面。换句话说就是 l 不在 i 子树里面呗。于是 \operatorname{LCA}(l,i)=r。这两个条件就完美约束了 i 的所有范围。

维护好了 l,r 该计算答案了。依然分讨两种情况:祖孙关系或不是祖孙关系。后者依旧简单,ans=sz_l \times sz_r

对于 l,r 呈祖孙关系的情形,设 rl 的祖先。但是我没有想到巧妙办法,只能够把“rl 方向子树”到底是谁求出来。可以写 dfn 序二分,而使用重剖的人就可以依靠“重链 dfn 序连续”的性质直接爬链即可。

假设“rl 方向子树”是 u,那么就有 ans=sz_l \times (n-sz_u)

最终复杂度:O(n+q\log n)。瓶颈在于求 LCA 和求定向子树。注意特判掉 \operatorname{mex}=0,1 的情况。

代码

#include<bits/stdc++.h>
using namespace std;
#define int long long
int n;
int dfstime;
struct tree_heavy{
    int hson,fa,sz,d;
    int dfn,top;
}t[500005];
int num[500005];
vector<int> g[500005];
int ans[500005];
void dfs1(int x,int fa){
    t[x].fa=fa;
    t[x].sz=1;
    t[x].d=t[t[x].fa].d+1;
    t[x].hson=-1;
    for(int i=0;i<g[x].size();i++){
        int v=g[x][i];
        if(v==fa){
            continue;
        }
        dfs1(v,x);
        t[x].sz+=t[v].sz;
        if(t[x].hson==-1||t[t[x].hson].sz<t[v].sz){
            t[x].hson=v;
        }
    }
}
void dfs2(int x,int tp){
    t[x].top=tp;
    dfstime++;
    t[x].dfn=dfstime;
    num[dfstime]=x;
    if(t[x].hson!=-1){
        dfs2(t[x].hson,tp);
    }
    for(int i=0;i<g[x].size();i++){
        int v=g[x][i];
        if(v==t[x].fa||v==t[x].hson){
            continue;
        }
        dfs2(v,v);
    }
}
int LCA(int x,int y){
    while(t[x].top!=t[y].top){
        if(t[t[x].top].d<t[t[y].top].d){
            swap(x,y);
        }
        x=t[t[x].top].fa;
    }
    if(t[x].d<t[y].d){
        swap(x,y);
    }
    return y;
}
int LZH(int x,int y){
    while(t[x].top!=t[y].top){
        if(t[t[x].top].d<t[t[y].top].d){
            swap(x,y);
        }
        if(t[t[x].top].fa==y){
            return t[x].top;
        }
        x=t[t[x].top].fa;
    }
    if(t[x].d<t[y].d){
        swap(x,y);
    }
    return num[t[y].dfn+1];
}
int l,r;
int calc(){
    int C=LCA(l,r);
    if(C!=l&&C!=r){
        return t[l].sz*t[r].sz;
    }
    if(C==l){
        swap(l,r);
    }
    return t[l].sz*(n-t[LZH(l,r)].sz);
} 
bool check(int x){
    int C=LCA(l,r);
    if(C!=l&&C!=r){
        int cl=LCA(x,l);
        int cr=LCA(x,r);
        if(cl==x&&cr==C||cr==x&&cl==C){
            return 1;
        } 
        if(cl==l&&cr==C){
            l=x;
            return 1;
        }
        if(cr==r&&cl==C){
            r=x;
            return 1;
        }
        return 0;
    }
    if(C==l){
        swap(l,r);
    }
    int cl=LCA(x,l);
    int cr=LCA(x,r);
    if(cl==l){
        l=x;
        return 1;
    }
    if(cl==x&&cr==r){
        return 1;
    }
    if(cr==cl){
        r=x;
        return 1;
    }
    return 0;
}
void slv(){
    dfstime=0;
    for(int i=1;i<=n;i++){
        g[i].clear();
        t[i]={0,0,0,0,0,0};
    }
    cin>>n;
    for(int i=1;i<n;i++){
        int x,y;
        cin>>x>>y;
        x++;
        y++;
        g[x].push_back(y);
        g[y].push_back(x);
    }
    dfs1(1,0);
    dfs2(1,1);
    l=1,r=2;
    ans[0]=ans[1]=n*(n-1)/2;
    for(int v:g[1]){
        ans[1]-=t[v].sz*(t[v].sz-1)/2;
    }
    ans[2]=calc();
    ans[n+1]=0;
    int fl=0;
    for(int i=3;i<=n;i++){
        if(fl||!check(i)){
            fl=1;
            ans[i]=0;
        }
        else{
            ans[i]=calc();
        }
    }
    for(int i=0;i<n;i++){
        ans[i]-=ans[i+1];
    }
    for(int i=0;i<=n;i++){
        cout<<ans[i]<<" ";
    }
    cout<<"\n";
}
signed main(){
    int t;
    cin>>t;
    while(t--){
        slv();
    }
    return 0;
}

附图: