题解【P6556 The Forest】

· · 题解

呃呃了,在以为其他题解做麻烦的前提下写了写发现假了,结果优化成了和其他人一样的做法。

题目大意

给定 n 个点的两棵树 A,B,求有多少个点集满足将点集内的点按照树上的边连边后,在 A 树上形成一个联通块,在 B 树上形成一条链。

$T = 3,1\leq n\leq10^5$。 ### 题目分析 先考虑一个性质,对于一个树上的点集 $T$,若其内部边的数量为 $x$,那么这个点集的联通块数是 $|T|-x$,证明考虑一开始每个点都单独一个联通块,每次连一条边就是把两个联通块合并成一个。 有了这个性质如何做呢?这启发了我们维护联通块数。 首先考虑特殊性质,对于 $B$ 树是链,等价于要求点集是个区间,所以考虑扫描线,设 $s_i$ 表示在当前扫描线右端点在 $r$,左端点在 $i$ 时,这个区间点集在 $A$ 树上有几个联通块,答案显然是区间内 $1$ 的个数,由于区间最小值一定最小为 $1$,所以可以直接维护区间最小值个数。 如何转移,考虑由 $r$ 推到 $r+1$,此时对于 $[1,r+1]$ 来说都新加入了一个点,所以区间加 $1$,然后对于 $A$ 树上的一条边 $(u,r+1)$ 满足 $u<r+1$ 来说,$[1,u]$ 的 $s_i$ 对应的点集内一定有这条边,所以区间减 $1$ 即可,答案就是所有版本的 $1$ 的个数和。 考虑扩展到树上,如何将区间转换成链,不难想到令每个点作为根,求出每个点到根路径形成的点集在 $A$ 树上的联通块个数,不妨设 $f_i$ 表示这个,答案会算多,原因是对于一条合法的链 $(u,v)$ 在 $u,v$ 为根时都会计算一遍,所以要去掉,注意 $(u,u)$ 不会算重。 接下来的内容默认会换根意义下的区间加减,若不会请去[遥远的国度](https://www.luogu.com.cn/problem/P3979)。 假设当前根为 $u$,要换到他的儿子 $v$,如何转移 $s_i$,令 $W(x,y)$ 表示以 $x$ 为根时,$y$ 的子树表达的点集。 首先由于 $v$ 提到了根的位置,所以除了 $W(u,v)$ 以外的所有点,所对应的点集都插入了一个点,区间加,同理 $W(u,v)$ 整体少了一个点。 现在考虑新的边的贡献,$v$ 对于 $W(u,v)$ 的贡献在 $u$ 为根的时候已经统计过了,所以对于 $A$ 树边 $(v,x)$,若 $x \notin W(u,v)$ 那么就对 $W(v,x)$ 进行一次子树减,因为这部分都会被这条边影响。同理我们也要删除 $u$ 的在 $W(u,v)$ 内的 $A$ 树上邻居的贡献,但是发现每次换根都枚举一圈 $A$ 树的邻居,总的枚举个数就成了两树度数的平方。 注意到对于 $u$ 需要删掉的贡献只有在 $W(u,v)$ 内的,容易发现对于 $u$ 每个儿子,其子树区间不相交,所以我们可以将 $A$ 边按照 $B$ 树的 dfs 序排序,这样的话每个贡献只会增减各一次。 ### Code 注意要做到从 $v$ 版本回溯到 $u$,所以记录下来操作反着做一遍即可。 对于最开始的 $1$ 号版本,可以暴力预处理出来初始情况。 ```cpp #include <iostream> #include <algorithm> #include <cstdio> #include <cstring> #include <cmath> #include <queue> #define int long long using namespace std; int read(){ int x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f; } const int N = 1e5+5; int n; vector<int>in[N],ed[N]; int minx,cnt,rt; struct seg{ struct aa{ int lc,rc,mi,sum,tag; void clear(){ lc = rc = mi = sum = tag = 0; } }node[N*2]; void pushup(int u){ aa x = node[node[u].lc],y = node[node[u].rc]; node[u].mi = min(x.mi,y.mi); node[u].sum = (x.mi==node[u].mi?x.sum:0)+(y.mi==node[u].mi?y.sum:0); } int tot; int newnode(){ int u = ++tot; node[u].clear(); return u; } void build(int &u,int l,int r){ u = newnode(); node[u].sum = (r-l+1); if(l==r){ return; } int mid = (l+r)/2; build(node[u].lc,l,mid); build(node[u].rc,mid+1,r); } void lazy_tag(int u,int x){ node[u].mi+=x; node[u].tag+=x; } void pushdown(int u){ if(!node[u].tag){ return; } lazy_tag(node[u].lc,node[u].tag); lazy_tag(node[u].rc,node[u].tag); node[u].tag = 0; } void upd(int u,int l,int r,int ll,int rr,int x){ if(l==ll&&r==rr){ lazy_tag(u,x); return; } pushdown(u); int mid = (l+r)/2; if(rr<=mid){ upd(node[u].lc,l,mid,ll,rr,x); }else if(ll>mid){ upd(node[u].rc,mid+1,r,ll,rr,x); }else{ upd(node[u].lc,l,mid,ll,mid,x); upd(node[u].rc,mid+1,r,mid+1,rr,x); } pushup(u); } void ask(int u,int l,int r,int ll,int rr){ if(l==ll&&r==rr){ if(node[u].mi<minx){ minx = node[u].mi; cnt = node[u].sum; }else if(node[u].mi==minx){ cnt+=node[u].sum; } return; } pushdown(u); int mid = (l+r)/2; if(rr<=mid){ ask(node[u].lc,l,mid,ll,rr); }else if(ll>mid){ ask(node[u].rc,mid+1,r,ll,rr); }else{ ask(node[u].lc,l,mid,ll,mid); ask(node[u].rc,mid+1,r,mid+1,rr); } } }T; int siz[N],dep[N],son[N],fa[N],top[N],dfn[N],tt; bool cmp(int a,int b){ return dfn[a]<dfn[b]; } void dfs1(int u,int f){ siz[u] = 1; son[u] = 0; for(auto x:in[u]){ if(x==f){ continue; } fa[x] = u; dep[x] = dep[u]+1; dfs1(x,u); siz[u]+=siz[x]; if(siz[x]>siz[son[u]]){ son[u] = x; } } } void dfs2(int u,int t){ top[u] = t; dfn[u] = ++tt; if(!son[u]){ return; } dfs2(son[u],t); for(auto x:in[u]){ if(x==fa[u]||x==son[u]){ continue; } dfs2(x,x); } } int Lca(int u,int v){ while(top[u]!=top[v]){ if(dep[top[u]]<dep[top[v]]){ swap(u,v); } u = fa[top[u]]; } if(dep[u]<dep[v]){ swap(u,v); } return v; } int query(){ minx = 1e9; cnt = 0; T.ask(rt,1,n,1,n); if(minx==1){ return cnt; }else{ return 0; } } struct bb{ int l,r,x; }; vector<bb>op[N]; void add(int u,int l,int r,int x){ op[u].push_back((bb){l,r,x}); T.upd(rt,1,n,l,r,x); } int RT,res; int get(int u,int x){ while(top[u]!=top[x]){ if(fa[top[u]]==x){ return top[u]; } u = fa[top[u]]; } return son[x]; } void dfs(int u){ if(u!=1){ add(u,1,n,1); add(u,dfn[u],dfn[u]+siz[u]-1,-2); for(auto x:ed[u]){ if(dfn[u]<=dfn[x]&&dfn[x]<=dfn[u]+siz[u]-1){ continue; } if(dfn[x]<=dfn[u]&&dfn[u]<=dfn[x]+siz[x]-1){ int v = get(u,x); add(u,1,n,-1); add(u,dfn[v],dfn[v]+siz[v]-1,1); }else{ add(u,dfn[x],dfn[x]+siz[x]-1,-1); } } } res+=query(); int r = 0; int sz = ed[u].size(); while(r<sz){ int y = ed[u][r]; if(dfn[y]<dfn[u]||dfn[y]>dfn[u]+siz[u]-1){ r++; }else{ break; } } for(auto x:in[u]){ if(x==fa[u]){ continue; } int R = r; while(r<sz){ int y = ed[u][r]; if(dfn[x]<=dfn[y]&&dfn[y]<=dfn[x]+siz[x]-1){ T.upd(rt,1,n,dfn[y],dfn[y]+siz[y]-1,1); r++; }else{ break; } } dfs(x); for(int i=R;i<r;i++){ int y = ed[u][i]; T.upd(rt,1,n,dfn[y],dfn[y]+siz[y]-1,-1); } } for(auto x:op[u]){ T.upd(rt,1,n,x.l,x.r,-x.x); } } int U[N],V[N]; void init(){ n = read(); T.tot = 0;rt = 0;res = 0;tt = 0; for(int i=1;i<=n;i++){ in[i].clear();ed[i].clear();op[i].clear(); } for(int i=1;i<n;i++){ int u,v; u = read();v = read(); U[i] = u;V[i] = v; } for(int i=1;i<n;i++){ int u,v; u = read();v = read(); in[u].push_back(v); in[v].push_back(u); } dfs1(1,1); dfs2(1,1); T.build(rt,1,n); for(int i=1;i<=n;i++){ T.upd(rt,1,n,dfn[i],dfn[i]+siz[i]-1,1); } for(int i=1;i<n;i++){ int u,v; u = U[i];v = V[i]; ed[u].push_back(v); ed[v].push_back(u); if(dep[u]>dep[v]){ swap(u,v); } int L = Lca(u,v); if(u==L){ T.upd(rt,1,n,dfn[v],dfn[v]+siz[v]-1,-1); } } for(int i=1;i<=n;i++){ sort(ed[i].begin(),ed[i].end(),cmp); sort(in[i].begin(),in[i].end(),cmp); } RT = 1; dfs(1); cout<<(res-n)/2+n<<"\n"; } signed main(){ int T = read(); while(T--){ init(); } return 0; } ```