题解:P10107 [GDKOI2023 提高组] 树
sheepchat
·
·
题解
CSP 模拟赛(?)搬了,有同学过了,太强了。
很巧妙的题,这里介绍第一篇题解的做法。
距离是假的其实是深度之差,下面都这样认为。
考虑跑个 dfs 序,记 low_u 为子树 u 中的最大时间戳,dfn_u 为 u 的时间戳。
那么考虑这样一件事情,如果查询 (x,k),我们会将其理解为满足 dep_u\in [dep_x,dep_x+k],dfn_u\in[dfn_x,low_x) 的点的贡献。
不妨差分一下,记 pre_u 为和 u 同层的点中,它的 dfn 序上的前驱。
满足 dep_u\in [dep_x,dep_x+k],low_{pre_u}<dfn_u\le low_x 的点的贡献,依然是没问题的,这个可以分为两段前缀,下面我们考虑求 dep_u\in [dep_x,dep_x+k],dfn_u\le low_x 点的贡献,因为这样把 x 换成 pre_x 再求一遍,一减就是答案。
(这是个典 trick 吗?我想不通为啥题解区都觉得这个很显然,问同学,同学:“做多题就知道了。”)
那么我们就转变问题了,但是 dep_u\in [dep_x,dep_x+k] 依然不好算。而且深度这一维不好做相同的操作。
那假设 k 正好是 2^p-1 呢?考虑前 2^{p-1}-1,是不变的,而且后面 2^{p-1} 相当于给他异或上 2^{p-1} 这里记下这一块有多少个 p-1 位上为 1,那么也可以计算。
欸,我们发现上面过程也可以拓展到非 2^p-1 的形式,且我们考虑记下 f_{i,x},问题为 (x,2^i-1) 时候的答案。(不是原问题哦,而是转化完之后的)。
考虑如何去求,记 rs_u 为 u 最后遍历的儿子标号,没儿子就是 rs_{pr_u},那么考虑类似倍增 LCA 的求解方式,同理设 rs_{i,u} 为 u \gets rs_u 这个操作重复 2^i 后的节点。
这是可减的,所以说考虑记录 $cnt_u$ 表示满足 $dep_x\le dep_u,dfn_x\le low_u$ 的点数,减一下即可解决,对于每一位也可以用类似的方法。
求这个可以用类似前缀和的方法。
做完了,时空都为 $O(n\log n)$。
代码:
```cpp
#include<bits/stdc++.h>
using namespace std;
int n;
int dfn[1000005],tot,a[1000005];
int pr[1000005];
long long f[25][1000005];
int rs[25][1000005],dep[1000005];
int fa[1000005],low[1000005],cnt[1000005],res[35][1000005],sum[35][1000005],sum2[1000005];
vector<int>vec[1000005];
int last[1000005];
void dfs(int u){
low[u]=dfn[u]=++tot;
pr[u]=last[dep[u]];
last[dep[u]]=u;
sum2[u]=cnt[u]=sum2[pr[u]]+1;
for(int i=0;i<32;i++)sum[i][u]=((a[u]>>i)&1)+sum[i][pr[u]],res[i][u]=((a[u]>>i)&1)+sum[i][pr[u]];
int lastv=0;
for(int i=0;i<vec[u].size();i++){
int v=vec[u][i];
dep[v]=dep[u]+1;
lastv=v;
dfs(v);
low[u]=max(low[u],low[v]);
}
if(lastv==0)rs[0][u]=rs[0][pr[u]];
else rs[0][u]=lastv;
for(int i=0;i<32;i++)res[i][u]+=res[i][rs[0][u]];
f[0][u]=f[0][pr[u]]+a[u];
cnt[u]+=cnt[rs[0][u]];
}
long long solve(int x,int k){
long long ans=0,pos=x;
k++;
for(int i=20;i>=0;i--)if(((k>>i)&1))pos=rs[i][pos];
for(int i=20;i>=0;i--){
if(!((k>>i)&1))continue;
k^=(1<<i);
ans+=f[i][x]+1ll*(1<<i)*(cnt[rs[i][x]]-cnt[pos]-2*(res[i][rs[i][x]]-res[i][pos]));
x=rs[i][x];
}
return ans;
}
int main(){
cin>>n;
for(int i=1;i<=n;i++)cin>>a[i];
for(int i=2;i<=n;i++){
cin>>fa[i];
vec[fa[i]].push_back(i);
}
dep[1]=1;
dfs(1);
for(int i=1;i<=20;i++){
for(int j=1;j<=n;j++){
rs[i][j]=rs[i-1][rs[i-1][j]];
f[i][j]=f[i-1][j]+f[i-1][rs[i-1][j]]+1ll*(1<<(i-1))*(cnt[rs[i-1][j]]-cnt[rs[i][j]]-2*(res[i-1][rs[i-1][j]]-res[i-1][rs[i][j]]));
}
}
int q;
cin>>q;
while(q--){
int x,k;
cin>>x>>k;
cout<<solve(x,k)-solve(pr[x],k)<<'\n';
}
}
```