题解 P5588 【小猪佩奇爬树】

ljc20020730

2019-10-14 19:22:47

Solution

本题可以直接无脑写一个虚树。 时间复杂度依然是优美的$O(n \ log_2 \ n)$。 考虑对于每一个颜色分别处理,建出$n$个颜色的$n$棵虚树。 设第$i$个颜色有$k_i$个节点,那么显然这棵虚树上的节点个数是在$k_i$这个级别上的。 虚树的根设为该颜色所有节点的共同$lca$,求多个节点的$lca$怕已经是$trick$题了。 直接找出$dfs$序最小和$dfs$序最大的两个节点,求一下$lca$即可。 首先,这棵树必然没有度数为$2$的节点,如果有,那么直接输出$0$ 否则,按照根节点的度数是$1$还是$2$分类讨论。 如果是$1$ , 那么就是一个以$1$为根的链型,我们直接可以$dfs$找到虚树中最深的节点把它拿出来。而此时的树根必然是一个标记点。 如果是$2$ , 那么就是以$1$为根挂下两根链,我们直接$dfs$找出这两根链的最深处的两个节点。 相信只要找出这一条链,大家就不难考虑求出最终的答案了。 需要特殊考虑的是,只有$0$个节点或者$1$个节点的情况,这需要分类讨论。 对于$1$个节点的情况,可以考虑补集转化,即求出子树中的路径个数,用总路径个数减去。 这样,本题的时间复杂度就是美妙且暴力的$O(n \ log_2 \ n)$ > $Ps : $ 不知道蒟蒻写个这个虚树为什么跑的那么慢,下面的代码需要开$O2$才能cao过$3s$时限.... ```cpp # include<bits/stdc++.h> # define int long long using namespace std; namespace fast_IO{ const int IN_LEN = 10000000, OUT_LEN = 10000000; char ibuf[IN_LEN], obuf[OUT_LEN], *ih = ibuf + IN_LEN, *oh = obuf, *lastin = ibuf + IN_LEN, *lastout = obuf + OUT_LEN - 1; inline char getchar_(){return (ih == lastin) && (lastin = (ih = ibuf) + fread(ibuf, 1, IN_LEN, stdin), ih == lastin) ? EOF : *ih++;} inline void putchar_(const char x){if(oh == lastout) fwrite(obuf, 1, oh - obuf, stdout), oh = obuf; *oh ++= x;} inline void flush(){fwrite(obuf, 1, oh - obuf, stdout);} int read(){ int x = 0; int zf = 1; char ch = ' '; while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar_(); if (ch == '-') zf = -1, ch = getchar_(); while (ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar_(); return x * zf; } void write(int x){ if (x < 0) putchar_('-'), x = -x; if (x > 9) write(x / 10); putchar_(x % 10 + '0'); } } using namespace fast_IO; const int N=1e6+10; struct rec{ int pre,to;}a[N<<1]; vector<int>v[N],E[N]; stack<int>s; bool ok[N]; int n,tot; int head[N],size[N],is[N],dep[N],dfn[N],du[N]; int g[N][22]; void adde(int u,int v) { a[++tot].pre=head[u]; a[tot].to=v; head[u]=tot; } void dfs1(int u,int fa) { size[u]=1; g[u][0]=fa; dep[u]=dep[fa]+1; dfn[u]=++dfn[0]; for (int i=head[u];i;i=a[i].pre) { int v=a[i].to; if (v==fa) continue; dfs1(v,u); size[u]+=size[v]; } } int lca(int u,int v) { if (dep[u]<dep[v]) swap(u,v); for (int i=21;i>=0;i--) if (dep[g[u][i]]>=dep[v]) u=g[u][i]; if (u==v) return u; for (int i=21;i>=0;i--) if (g[u][i]!=g[v][i]) u=g[u][i],v=g[v][i]; return g[u][0]; } int Ref; void clear(int u,int fa) { ok[u]=0; du[u]=0; for (int i=0;i<(int)E[u].size();i++) { int v=E[u][i]; if (v==fa) continue; clear(v,u); } E[u].clear(); } void dfs2(int u,int fa) { if (ok[u]) Ref=u; for (int i=0;i<(int)E[u].size();i++) { int v=E[u][i]; if (v==fa) continue; dfs2(v,u); } } bool cmp(int a,int b) { return dfn[a]<dfn[b]; } bool work(int col) { while (s.size()) s.pop(); int cnt=0; for (int i=0;i<(int)v[col].size();i++) is[++cnt]=v[col][i],ok[is[cnt]]=1; int Min=is[1],Max=is[1]; for (int i=2;i<=cnt;i++) { if (dfn[Min]>dfn[is[i]]) Min=is[i]; if (dfn[Max]<dfn[is[i]]) Max=is[i]; } int rt=lca(Min,Max); sort(is+1,is+1+cnt,cmp); s.push(rt); int res = 0; bool oo = true; for (int i=1;i<=cnt;i++) if (is[i]!=rt) { int u=is[i],l=lca(u,s.top()); while (s.top()!=l) { int tmp=s.top();s.pop(); if (dfn[s.top()]<dfn[l]) s.push(l); res++; E[s.top()].push_back(tmp); E[tmp].push_back(s.top()); du[s.top()]++; du[tmp]++; res++; if (du[s.top()]>2 || du[tmp]>2) { oo=false; } } s.push(u); } while (s.top()!=rt) { int tmp=s.top();s.pop(); E[s.top()].push_back(tmp); E[tmp].push_back(s.top()); res++; du[s.top()]++; du[tmp]++; if (du[s.top()]>2 || du[tmp]>2) { oo=false; } } if (!oo) { clear(rt,0); return false; } if (E[rt].size()==1) { int u=rt,v; dfs2(rt,0); v=Ref; int s1=size[v]; for (int i=21;i>=0;i--) if (dep[g[v][i]]>=dep[u]+1) v=g[v][i]; int s2=n-size[v]; write(s1*s2); putchar_('\n'); } else { int u,v; dfs2(E[rt][0],rt); u=Ref; dfs2(E[rt][1],rt); v=Ref; write(size[u]*size[v]); putchar_('\n'); } clear(rt,0); return true; } signed main() { n=read(); for (int i=1;i<=n;i++) { int t=read(); v[t].push_back(i); } for (int i=2;i<=n;i++) { int u=read(),v=read(); adde(u,v); adde(v,u); } dfs1(1,0); for (int i=1;i<=21;i++) for (int j=1;j<=n;j++) g[j][i]=g[g[j][i-1]][i-1]; for (int c=1;c<=n;c++) { if (v[c].size()==0) { write(n*(n-1)/2); putchar_('\n'); continue; } if (v[c].size()==1) { int u=v[c][0],tt=0,ret=n*(n-1)/2; for (int i=head[u];i;i=a[i].pre) { int v=a[i].to; if (v==g[u][0]) continue; ret-=size[v]*(size[v]-1)/2; tt+=size[v]; } ret-=(n-tt-1)*(n-tt-2)/2; write(ret); putchar_('\n'); continue; } if (!work(c)) write(0),putchar_('\n'); } flush(); return 0; } ```