题解 P6729 【[WC2020] 有根树】

skydogli

2020-08-07 17:00:01

Solution

### update:感谢EI,证明出了更精确的上界。 机房神仙写了个做法几乎考场切掉了此题,我肝4.5h的分块只有70,自闭了/kk 这是一个时间复杂度上界 $O(n\sqrt {n} \log n)$,下界 $O(n\log^2 n)$ 的做法,但是可以通过此题。 其实现十分简单,使用树剖维护**dfn区间**上的最大最小值以及被标记点的数量,每次查询时只需判断当前解是否合法或当前解减一是否合法即可。而判断当前解 $res$ 合法事实上就是 $$\sum_{x\in S,\operatorname{siz_x}> res} \le res$$ 这本质是统计问题,而该做法统计的方式相当暴力:遍历所有线段树区间,如果当前区间无须往下递归则直接结束,无须往下递归的条件有(满足其一即可): 1. 当前统计到的数量大于所求 2. 当前区间没有标记点 3. 当前区间最小值大于要求 4. 当前区间最大值小于要求 不难发现,卡满其复杂度的难度极大且常数极小,而该复杂度的粗略上界 $O(n\sqrt{n\log n}\log n)$ 的证明如下: 首先,当我们找到一条链时,它是单调的,所以爆搜的复杂度是$O(\log n)$,因此我们考虑找到的链的数量,下面是一个粗略的证明 - 当 $res\le \sqrt{n\log n}$ 时,最多只会被找到$\sqrt{n\log n}$个区间就能达到条件1,就可以退出了。 - 当 $res\ge \sqrt{n\log n}$ 时,我们会查询若干个重链,现在就要证明我们查到的重链数量不超过$\sqrt {n\log n}$,考虑只有链顶的子树大小大于 $\sqrt{n\log n}$ 才需要统计次数(否则我们就退出了),而一个节点上方最多有 $\log$ 条重链,也就是所有链顶的子树大小之和的上界是 $O(n\log n)$ ,因此符合条件的链不超过 $\sqrt{n\log n}$ 该证明并不精确,因为我们直接认为所有链顶的siz大小相等,但我们证明所有链顶子树和$n\log n$是依据```上方log条重链```的,每往上条一次轻边siz是乘二的,显然不能满足链顶siz都相等的条件。 于是上U群求救,多项式&DS之神EI做出了 > 重链剖分中,子树大小大于$B$的链顶数量不超过$\frac{n}{B}$个 的证明: > EI原话: 就是我们把原来sz >B的结点保留,这个“macro tree”我们证明原树树链剖分的链顶在这颗树中被保留的数量就是结果这颗树的叶子数? >每个叶子显然需要在一个链中,刚刚展示了每个非叶子结点在重链中延伸的下一个结点必然是macro tree的某个孩子 蒟蒻的理解:意思是我们只留下子树siz大于$B$的点,这棵树上有多少链顶就是所求。 考虑重链剖分的性质,在新树中,每个链顶要么没有儿子,要么重儿子还在这棵树上,也就是肯定会跑到叶子节点,不可能在中途停下。而一个叶子要么是链顶要么是重儿子,所以新树中链顶数量等于叶子数量。 同时,叶子和叶子之间不可能共享子孙,也就是说新树的叶子数量不超过$\frac{n}{B}$ 因此,对于这题,可以仿照上面的证明得到复杂度上界为$O(n\sqrt{n}\log n)$ 关于实际运行效率,它跑得非常快(大概也因为出题人没卡),可以轻松获得95分的好成绩,实现不丑可以在loj获得满分,使用各位高超的卡常技巧可以在洛谷通过。我常数比较大,使用了标记永久化的办法卡常。 ```cpp #include<bits/stdc++.h> using namespace std; int read(){ int a=0;char c=getchar(); while(c>'9'||c<'0')c=getchar(); while('0'<=c&&c<='9'){ a=a*10+c-48; c=getchar(); } return a; } #define MN 500005 int n,q,deg[MN]; int siz[MN],fa[MN],dep[MN],id[MN],top[MN],w[MN],cnt,ID[MN]; int TOT; vector<int>edge[MN]; void dfs1(int x){ siz[x]=1; for(int i=0;i<edge[x].size();++i){ int v=edge[x][i]; if(fa[x]==v)continue; dep[v]=dep[x]+1; fa[v]=x; dfs1(v); siz[x]+=siz[v]; if(siz[w[x]]<siz[v])w[x]=v; } } void dfs2(int x){ id[x]=++cnt;ID[cnt]=x; if(w[x]){top[w[x]]=top[x];dfs2(w[x]);} for(int i=0;i<edge[x].size();++i){ int v=edge[x][i]; if(fa[x]==v||w[x]==v)continue; top[v]=v; dfs2(v); } } #define Ls (x<<1) #define Rs (x<<1|1) #define mid ((l+r)>>1) struct data{ int sum,Max,Min,tag; }W[MN<<2]; void ptag(int x,const int &v){ W[x].tag+=v; W[x].Max+=v; W[x].Min+=v; } void pd(int x){ if(W[x].tag){ if(W[Ls].sum){ ptag(Ls,W[x].tag); } if(W[Rs].sum){ ptag(Rs,W[x].tag); } W[x].tag=0; } } void pushup(int x){ W[x].Max=max(W[Ls].Max,W[Rs].Max)+W[x].tag; W[x].Min=min(W[Ls].Min,W[Rs].Min)+W[x].tag; } int ask(int x,int l,int r,int b,int e){ if(b<=l&&r<=e)return W[x].sum; if(l>e||r<b||!W[x].sum)return 0; pd(x); int res=0; if(b<=mid)res=ask(Ls,l,mid,b,e); if(e>mid)res+=ask(Rs,mid+1,r,b,e); return res; } int ned; void Ins(int x,int l,int r,int loc,int v){ W[x].sum+=v; if(l==r){ if(v==1){ W[x].Max=W[x].Min=ask(1,1,n,l,id[ID[l]]+siz[ID[l]]-1); } else{ W[x].Max=-1e9; W[x].Min=1e9; } return; } pd(x); if(loc<=mid)Ins(Ls,l,mid,loc,v); else Ins(Rs,mid+1,r,loc,v); pushup(x); } void change(int x,int l,int r,int b,int e,int v){ if(!W[x].sum)return; if(b<=l&&r<=e){ ptag(x,v); return; } if(b<=mid&&W[Ls].sum)change(Ls,l,mid,b,e,v); if(e>mid&&W[Rs].sum)change(Rs,mid+1,r,b,e,v); pushup(x); } void search(int x,int l,int r,int v){ if(!W[x].sum||TOT>=ned||W[x].Max<v)return; if(W[x].Min>=v){TOT+=W[x].sum;return;} search(Ls,l,mid,v-W[x].tag); search(Rs,mid+1,r,v-W[x].tag); } int main(){ n=read(); for(int i=1;i<=n*4;++i)W[i].Max=-1e9; for(int i=1;i<=n*4;++i)W[i].Min=1e9; bool flag=1; for(int i=1;i<n;++i){ int x=read(),y=read(); deg[x]++;deg[y]++; edge[x].push_back(y); edge[y].push_back(x); } q=read(); dep[1]=1;dfs1(1); top[1]=1;dfs2(1); int res=0,upd=0; for(int i=1;i<=q;++i){ int tp=read(),v=read(); if(tp==1){ Ins(1,1,n,id[v],1); v=fa[v]; while(v){ change(1,1,n,id[top[v]],id[v],1); v=fa[top[v]]; } TOT=0;ned=res+1; search(1,1,n,res+1); if(TOT>res) res++; } else{ Ins(1,1,n,id[v],-1); v=fa[v]; while(v){ change(1,1,n,id[top[v]],id[v],-1); v=fa[top[v]]; } TOT=0;ned=res; search(1,1,n,res); if(TOT<=res-1)res--; } printf("%d\n",res); } return 0; } ```