P9194 [USACO23OPEN] Triples of Cows P 题解

· · 题解

Link

由于每次删点之后新增的边数太多,所以考虑拆边。用一个白点表示一条边,并将原来树中的点称为黑点,编号不变。对于第 i 条边 (u_i,v_i),在 (u_i,n+i),(v_i,n+i) 间连边,得到一棵树 T

删除点 i 时,我们将点 i 相邻的所有白点合并为一个点,然后删除 i。容易归纳证明:每次删除后 T 仍然是树;真实的 u,v 间有边等价于树 T 中存在一个白点与 u,v 均相邻。

因为依次删除 1,2,...,n,所以在树 T 中,可以把点 n 作为根。这样每次删除时,就可以把 i 的所有相邻白点合并到 i 的父亲白点上去。然后,用并查集维护每个白点被合并到了哪个点。

fa_u 表示在初始的 Tu 的父亲结点,p_u 表示某个时刻白点 u 被合并到的点。那么,在这一时刻,黑点 u 的父亲结点是 p_{fa_u},白点 u 的父亲节点是 fa_{p_u}

下面来考虑如何统计答案。对白点 u,记 s_uu 的儿子个数。

对于一个符合要求的 (a,b,c),设 a,b 通过白点 x 相连,b,c 通过白点 y 相连。

  1. 如果 x=y:固定 x,在 x 的邻点中任选 3 个,则对答案的贡献为 \sum_{x} (s_x+1)s_x(s_x-1)

    求和的条件是 x 是白点。

  2. 如果 x\not=y,且 x,y 都是 b 的子节点:固定 b,先任取 b 的两个子结点 x,y,此时贡献 s_xs_y。则总的贡献为 \sum_{b} (\sum_{x} s_x) ^2 - \sum_{x} s_x^2

    第一个求和的条件是 b 是黑点,后两个求和的条件是 xb 的儿子。 注意到后一项拆出来就是对所有白点 xs_x^2 的和,那就拆出来吧。

  3. 如果 x\not=y,且 x,y 一个是 b 的子结点,另一个是 b 的父亲结点:不妨 xb 的父亲结点,固定 xbx 的一个子结点,y 又是 b 的一个子结点,则对答案的贡献为 \sum_{x} 2\times s_x \times \sum_{b} \sum_{y} s_y

    三个求和的条件分别为 x 是白点,bx 的子结点,yb 的结点。

列出式子后,我们发现需要维护以下数据:

  1. 白点 u 的儿子数目 s_u
  2. 黑点 u 的儿子的 s 值之和,也可以存到数组 s
  3. 白点 u 的儿子的 s 值之和 t_u

答案是 \sum_{x}(f(s_x)+2s_xt_x)+\sum_{y} s_y^2,其中 f(x)=x^3-x^2-x,两个求和的条件分别是 x 是黑点,y 是白点。

删除点 u 时,枚举它初始的的儿子(一定没有被合并过),在并查集中将其合并到 u 的父亲中。然后清零 uu 的儿子的值,更新 u 的三层祖先的值,并更新答案。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=4e5+5;
int n,fa[N],p[N];ll s[N],t[N],ans;
int head[N],ver[N<<1],nxt[N<<1],tot;
void add(int u,int v){ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;}
void dfs(int u){
    for(int i=head[u],v;i;i=nxt[i])
        if((v=ver[i])!=fa[u]){
            fa[v]=u;dfs(v);
            if(u<=n)s[u]+=s[v];
            else ++s[u],t[u]+=s[v];
        }
}
int find(int x){return (x==p[x]?x:(p[x]=find(p[x])));}
ll f(ll x){return x*x*x-x*x-x;}
int main(){
    scanf("%d",&n);
    for(int i=1,u,v;i<n;i++){
        scanf("%d%d",&u,&v);
        add(u,n+i);add(v,n+i);add(n+i,u);add(n+i,v);
    }
    dfs(n);
    for(int i=1;i<2*n;i++)p[i]=i;
    for(int i=1;i<=n;i++)ans+=s[i]*s[i];
    for(int i=n+1;i<2*n;i++)ans+=f(s[i])+2*s[i]*t[i];
    for(int u=1;u<=n;u++){
        printf("%lld\n",ans);
        int g=find(fa[u]),w=fa[g];ll del=-1;
        ans-=f(s[g])+2*s[g]*t[g]+s[w]*s[w];s[w]-=s[g];--s[g];
        t[g]-=s[u];ans-=s[u]*s[u];s[u]=0;
        for(int i=head[u],v;i;i=nxt[i])
            if((v=ver[i])!=fa[u]){
                p[v]=g;s[g]+=s[v];t[g]+=t[v];del+=s[v];
                ans-=f(s[v])+2*s[v]*t[v];s[v]=t[v]=0;
            }
        s[w]+=s[g];ans+=f(s[g])+2*s[g]*t[g]+s[w]*s[w];
        t[w=find(fa[fa[g]])]+=del;ans+=2*s[w]*del;
    }
    return 0;
}

Upd:修改了格式。