题解:CF2062E2 The Game (Hard Version)

· · 题解

[CF2062E2] The Game (Hard Version)

再阅读这篇题解前,请保证你已经会 $ \rm E1 $ 的解法,否则可以移步至 [这里](https://www.luogu.com.cn/article/u1cq5304)。 假设你第一步删除了 $ x $,然后对手按照 $ \rm E1 $ 的做法找到了比 $ x $ 权值大的且先手必胜的点 $ y $,那你就炸了。因此我们需要保证:再删除 $ x $ 后,所有权值比 $ x $ 大的 $ y $,要么 $ y $ 在 $ x $ 的子树内,要么所有权值比 $ y $ 大的且不在 $ y $ 子树内的点 $ z $ 都在 $ x $ 的子树内。 那么当扫到 $ y $ 时,只需要找到所有 $ z $ 的 $ \rm LCA $,设为点 $ a $。那么 $ x $ 要么是 $ y $ 的祖先,要么是 $ a $ 的祖先。可以再用一个 $ \rm BIT $ 来判断这个条件。 关于如何找到 $ a $,只需要用一个 $ \rm set $ 维护所有权值比 $ y $ 大的点的 $ \rm dfs $ 序,然后找到不在 $ y $ 的子树内的 $ \rm dfs $ 序最小的点 $ l $ 和最大的点 $ r $,则 $ a $ 就是 $ l $ 和 $ r $ 的 $ \rm LCA $。 最后复杂度是 $ O(n \log n) $。 --- ```cpp #include<bits/stdc++.h> #define ll long long #define pn putchar('\n') #define mset(a,x) memset(a,x,sizeof a) #define mcpy(a,b) memcpy(a,b,sizeof a) #define all(a) a.begin(),a.end() #define fls() fflush(stdout) using namespace std; int re() { int x=0; bool t=1; char ch=getchar(); while(ch>'9'||ch<'0') t=ch=='-'?0:t,ch=getchar(); while(ch>='0'&&ch<='9') x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); return t?x:-x; } const int maxn=1e6+5; int n,m; int dep[maxn]; int id[maxn],rid[maxn]; vector<int>a[maxn]; vector<int>e[maxn]; vector<int>ans; set<int>b; struct BIT { int tree[maxn]; void add(int x,int ad) { for(int i=x;i<=m;i+=i&-i) tree[i]+=ad; } int qry(int x) { int ret=0; for(int i=x;i;i-=i&-i) ret+=tree[i]; return ret; } int qry(int l,int r) { return qry(r)-qry(l-1); } }bit1,bit2; int st[20][maxn]; int min2(int u,int v) { return dep[u]<dep[v]?u:v; } int lca(int l,int r) { if(!l||!r) return st[0][l|r]; if(l>r) swap(l,r); int k=__lg(r-l+1); return min2(st[k][l],st[k][r-(1<<k)+1]); } void dfs(int pos,int fa) { m++; id[pos]=m; st[0][m]=pos; for(int v:e[pos]) { if(v==fa) continue; dep[v]=dep[pos]+1; dfs(v,pos); m++; st[0][m]=pos; } rid[pos]=m; } void solve() { n=re(); for(int i=1;i<=n;i++) { a[i].clear(); e[i].clear(); } for(int i=1;i<=n;i++) a[re()].push_back(i); for(int i=1;i<n;i++) { int u=re(),v=re(); e[u].push_back(v); e[v].push_back(u); } m=0; dfs(1,0); for(int j=1;j<=19;j++) { for(int i=1;i+(1<<j)-1<=m;i++) st[j][i]=min2(st[j-1][i],st[j-1][i+(1<<j-1)]); } for(int i=1;i<=m;i++) bit1.tree[i]=bit2.tree[i]=0; int cnt1=0,cnt2=0; b.clear(); ans.clear(); for(int k=n;k;k--) { for(int i:a[k]) { if(bit1.qry(id[i],rid[i])<cnt1&&bit2.qry(id[i],rid[i])==cnt2) ans.push_back(i); } for(int i:a[k]) { if(bit1.qry(id[i],rid[i])<cnt1) { auto l=b.lower_bound(id[i]),r=b.lower_bound(rid[i]); int t1=0,t2=0; if(l!=b.begin()) { l--; t1=lca(*b.begin(),*l); } if(r!=b.end()) t2=lca(*r,*b.rbegin()); int t=lca(id[t1],id[t2]); bit2.add(id[t],1); bit2.add(id[i],1); bit2.add(id[lca(id[t],id[i])],-1); cnt2++; } } for(int i:a[k]) { cnt1++; b.insert(id[i]); bit1.add(id[i],1); } } sort(all(ans)); printf("%d ",ans.size()); for(int i:ans) printf("%d ",i); pn; } signed main() { int T=re(); while(T--) solve(); return 0; } ```