题解 P3233 【[HNOI2014]世界树】

· · 题解

虚树的关键不在于你看出它是虚树,而是建完虚树之后怎么搞

看到 \sum m_i\leq 10^5,果断建出虚树。

然后就不是太会做了/kk

首先本题要计算整棵树中的答案,所以我们尝试把答案拆成 2 部分:

  1. 虚树上的点对答案的贡献

    by_xx 最近的点的编号。by_x 显然可以通过两遍 dfs 求出(第一遍 dfs 求出 x 子树内的贡献,用儿子更新父亲,第二遍 dfs 求出 x 子树外的贡献,用父亲更新儿子,类似于换根 dp),然后令所有 by_x1 即可。

  2. 不在虚树上的点对答案的贡献

    这一部分比较复杂,我们不妨画个图来帮助我们理解。

    如图,蓝色节点为关键点,绿色的边为虚树组成的边。显然,除了节点 1,2,6,7,8,17,20 之外其它节点都不在虚树上。

    我们又可把这些节点分为两类:

    Ⅰ. 虚树上某个节点的儿子的子树(即图中的黄色节点)。假设我们考虑关键节点 u 这部分的贡献,那么这个贡献显然可以表示为 sz_u-1-\sum\limits_{v\in son_u}sz_v,也就是整个 u 子树的大小扣掉在虚树上的儿子的子树大小。但是这里我们不能直接枚举虚树上的儿子。比方说我们要计算节点 2 那部分黄色子树的贡献。节点 2 在虚树上的唯一儿子为节点 6。而直接拿 sz_2-1-sz_6 显然是不行的,正确的结果应该是 sz_2-1-sz_3。故这里我们要减去 uv 方向的直接儿子。那么这个直接儿子怎么求呢?借鉴 CF916E 的套路,求出 vdep_v-dep_u-1 级祖先就行了。

    Ⅱ. 虚树上某两个节点 u,v 之间的点及其子树内的节点,(即图中的粉色节点)

    继续分情况:

    ①. by_u=by_v,那么显然 u,v 所有节点都属于 by_v。找出 uv 方向的直接儿子 s,并令 sz_{by_u} 加上 sz_s-sz_v

    ②. by_u\neq by_v,显然 by_uby_v 分别位于边 (u,v) 的两侧。画张比较清晰的图:

​ 二分枚举断点 p,qp 及上面部分都属于 by_uq 及下面部分都属于 by_v。那么这条路径上的点对 by_u 贡献 就是 sz_s-sz_q,对 by_v 的贡献就是 sz_q-sz_v

真是道恶心的题啊

