题解 P5666 【树的重心【民间数据】】

wucstdio

2019-11-18 08:54:03

Solution

考场这道题的思考时间比前两道题都短…… 果然是学的太多思维僵化了。 首先,题目中给的式子可以等价为枚举一个点,然后查看这个点会作为多少棵子树的重心。 接下来我们的问题就是对每一个点 $x$ ,求出这个点会被统计多少次,也就是求出有多少割掉一棵子树的方案,使得 $x$ 成为重心。 首先拎出来一个重心当做树根,然后 dfs 一遍。 如果 $x$ 不是重心,那么就需要在它父亲的那颗子树中切掉一棵子树: ![](https://cdn.luogu.com.cn/upload/image_hosting/ygohn2wr.png) 我们设 $x$ 这棵子树大小为 $s$ , $x$ 的子节点中最大的子树大小为 $m$ ,那么切掉的点 $del$ 需要满足如下式子: $$\begin{cases}n-s-size[del]\le \dfrac{n-size[del]}{2}\\m\le \dfrac{n-size[del]}{2}\end{cases}$$ 化简可以得到 $$n-2s\le size[del]\le n-2m$$ 我们用一个树状数组记录这个点父节点的那颗子树中 $size=i$ 的子树数量。在 dfs 的过程中,每 dfs 到一个点 $x$ 就让 $c[size[x]]-1,c[n-size[x]]+1$ ,也就是当我们割掉 $f$ 与 $x$ 的边的时候切掉的子树大小从 $size[x]$ 变成了 $n-size[x]$ 。统计答案就是区间求和。 但是还有一个问题。这样一来我们统计的答案有可能会包含在 $x$ 的子树中删掉一棵子树的情况,这样是不满足条件的。 为了解决这个问题,我们可以再 dfs 一遍,这一次我们用线段树合并求出每一个点的子树中 $size=i$ 的子树数量。这里统计的是不合法的方案,从答案里面减去就行了。 最后还剩下 $x$ 是重心的情况。由于重心只有最多两个,我们可以先以重心为根对整棵树进行dfs,统计出 $x$ 这个点最大的子树 $m_1$ 和次大的子树 $m_2$ 。 接下来分情况讨论。如果我们是在最大的子树里面切割,就需要满足 $$m_2\le \dfrac{n-size[del]}{2}$$ 即 $$size[del]\le n-2m_2$$ 否则我们需要满足 $$m_1\le \dfrac{n-size[del]}{2}$$ 即 $$size[del]\le n-2m_1$$ 我们可以在线段树合并的过程中顺便求出这个答案。 总时间复杂度是 $O(n\log n)$ ,在考场电脑上不写读入优化 $3$s 左右,写读入优化 $2.5$s 。 下面是代码: ```cpp #include<cstdio> #include<algorithm> #include<cstring> #define ll long long #define lson tree[x].child[0] #define rson tree[x].child[1] #define mid (l+r)/2 using namespace std; struct Edge { int to; int nxt; }e[600005]; struct Tree { int child[2]; int sum; }tree[10000005]; int n,m,edgenum,tot,head[300005],pa[300005],size[300005],s[300005],root[300005]; ll ans; bool isroot[300005]; int read() { char c=(char)getchar(); while(c>'9'||c<'0')c=(char)getchar(); int t=0; while(c>='0'&&c<='9') { t=t*10+c-'0'; c=(char)getchar(); } return t; } void add(int u,int v) { e[++edgenum].to=v; e[edgenum].nxt=head[u]; head[u]=edgenum; } void dfs_pre(int node)//预处理,找出重心 { size[node]=1; for(int hd=head[node];hd;hd=e[hd].nxt) { int to=e[hd].to; if(to==pa[node])continue; pa[to]=node; dfs_pre(to); size[node]+=size[to]; if(size[to]>n/2)isroot[node]=0; } if(n-size[node]>n/2)isroot[node]=0; } void dfs1(int node)//求出所有的size { size[node]=1; for(int hd=head[node];hd;hd=e[hd].nxt) { int to=e[hd].to; if(to==pa[node])continue; pa[to]=node; dfs1(to); size[node]+=size[to]; } } void add(int p) { while(p<=n) { s[p]++; p+=p^(p&(p-1)); } } void dec(int p) { while(p<=n) { s[p]--; p+=p^(p&(p-1)); } } int sum(int p) { int ans=0; while(p) { ans+=s[p]; p&=p-1; } return ans; } int sum(int l,int r) { if(l>r)return 0; return sum(r)-sum(l-1); }//以上是树状数组 int merge(int x,int y) { if(!x||!y)return x+y; tree[x].sum+=tree[y].sum; lson=merge(lson,tree[y].child[0]); rson=merge(rson,tree[y].child[1]); return x; } void add(int x,int l,int r,int p) { tree[x].sum++; if(l==r)return; if(p<=mid) { if(!lson)lson=++tot; add(lson,l,mid,p); } else { if(!rson)rson=++tot; add(rson,mid+1,r,p); } } //void debug(int x,int l,int r) //{ // if(l==r) // { // printf("%lld ",tree[x].sum); // return; // } // debug(lson,l,mid); // debug(rson,mid+1,r); //} ll sum(int x,int l,int r,int from,int to) { if(from>to)return 0; if(!x)return 0; if(l>=from&&r<=to)return tree[x].sum; ll ans=0; if(from<=mid)ans+=sum(lson,l,mid,from,to); if(to>mid)ans+=sum(rson,mid+1,r,from,to); return ans; }//以上是线段树合并 void dfs2(int node) { // printf("%d:\n",node); int max1=0,max2=0; for(int hd=head[node];hd;hd=e[hd].nxt) { int to=e[hd].to; if(to==pa[node])continue; if(size[to]>max1)max2=max1,max1=size[to]; else if(size[to]>max2)max2=size[to]; } // printf("max1=%d,max2=%d\n",max1,max2); if(isroot[node]&&pa[node]) { if(n-size[node]>max1)max2=max1,max1=n-size[node]; else if(n-size[node]>max2)max2=n-size[node]; if(n-size[node]==max1)ans+=1ll*node*sum(1,n-2*max2); else ans+=1ll*node*sum(1,n-2*max1); } for(int hd=head[node];hd;hd=e[hd].nxt) { int to=e[hd].to; if(to==pa[node])continue; pa[to]=node; dec(size[to]); add(n-size[to]); dfs2(to); // printf("to=%d,root=%d,",to,root[to]); // debug(root[to],1,n); // printf("\n"); if(isroot[node]) { // printf("%d->%d,range=",node,to); // if(size[to]==max1)printf("%d %d,res=%lld\n",1,n-2*max2,sum(root[to],1,n,1,n-2*max2)); // else printf("%d %d,res=%lld\n",1,n-2*max1,sum(root[to],1,n,1,n-2*max1)); if(size[to]==max1)ans+=1ll*node*sum(root[to],1,n,1,n-2*max2); else ans+=1ll*node*sum(root[to],1,n,1,n-2*max1); } root[node]=merge(root[node],root[to]); add(size[to]); dec(n-size[to]); } if(!isroot[node]) { // printf("node=%d:tot=%lld-%lld\n",node,sum(max(1,n-2*size[node]),min(n,n-2*max1)),sum(root[node],1,n,max(1,n-2*size[node]),min(n,n-2*max1))); ans+=1ll*node*sum(max(1,n-2*size[node]),min(n,n-2*max1)); ans-=1ll*node*sum(root[node],1,n,max(1,n-2*size[node]),min(n,n-2*max1)); } if(isroot[node]&&pa[node]) { if(n-size[node]==max1)ans-=1ll*node*sum(root[node],1,n,1,n-2*max2); else ans-=1ll*node*sum(root[node],1,n,1,n-2*max1); } add(root[node]?root[node]:root[node]=++tot,1,n,size[node]); // printf("node=%d:\n",node); // printf("tree:"); // debug(root[node],1,n); // printf("\n"); // printf("array:"); // for(int i=1;i<=n;i++)printf("%lld ",sum(i)-sum(i-1)); // printf("\n"); } int t; int main() { // freopen("centroid.in","r",stdin); // freopen("centroid.out","w",stdout); t=read(); while(t--) { n=read(); ans=edgenum=0; for(int i=1;i<=tot;i++)tree[i].sum=tree[i].child[0]=tree[i].child[1]=0; for(int i=1;i<=n;i++)s[i]=head[i]=pa[i]=size[i]=root[i]=0,isroot[i]=1; tot=0; for(int i=1;i<n;i++) { int u=read(),v=read(); add(u,v); add(v,u); } dfs_pre(1); for(int i=1;i<=n;i++) { if(isroot[i]) { pa[i]=0; dfs1(i); for(int j=1;j<=n;j++) if(j!=i)add(size[j]); dfs2(i); break; } } printf("%lld\n",ans); } return 0; } ```