题解 P5588 【小猪佩奇爬树】
ljc20020730
2019-10-14 19:22:47
本题可以直接无脑写一个虚树。
时间复杂度依然是优美的$O(n \ log_2 \ n)$。
考虑对于每一个颜色分别处理,建出$n$个颜色的$n$棵虚树。
设第$i$个颜色有$k_i$个节点,那么显然这棵虚树上的节点个数是在$k_i$这个级别上的。
虚树的根设为该颜色所有节点的共同$lca$,求多个节点的$lca$怕已经是$trick$题了。
直接找出$dfs$序最小和$dfs$序最大的两个节点,求一下$lca$即可。
首先,这棵树必然没有度数为$2$的节点,如果有,那么直接输出$0$
否则,按照根节点的度数是$1$还是$2$分类讨论。
如果是$1$ , 那么就是一个以$1$为根的链型,我们直接可以$dfs$找到虚树中最深的节点把它拿出来。而此时的树根必然是一个标记点。
如果是$2$ , 那么就是以$1$为根挂下两根链,我们直接$dfs$找出这两根链的最深处的两个节点。
相信只要找出这一条链,大家就不难考虑求出最终的答案了。
需要特殊考虑的是,只有$0$个节点或者$1$个节点的情况,这需要分类讨论。
对于$1$个节点的情况,可以考虑补集转化,即求出子树中的路径个数,用总路径个数减去。
这样,本题的时间复杂度就是美妙且暴力的$O(n \ log_2 \ n)$
> $Ps : $ 不知道蒟蒻写个这个虚树为什么跑的那么慢,下面的代码需要开$O2$才能cao过$3s$时限....
```cpp
# include<bits/stdc++.h>
# define int long long
using namespace std;
namespace fast_IO{
const int IN_LEN = 10000000, OUT_LEN = 10000000;
char ibuf[IN_LEN], obuf[OUT_LEN], *ih = ibuf + IN_LEN, *oh = obuf, *lastin = ibuf + IN_LEN, *lastout = obuf + OUT_LEN - 1;
inline char getchar_(){return (ih == lastin) && (lastin = (ih = ibuf) + fread(ibuf, 1, IN_LEN, stdin), ih == lastin) ? EOF : *ih++;}
inline void putchar_(const char x){if(oh == lastout) fwrite(obuf, 1, oh - obuf, stdout), oh = obuf; *oh ++= x;}
inline void flush(){fwrite(obuf, 1, oh - obuf, stdout);}
int read(){
int x = 0; int zf = 1; char ch = ' ';
while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar_();
if (ch == '-') zf = -1, ch = getchar_();
while (ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar_(); return x * zf;
}
void write(int x){
if (x < 0) putchar_('-'), x = -x;
if (x > 9) write(x / 10);
putchar_(x % 10 + '0');
}
}
using namespace fast_IO;
const int N=1e6+10;
struct rec{ int pre,to;}a[N<<1];
vector<int>v[N],E[N];
stack<int>s;
bool ok[N];
int n,tot;
int head[N],size[N],is[N],dep[N],dfn[N],du[N];
int g[N][22];
void adde(int u,int v)
{
a[++tot].pre=head[u];
a[tot].to=v;
head[u]=tot;
}
void dfs1(int u,int fa) {
size[u]=1; g[u][0]=fa;
dep[u]=dep[fa]+1; dfn[u]=++dfn[0];
for (int i=head[u];i;i=a[i].pre) {
int v=a[i].to; if (v==fa) continue;
dfs1(v,u); size[u]+=size[v];
}
}
int lca(int u,int v) {
if (dep[u]<dep[v]) swap(u,v);
for (int i=21;i>=0;i--)
if (dep[g[u][i]]>=dep[v]) u=g[u][i];
if (u==v) return u;
for (int i=21;i>=0;i--)
if (g[u][i]!=g[v][i]) u=g[u][i],v=g[v][i];
return g[u][0];
}
int Ref;
void clear(int u,int fa) {
ok[u]=0; du[u]=0;
for (int i=0;i<(int)E[u].size();i++) {
int v=E[u][i]; if (v==fa) continue;
clear(v,u);
}
E[u].clear();
}
void dfs2(int u,int fa) {
if (ok[u]) Ref=u;
for (int i=0;i<(int)E[u].size();i++) {
int v=E[u][i]; if (v==fa) continue;
dfs2(v,u);
}
}
bool cmp(int a,int b) {
return dfn[a]<dfn[b];
}
bool work(int col) {
while (s.size()) s.pop(); int cnt=0;
for (int i=0;i<(int)v[col].size();i++)
is[++cnt]=v[col][i],ok[is[cnt]]=1;
int Min=is[1],Max=is[1];
for (int i=2;i<=cnt;i++) {
if (dfn[Min]>dfn[is[i]]) Min=is[i];
if (dfn[Max]<dfn[is[i]]) Max=is[i];
}
int rt=lca(Min,Max);
sort(is+1,is+1+cnt,cmp);
s.push(rt);
int res = 0; bool oo = true;
for (int i=1;i<=cnt;i++) if (is[i]!=rt) {
int u=is[i],l=lca(u,s.top());
while (s.top()!=l) {
int tmp=s.top();s.pop();
if (dfn[s.top()]<dfn[l]) s.push(l);
res++;
E[s.top()].push_back(tmp);
E[tmp].push_back(s.top());
du[s.top()]++; du[tmp]++;
res++;
if (du[s.top()]>2 || du[tmp]>2) {
oo=false;
}
}
s.push(u);
}
while (s.top()!=rt) {
int tmp=s.top();s.pop();
E[s.top()].push_back(tmp);
E[tmp].push_back(s.top());
res++;
du[s.top()]++; du[tmp]++;
if (du[s.top()]>2 || du[tmp]>2) {
oo=false;
}
}
if (!oo) {
clear(rt,0); return false;
}
if (E[rt].size()==1) {
int u=rt,v; dfs2(rt,0); v=Ref;
int s1=size[v];
for (int i=21;i>=0;i--)
if (dep[g[v][i]]>=dep[u]+1) v=g[v][i];
int s2=n-size[v];
write(s1*s2); putchar_('\n');
} else {
int u,v;
dfs2(E[rt][0],rt); u=Ref;
dfs2(E[rt][1],rt); v=Ref;
write(size[u]*size[v]); putchar_('\n');
}
clear(rt,0);
return true;
}
signed main()
{
n=read();
for (int i=1;i<=n;i++) {
int t=read();
v[t].push_back(i);
}
for (int i=2;i<=n;i++) {
int u=read(),v=read();
adde(u,v); adde(v,u);
}
dfs1(1,0);
for (int i=1;i<=21;i++)
for (int j=1;j<=n;j++)
g[j][i]=g[g[j][i-1]][i-1];
for (int c=1;c<=n;c++) {
if (v[c].size()==0) {
write(n*(n-1)/2); putchar_('\n');
continue;
}
if (v[c].size()==1) {
int u=v[c][0],tt=0,ret=n*(n-1)/2;
for (int i=head[u];i;i=a[i].pre) {
int v=a[i].to; if (v==g[u][0]) continue;
ret-=size[v]*(size[v]-1)/2;
tt+=size[v];
}
ret-=(n-tt-1)*(n-tt-2)/2;
write(ret); putchar_('\n');
continue;
}
if (!work(c)) write(0),putchar_('\n');
}
flush();
return 0;
}
```