题解:P10107 [GDKOI2023 提高组] 树

· · 题解

CSP 模拟赛(?)搬了,有同学过了,太强了。

很巧妙的题,这里介绍第一篇题解的做法。

距离是假的其实是深度之差,下面都这样认为。

考虑跑个 dfs 序,记 low_u 为子树 u 中的最大时间戳,dfn_uu 的时间戳。

那么考虑这样一件事情,如果查询 (x,k),我们会将其理解为满足 dep_u\in [dep_x,dep_x+k],dfn_u\in[dfn_x,low_x) 的点的贡献。

不妨差分一下,记 pre_u 为和 u 同层的点中,它的 dfn 序上的前驱。

满足 dep_u\in [dep_x,dep_x+k],low_{pre_u}<dfn_u\le low_x 的点的贡献,依然是没问题的,这个可以分为两段前缀,下面我们考虑求 dep_u\in [dep_x,dep_x+k],dfn_u\le low_x 点的贡献,因为这样把 x 换成 pre_x 再求一遍,一减就是答案。

(这是个典 trick 吗?我想不通为啥题解区都觉得这个很显然,问同学,同学:“做多题就知道了。”)

那么我们就转变问题了,但是 dep_u\in [dep_x,dep_x+k] 依然不好算。而且深度这一维不好做相同的操作。

那假设 k 正好是 2^p-1 呢?考虑前 2^{p-1}-1,是不变的,而且后面 2^{p-1} 相当于给他异或上 2^{p-1} 这里记下这一块有多少个 p-1 位上为 1,那么也可以计算。

欸,我们发现上面过程也可以拓展到非 2^p-1 的形式,且我们考虑记下 f_{i,x},问题为 (x,2^i-1) 时候的答案。(不是原问题哦,而是转化完之后的)。

考虑如何去求,记 rs_uu 最后遍历的儿子标号,没儿子就是 rs_{pr_u},那么考虑类似倍增 LCA 的求解方式,同理设 rs_{i,u}u \gets rs_u 这个操作重复 2^i 后的节点。

这是可减的,所以说考虑记录 $cnt_u$ 表示满足 $dep_x\le dep_u,dfn_x\le low_u$ 的点数,减一下即可解决,对于每一位也可以用类似的方法。 求这个可以用类似前缀和的方法。 做完了,时空都为 $O(n\log n)$。 代码: ```cpp #include<bits/stdc++.h> using namespace std; int n; int dfn[1000005],tot,a[1000005]; int pr[1000005]; long long f[25][1000005]; int rs[25][1000005],dep[1000005]; int fa[1000005],low[1000005],cnt[1000005],res[35][1000005],sum[35][1000005],sum2[1000005]; vector<int>vec[1000005]; int last[1000005]; void dfs(int u){ low[u]=dfn[u]=++tot; pr[u]=last[dep[u]]; last[dep[u]]=u; sum2[u]=cnt[u]=sum2[pr[u]]+1; for(int i=0;i<32;i++)sum[i][u]=((a[u]>>i)&1)+sum[i][pr[u]],res[i][u]=((a[u]>>i)&1)+sum[i][pr[u]]; int lastv=0; for(int i=0;i<vec[u].size();i++){ int v=vec[u][i]; dep[v]=dep[u]+1; lastv=v; dfs(v); low[u]=max(low[u],low[v]); } if(lastv==0)rs[0][u]=rs[0][pr[u]]; else rs[0][u]=lastv; for(int i=0;i<32;i++)res[i][u]+=res[i][rs[0][u]]; f[0][u]=f[0][pr[u]]+a[u]; cnt[u]+=cnt[rs[0][u]]; } long long solve(int x,int k){ long long ans=0,pos=x; k++; for(int i=20;i>=0;i--)if(((k>>i)&1))pos=rs[i][pos]; for(int i=20;i>=0;i--){ if(!((k>>i)&1))continue; k^=(1<<i); ans+=f[i][x]+1ll*(1<<i)*(cnt[rs[i][x]]-cnt[pos]-2*(res[i][rs[i][x]]-res[i][pos])); x=rs[i][x]; } return ans; } int main(){ cin>>n; for(int i=1;i<=n;i++)cin>>a[i]; for(int i=2;i<=n;i++){ cin>>fa[i]; vec[fa[i]].push_back(i); } dep[1]=1; dfs(1); for(int i=1;i<=20;i++){ for(int j=1;j<=n;j++){ rs[i][j]=rs[i-1][rs[i-1][j]]; f[i][j]=f[i-1][j]+f[i-1][rs[i-1][j]]+1ll*(1<<(i-1))*(cnt[rs[i-1][j]]-cnt[rs[i][j]]-2*(res[i-1][rs[i-1][j]]-res[i-1][rs[i][j]])); } } int q; cin>>q; while(q--){ int x,k; cin>>x>>k; cout<<solve(x,k)-solve(pr[x],k)<<'\n'; } } ```