题解 P3233 【[HNOI2014]世界树】
虚树的关键不在于你看出它是虚树,而是建完虚树之后怎么搞
看到
然后就不是太会做了/kk
首先本题要计算整棵树中的答案,所以我们尝试把答案拆成 2 部分:
-
虚树上的点对答案的贡献
记
by_x 离x 最近的点的编号。by_x 显然可以通过两遍 dfs 求出(第一遍 dfs 求出x 子树内的贡献,用儿子更新父亲,第二遍 dfs 求出x 子树外的贡献,用父亲更新儿子,类似于换根 dp),然后令所有by_x 加1 即可。 -
不在虚树上的点对答案的贡献
这一部分比较复杂,我们不妨画个图来帮助我们理解。
如图,蓝色节点为关键点,绿色的边为虚树组成的边。显然,除了节点
1,2,6,7,8,17,20 之外其它节点都不在虚树上。我们又可把这些节点分为两类:
Ⅰ. 虚树上某个节点的儿子的子树(即图中的黄色节点)。假设我们考虑关键节点
u 这部分的贡献,那么这个贡献显然可以表示为sz_u-1-\sum\limits_{v\in son_u}sz_v ,也就是整个u 子树的大小扣掉在虚树上的儿子的子树大小。但是这里我们不能直接枚举虚树上的儿子。比方说我们要计算节点2 那部分黄色子树的贡献。节点2 在虚树上的唯一儿子为节点6 。而直接拿sz_2-1-sz_6 显然是不行的,正确的结果应该是sz_2-1-sz_3 。故这里我们要减去u 在v 方向的直接儿子。那么这个直接儿子怎么求呢?借鉴 CF916E 的套路,求出v 的dep_v-dep_u-1 级祖先就行了。Ⅱ. 虚树上某两个节点
u,v 之间的点及其子树内的节点,(即图中的粉色节点)继续分情况:
①.
by_u=by_v ,那么显然u,v 所有节点都属于by_v 。找出u 在v 方向的直接儿子s ,并令sz_{by_u} 加上sz_s-sz_v ②.
by_u\neq by_v ,显然by_u 与by_v 分别位于边(u,v) 的两侧。画张比较清晰的图:
二分枚举断点
真是道恶心的题啊
#include <bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define fz(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
#define ffe(it,v) for(__typeof(v.begin()) it=v.begin();it!=v.end();it++)
#define fill0(a) memset(a,0,sizeof(a))
#define fill1(a) memset(a,-1,sizeof(a))
#define fillbig(a) memset(a,63,sizeof(a))
#define pb push_back
#define ppb pop_back
#define mp make_pair
template<typename T1,typename T2> void chkmin(T1 &x,T2 y){if(x>y) x=y;}
template<typename T1,typename T2> void chkmax(T1 &x,T2 y){if(x<y) x=y;}
typedef pair<int,int> pii;
typedef long long ll;
template<typename T> void read(T &x){
x=0;char c=getchar();T neg=1;
while(!isdigit(c)){if(c=='-') neg=-1;c=getchar();}
while(isdigit(c)) x=x*10+c-'0',c=getchar();
x*=neg;
}
const int MAXN=3e5;
const int LOG_N=19;
int n,qu;
namespace graph{
int nxt[MAXN*2+5],to[MAXN*2+5],hd[MAXN+5],ec=0;
void adde(int u,int v){to[++ec]=v;nxt[ec]=hd[u];hd[u]=ec;}
int fa[MAXN+5][LOG_N+2],dep[MAXN+5],dfn[MAXN+5],tim=0,sz[MAXN+5];
void dfs(int x,int f){
dfn[x]=++tim;fa[x][0]=f;sz[x]=1;
for(int e=hd[x];e;e=nxt[e]){
int y=to[e];if(y==f) continue;
dep[y]=dep[x]+1;dfs(y,x);sz[x]+=sz[y];
}
}
int getlca(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
for(int i=LOG_N;~i;i--) if(dep[x]-(1<<i)>=dep[y]) x=fa[x][i];
if(x==y) return x;
for(int i=LOG_N;~i;i--) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
int getfa(int x,int k){for(int i=LOG_N;~i;i--) if(k>>i&1) x=fa[x][i];return x;}
int getdis(int x,int y){return dep[x]+dep[y]-(dep[getlca(x,y)]<<1);}
void prework(){
dfs(1,0);
for(int i=1;i<=LOG_N;i++) for(int j=1;j<=n;j++)
fa[j][i]=fa[fa[j][i-1]][i-1];
}
}
using graph::dep;
using graph::sz;
using graph::getlca;
using graph::getfa;
using graph::getdis;
namespace virt{
bool cmp(int x,int y){return graph::dfn[x]<graph::dfn[y];}
int hd[MAXN+5],to[MAXN+5],nxt[MAXN+5],cst[MAXN+5],ec=0;
void adde(int u,int v){to[++ec]=v;cst[ec]=dep[v]-dep[u];nxt[ec]=hd[u];hd[u]=ec;}
int pt[MAXN+5],pn,ori[MAXN+5];bool mark[MAXN+5];
int stk[MAXN],tp=0;
void insert(int x){
if(!tp){stk[++tp]=x;return;}
int lc=getlca(x,stk[tp]);
while(tp>=2&&dep[stk[tp-1]]>dep[lc]) adde(stk[tp-1],stk[tp]),tp--;
if(tp&&dep[stk[tp]]>dep[lc]) adde(lc,stk[tp--]);
if(!tp||lc!=stk[tp]) stk[++tp]=lc;stk[++tp]=x;
}
void fin(){while(tp>=2) adde(stk[tp-1],stk[tp]),tp--;stk[tp--]=0;}
void build(){
sort(pt+1,pt+pn+1,cmp);if(!mark[1]) insert(1);
for(int i=1;i<=pn;i++) insert(pt[i]);fin();
}
int by[MAXN+5],ans[MAXN+5];
void dfs1(int x){
// printf("%d\n",x);
if(!mark[x]) by[x]=-1;
else by[x]=x;
for(int e=hd[x];e;e=nxt[e]){
int y=to[e];dfs1(y);
if(by[y]!=-1){
if(by[x]==-1) by[x]=by[y];
else{
int d1=getdis(x,by[y]),d2=getdis(x,by[x]);
if(d1<d2||(d1==d2&&by[x]>by[y])) by[x]=by[y];
}
}
}
}
void dfs2(int x){
for(int e=hd[x];e;e=nxt[e]){
int y=to[e];
if(by[y]==-1) by[y]=by[x];
else{
int d1=getdis(y,by[x]),d2=getdis(y,by[y]);
if(d1<d2||(d1==d2&&by[y]>by[x])) by[y]=by[x];
} dfs2(y);
}
// printf("%d %d\n",x,by[x]);
}
void dfs3(int x){
ans[by[x]]+=sz[x];
for(int e=hd[x];e;e=nxt[e]){
int y=to[e];ans[by[x]]-=sz[getfa(y,dep[y]-dep[x]-1)];
dfs3(y);
}
}
void dfs4(int x){
for(int e=hd[x];e;e=nxt[e]){
int y=to[e];
if(by[x]==by[y]){
ans[by[x]]+=sz[getfa(y,dep[y]-dep[x]-1)]-sz[y];
} else {
int d1=getdis(x,by[x]),d2=getdis(y,by[y]),len=dep[y]-dep[x]-1;
int l=0,r=len,p=len+1;
while(l<=r){
int mid=(l+r)>>1;
if((d1+mid<d2+len+1-mid)||(d1+mid==d2+len+1-mid&&by[x]<by[y])) p=mid,l=mid+1;
else r=mid-1;
}
// printf("%d %d %d\n",x,y,p);
if(p==0) ans[by[y]]+=sz[getfa(y,len)]-sz[y];
else if(p==len+1) ans[by[x]]+=sz[getfa(y,len)]-sz[y];
else{
ans[by[x]]+=sz[getfa(y,len)]-sz[getfa(y,len-p)];
ans[by[y]]+=sz[getfa(y,len-p)]-sz[y];
}
} dfs4(y);
}
}
void clear(int x){
by[x]=ans[x]=0;
for(int e=hd[x];e;e=nxt[e]) clear(to[e]),ec--;
hd[x]=0;
}
void work(){
scanf("%d",&pn);for(int i=1;i<=pn;i++) scanf("%d",&pt[i]),ori[i]=pt[i],mark[pt[i]]=1;
build();dfs1(1);dfs2(1);dfs3(1);dfs4(1);
for(int i=1;i<=pn;i++) printf("%d ",ans[ori[i]]);printf("\n");clear(1);
for(int i=1;i<=pn;i++) mark[pt[i]]=0;
}
}
int main(){
scanf("%d",&n);
for(int i=1;i<n;i++){
int u,v;scanf("%d%d",&u,&v);
graph::adde(u,v);graph::adde(v,u);
} graph::prework();scanf("%d",&qu);
while(qu--) virt::work();
return 0;
}
/*
7
1 2
2 3
2 4
1 5
4 6
5 7
1
3 2 5 6
*/