#include <bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define fz(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
#define ffe(it,v) for(__typeof(v.begin()) it=v.begin();it!=v.end();it++)
#define fill0(a) memset(a,0,sizeof(a))
#define fill1(a) memset(a,-1,sizeof(a))
#define fillbig(a) memset(a,63,sizeof(a))
#define pb push_back
#define ppb pop_back
#define mp make_pair
template<typename T1,typename T2> void chkmin(T1 &x,T2 y){if(x>y) x=y;}
template<typename T1,typename T2> void chkmax(T1 &x,T2 y){if(x<y) x=y;}
typedef pair<int,int> pii;
typedef long long ll;
template<typename T> void read(T &x){
    x=0;char c=getchar();T neg=1;
    while(!isdigit(c)){if(c=='-') neg=-1;c=getchar();}
    while(isdigit(c)) x=x*10+c-'0',c=getchar();
    x*=neg;
}
const int MAXN=3e5;
const int LOG_N=19;
int n,qu;
namespace graph{
    int nxt[MAXN*2+5],to[MAXN*2+5],hd[MAXN+5],ec=0;
    void adde(int u,int v){to[++ec]=v;nxt[ec]=hd[u];hd[u]=ec;}
    int fa[MAXN+5][LOG_N+2],dep[MAXN+5],dfn[MAXN+5],tim=0,sz[MAXN+5];
    void dfs(int x,int f){
        dfn[x]=++tim;fa[x][0]=f;sz[x]=1;
        for(int e=hd[x];e;e=nxt[e]){
            int y=to[e];if(y==f) continue;
            dep[y]=dep[x]+1;dfs(y,x);sz[x]+=sz[y];
        }
    }
    int getlca(int x,int y){
        if(dep[x]<dep[y]) swap(x,y);
        for(int i=LOG_N;~i;i--) if(dep[x]-(1<<i)>=dep[y]) x=fa[x][i];
        if(x==y) return x;
        for(int i=LOG_N;~i;i--) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
        return fa[x][0];
    }
    int getfa(int x,int k){for(int i=LOG_N;~i;i--) if(k>>i&1) x=fa[x][i];return x;}
    int getdis(int x,int y){return dep[x]+dep[y]-(dep[getlca(x,y)]<<1);}
    void prework(){
        dfs(1,0);
        for(int i=1;i<=LOG_N;i++) for(int j=1;j<=n;j++)
            fa[j][i]=fa[fa[j][i-1]][i-1];
    }
}
using graph::dep;
using graph::sz;
using graph::getlca;
using graph::getfa;
using graph::getdis;
namespace virt{
    bool cmp(int x,int y){return graph::dfn[x]<graph::dfn[y];}
    int hd[MAXN+5],to[MAXN+5],nxt[MAXN+5],cst[MAXN+5],ec=0;
    void adde(int u,int v){to[++ec]=v;cst[ec]=dep[v]-dep[u];nxt[ec]=hd[u];hd[u]=ec;}
    int pt[MAXN+5],pn,ori[MAXN+5];bool mark[MAXN+5];
    int stk[MAXN],tp=0;
    void insert(int x){
        if(!tp){stk[++tp]=x;return;}
        int lc=getlca(x,stk[tp]);
        while(tp>=2&&dep[stk[tp-1]]>dep[lc]) adde(stk[tp-1],stk[tp]),tp--;
        if(tp&&dep[stk[tp]]>dep[lc]) adde(lc,stk[tp--]);
        if(!tp||lc!=stk[tp]) stk[++tp]=lc;stk[++tp]=x;
    }
    void fin(){while(tp>=2) adde(stk[tp-1],stk[tp]),tp--;stk[tp--]=0;}
    void build(){
        sort(pt+1,pt+pn+1,cmp);if(!mark[1]) insert(1);
        for(int i=1;i<=pn;i++) insert(pt[i]);fin();
    }
    int by[MAXN+5],ans[MAXN+5];
    void dfs1(int x){
//      printf("%d\n",x);
        if(!mark[x]) by[x]=-1;
        else by[x]=x;
        for(int e=hd[x];e;e=nxt[e]){
            int y=to[e];dfs1(y);
            if(by[y]!=-1){
                if(by[x]==-1) by[x]=by[y];
                else{
                    int d1=getdis(x,by[y]),d2=getdis(x,by[x]);
                    if(d1<d2||(d1==d2&&by[x]>by[y])) by[x]=by[y];
                }
            }
        }
    }
    void dfs2(int x){
        for(int e=hd[x];e;e=nxt[e]){
            int y=to[e]; 
            if(by[y]==-1) by[y]=by[x];
            else{
                int d1=getdis(y,by[x]),d2=getdis(y,by[y]);
                if(d1<d2||(d1==d2&&by[y]>by[x])) by[y]=by[x];
            } dfs2(y);
        }
//      printf("%d %d\n",x,by[x]);
    }
    void dfs3(int x){
        ans[by[x]]+=sz[x];
        for(int e=hd[x];e;e=nxt[e]){
            int y=to[e];ans[by[x]]-=sz[getfa(y,dep[y]-dep[x]-1)];
            dfs3(y);
        }
    }
    void dfs4(int x){
        for(int e=hd[x];e;e=nxt[e]){
            int y=to[e];
            if(by[x]==by[y]){
                ans[by[x]]+=sz[getfa(y,dep[y]-dep[x]-1)]-sz[y];
            } else {
                int d1=getdis(x,by[x]),d2=getdis(y,by[y]),len=dep[y]-dep[x]-1;
                int l=0,r=len,p=len+1;
                while(l<=r){
                    int mid=(l+r)>>1;
                    if((d1+mid<d2+len+1-mid)||(d1+mid==d2+len+1-mid&&by[x]<by[y])) p=mid,l=mid+1;
                    else r=mid-1;
                }
//              printf("%d %d %d\n",x,y,p);
                if(p==0) ans[by[y]]+=sz[getfa(y,len)]-sz[y];
                else if(p==len+1) ans[by[x]]+=sz[getfa(y,len)]-sz[y];
                else{
                    ans[by[x]]+=sz[getfa(y,len)]-sz[getfa(y,len-p)];
                    ans[by[y]]+=sz[getfa(y,len-p)]-sz[y];
                }
            } dfs4(y);
        }
    }
    void clear(int x){
        by[x]=ans[x]=0;
        for(int e=hd[x];e;e=nxt[e]) clear(to[e]),ec--;
        hd[x]=0;
    }
    void work(){
        scanf("%d",&pn);for(int i=1;i<=pn;i++) scanf("%d",&pt[i]),ori[i]=pt[i],mark[pt[i]]=1;
        build();dfs1(1);dfs2(1);dfs3(1);dfs4(1);
        for(int i=1;i<=pn;i++) printf("%d ",ans[ori[i]]);printf("\n");clear(1);
        for(int i=1;i<=pn;i++) mark[pt[i]]=0;
    }
}
int main(){
    scanf("%d",&n);
    for(int i=1;i<n;i++){
        int u,v;scanf("%d%d",&u,&v);
        graph::adde(u,v);graph::adde(v,u);
    } graph::prework();scanf("%d",&qu);
    while(qu--) virt::work();
    return 0;
}
/*
7
1 2
2 3
2 4
1 5
4 6
5 7
1
3 2 5 6
*/