题解【P6556 The Forest】
yizhiming
·
·
题解
呃呃了,在以为其他题解做麻烦的前提下写了写发现假了,结果优化成了和其他人一样的做法。
题目大意
给定 n 个点的两棵树 A,B,求有多少个点集满足将点集内的点按照树上的边连边后,在 A 树上形成一个联通块,在 B 树上形成一条链。
$T = 3,1\leq n\leq10^5$。
### 题目分析
先考虑一个性质,对于一个树上的点集 $T$,若其内部边的数量为 $x$,那么这个点集的联通块数是 $|T|-x$,证明考虑一开始每个点都单独一个联通块,每次连一条边就是把两个联通块合并成一个。
有了这个性质如何做呢?这启发了我们维护联通块数。
首先考虑特殊性质,对于 $B$ 树是链,等价于要求点集是个区间,所以考虑扫描线,设 $s_i$ 表示在当前扫描线右端点在 $r$,左端点在 $i$ 时,这个区间点集在 $A$ 树上有几个联通块,答案显然是区间内 $1$ 的个数,由于区间最小值一定最小为 $1$,所以可以直接维护区间最小值个数。
如何转移,考虑由 $r$ 推到 $r+1$,此时对于 $[1,r+1]$ 来说都新加入了一个点,所以区间加 $1$,然后对于 $A$ 树上的一条边 $(u,r+1)$ 满足 $u<r+1$ 来说,$[1,u]$ 的 $s_i$ 对应的点集内一定有这条边,所以区间减 $1$ 即可,答案就是所有版本的 $1$ 的个数和。
考虑扩展到树上,如何将区间转换成链,不难想到令每个点作为根,求出每个点到根路径形成的点集在 $A$ 树上的联通块个数,不妨设 $f_i$ 表示这个,答案会算多,原因是对于一条合法的链 $(u,v)$ 在 $u,v$ 为根时都会计算一遍,所以要去掉,注意 $(u,u)$ 不会算重。
接下来的内容默认会换根意义下的区间加减,若不会请去[遥远的国度](https://www.luogu.com.cn/problem/P3979)。
假设当前根为 $u$,要换到他的儿子 $v$,如何转移 $s_i$,令 $W(x,y)$ 表示以 $x$ 为根时,$y$ 的子树表达的点集。
首先由于 $v$ 提到了根的位置,所以除了 $W(u,v)$ 以外的所有点,所对应的点集都插入了一个点,区间加,同理 $W(u,v)$ 整体少了一个点。
现在考虑新的边的贡献,$v$ 对于 $W(u,v)$ 的贡献在 $u$ 为根的时候已经统计过了,所以对于 $A$ 树边 $(v,x)$,若 $x \notin W(u,v)$ 那么就对 $W(v,x)$ 进行一次子树减,因为这部分都会被这条边影响。同理我们也要删除 $u$ 的在 $W(u,v)$ 内的 $A$ 树上邻居的贡献,但是发现每次换根都枚举一圈 $A$ 树的邻居,总的枚举个数就成了两树度数的平方。
注意到对于 $u$ 需要删掉的贡献只有在 $W(u,v)$ 内的,容易发现对于 $u$ 每个儿子,其子树区间不相交,所以我们可以将 $A$ 边按照 $B$ 树的 dfs 序排序,这样的话每个贡献只会增减各一次。
### Code
注意要做到从 $v$ 版本回溯到 $u$,所以记录下来操作反着做一遍即可。
对于最开始的 $1$ 号版本,可以暴力预处理出来初始情况。
```cpp
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <queue>
#define int long long
using namespace std;
int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
const int N = 1e5+5;
int n;
vector<int>in[N],ed[N];
int minx,cnt,rt;
struct seg{
struct aa{
int lc,rc,mi,sum,tag;
void clear(){
lc = rc = mi = sum = tag = 0;
}
}node[N*2];
void pushup(int u){
aa x = node[node[u].lc],y = node[node[u].rc];
node[u].mi = min(x.mi,y.mi);
node[u].sum = (x.mi==node[u].mi?x.sum:0)+(y.mi==node[u].mi?y.sum:0);
}
int tot;
int newnode(){
int u = ++tot;
node[u].clear();
return u;
}
void build(int &u,int l,int r){
u = newnode();
node[u].sum = (r-l+1);
if(l==r){
return;
}
int mid = (l+r)/2;
build(node[u].lc,l,mid);
build(node[u].rc,mid+1,r);
}
void lazy_tag(int u,int x){
node[u].mi+=x;
node[u].tag+=x;
}
void pushdown(int u){
if(!node[u].tag){
return;
}
lazy_tag(node[u].lc,node[u].tag);
lazy_tag(node[u].rc,node[u].tag);
node[u].tag = 0;
}
void upd(int u,int l,int r,int ll,int rr,int x){
if(l==ll&&r==rr){
lazy_tag(u,x);
return;
}
pushdown(u);
int mid = (l+r)/2;
if(rr<=mid){
upd(node[u].lc,l,mid,ll,rr,x);
}else if(ll>mid){
upd(node[u].rc,mid+1,r,ll,rr,x);
}else{
upd(node[u].lc,l,mid,ll,mid,x);
upd(node[u].rc,mid+1,r,mid+1,rr,x);
}
pushup(u);
}
void ask(int u,int l,int r,int ll,int rr){
if(l==ll&&r==rr){
if(node[u].mi<minx){
minx = node[u].mi;
cnt = node[u].sum;
}else if(node[u].mi==minx){
cnt+=node[u].sum;
}
return;
}
pushdown(u);
int mid = (l+r)/2;
if(rr<=mid){
ask(node[u].lc,l,mid,ll,rr);
}else if(ll>mid){
ask(node[u].rc,mid+1,r,ll,rr);
}else{
ask(node[u].lc,l,mid,ll,mid);
ask(node[u].rc,mid+1,r,mid+1,rr);
}
}
}T;
int siz[N],dep[N],son[N],fa[N],top[N],dfn[N],tt;
bool cmp(int a,int b){
return dfn[a]<dfn[b];
}
void dfs1(int u,int f){
siz[u] = 1;
son[u] = 0;
for(auto x:in[u]){
if(x==f){
continue;
}
fa[x] = u;
dep[x] = dep[u]+1;
dfs1(x,u);
siz[u]+=siz[x];
if(siz[x]>siz[son[u]]){
son[u] = x;
}
}
}
void dfs2(int u,int t){
top[u] = t;
dfn[u] = ++tt;
if(!son[u]){
return;
}
dfs2(son[u],t);
for(auto x:in[u]){
if(x==fa[u]||x==son[u]){
continue;
}
dfs2(x,x);
}
}
int Lca(int u,int v){
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]){
swap(u,v);
}
u = fa[top[u]];
}
if(dep[u]<dep[v]){
swap(u,v);
}
return v;
}
int query(){
minx = 1e9;
cnt = 0;
T.ask(rt,1,n,1,n);
if(minx==1){
return cnt;
}else{
return 0;
}
}
struct bb{
int l,r,x;
};
vector<bb>op[N];
void add(int u,int l,int r,int x){
op[u].push_back((bb){l,r,x});
T.upd(rt,1,n,l,r,x);
}
int RT,res;
int get(int u,int x){
while(top[u]!=top[x]){
if(fa[top[u]]==x){
return top[u];
}
u = fa[top[u]];
}
return son[x];
}
void dfs(int u){
if(u!=1){
add(u,1,n,1);
add(u,dfn[u],dfn[u]+siz[u]-1,-2);
for(auto x:ed[u]){
if(dfn[u]<=dfn[x]&&dfn[x]<=dfn[u]+siz[u]-1){
continue;
}
if(dfn[x]<=dfn[u]&&dfn[u]<=dfn[x]+siz[x]-1){
int v = get(u,x);
add(u,1,n,-1);
add(u,dfn[v],dfn[v]+siz[v]-1,1);
}else{
add(u,dfn[x],dfn[x]+siz[x]-1,-1);
}
}
}
res+=query();
int r = 0;
int sz = ed[u].size();
while(r<sz){
int y = ed[u][r];
if(dfn[y]<dfn[u]||dfn[y]>dfn[u]+siz[u]-1){
r++;
}else{
break;
}
}
for(auto x:in[u]){
if(x==fa[u]){
continue;
}
int R = r;
while(r<sz){
int y = ed[u][r];
if(dfn[x]<=dfn[y]&&dfn[y]<=dfn[x]+siz[x]-1){
T.upd(rt,1,n,dfn[y],dfn[y]+siz[y]-1,1);
r++;
}else{
break;
}
}
dfs(x);
for(int i=R;i<r;i++){
int y = ed[u][i];
T.upd(rt,1,n,dfn[y],dfn[y]+siz[y]-1,-1);
}
}
for(auto x:op[u]){
T.upd(rt,1,n,x.l,x.r,-x.x);
}
}
int U[N],V[N];
void init(){
n = read();
T.tot = 0;rt = 0;res = 0;tt = 0;
for(int i=1;i<=n;i++){
in[i].clear();ed[i].clear();op[i].clear();
}
for(int i=1;i<n;i++){
int u,v;
u = read();v = read();
U[i] = u;V[i] = v;
}
for(int i=1;i<n;i++){
int u,v;
u = read();v = read();
in[u].push_back(v);
in[v].push_back(u);
}
dfs1(1,1);
dfs2(1,1);
T.build(rt,1,n);
for(int i=1;i<=n;i++){
T.upd(rt,1,n,dfn[i],dfn[i]+siz[i]-1,1);
}
for(int i=1;i<n;i++){
int u,v;
u = U[i];v = V[i];
ed[u].push_back(v);
ed[v].push_back(u);
if(dep[u]>dep[v]){
swap(u,v);
}
int L = Lca(u,v);
if(u==L){
T.upd(rt,1,n,dfn[v],dfn[v]+siz[v]-1,-1);
}
}
for(int i=1;i<=n;i++){
sort(ed[i].begin(),ed[i].end(),cmp);
sort(in[i].begin(),in[i].end(),cmp);
}
RT = 1;
dfs(1);
cout<<(res-n)/2+n<<"\n";
}
signed main(){
int T = read();
while(T--){
init();
}
return 0;
}
```