P7768 丑国传说 · 收税(En:Tax Collection) 题解

· · 题解

初学主席树的人(例如我这个蒟蒻)可以来刷一下这个题,使用主席树的话还是考察了一定思维的,虽然加了快读后最后一个点 993ms qwq。

Problem

题目大意:给定一棵树,每个节点有一个权值,多次询问,每次询问一个节点 x 的子树中深度小于等于 h 的权值异或和。

数据范围:n,q \leq 10^6,a_i \leq 10^9

Solution

首先我们可以想到,因为要处理这种子树中的问题,多半是要用到 dfs 序把子树中的问题转换成区间中的问题。

同时我们可以发现,我们需要统计的是深度,所以我们可以开一个可持久化权值线段树(主席树)来维护同一个深度上的点的异或值。然后我们考虑对于每一个 dfs 序开一棵主席树,可以发现如果 x 的子树的 dfs 序是 [dfn_x,ed_x] 的话,由 a\ xor\ a=0 可以得到,统计 x 子树中所有深度小于等于 h (以 x 为根,所以也就等价于深度在 [dep_x,dep_x+h] 之间)的异或和,也就是统计 [1,dfn_x-1] 之间满足条件的异或和,再异或上 [1,ed_x] 之间满足条件的异或和,有点类似于前缀和的原理,只不过是异或。

所以做法就很显然了,先处理出每一个点的 dfn_i,ed_i,dep_i,然后对于每一个 dfs 序开一棵主席树(注意这里是 dfs 序,而不是原本的序号)用来维护同一深度上的所有权值的异或值。查询的时候查找第 ed_x,dfn_x-1 两个版本在区间 dep_x,dep_x+h 之间的异或值,然后异或一下就可以得到答案了。

Code

代码如下:

#include<bits/stdc++.h>
using namespace std;
const int N=1e6+7232;
//真正的飞速快读,但让并不绝对,一般都挺快的,不过这个题没有普通快读快
const int Mt=1e5;
inline char getc()
{
    static char buf[Mt],*p1=buf,*p2=buf;
    return p1==p2&&(p2=(p1=buf)+fread(buf,1,Mt,stdin),p1==p2)?EOF:*p1++;
}
inline int read()
{
    int r=0,f=1;char c=getc();
    while(!isdigit(c)){if(c=='-')f=-1;c=getc();}
    while(isdigit(c))r=(r<<1)+(r<<3)+(c^48),c=getc();
    return r*f;
}
int n,m,x,y,maax;
int a[N];
//连边
int h[N],cnt;
struct hl{
    int v,nxt;
}e[N];
void add(int u,int v)
{
    e[++cnt].v=v,e[cnt].nxt=h[u],h[u]=cnt;
}
//dfs部分,预处理出dfn,ed,dep,re(re[i]代表dfs序为i的数在原来序列中的编号,下面会用到)
int num,dfn[N],ed[N],dep[N],re[N];
int mx(int x,int y){return x>y?x:y;}
int mi(int x,int y){return x<y?x:y;}
void dfs(int x,int fx)
{
    dfn[x]=++num;dep[x]=dep[fx]+1;re[num]=x;
    maax=mx(maax,dep[x]);
    for(int i=h[x];i;i=e[i].nxt) dfs(e[i].v,x);
    ed[x]=num;
}
//主席树
int root[N],tot;
struct len{
    int l,r,sum;
}t[N<<4];
int insert(int pre,int l,int r,int x,int y)
{
    int p=++tot;
    t[p].l=t[pre].l,t[p].r=t[pre].r,t[p].sum=(t[pre].sum^y);
    if(l<r)
    {
        int mid=l+r>>1;
        if(x<=mid) t[p].l=insert(t[pre].l,l,mid,x,y);
        else t[p].r=insert(t[pre].r,mid+1,r,x,y);
    }
    return p;
}
int ask(int u,int v,int l,int r,int x,int y)
{
    if(l==x&&r==y) return t[u].sum^t[v].sum;
    int mid=l+r>>1;
    if(x>mid) return ask(t[u].r,t[v].r,mid+1,r,x,y);//当整个区间都在右边时
    else if(y<=mid) return ask(t[u].l,t[v].l,l,mid,x,y);//当整个区间都在左边时
    else return ask(t[u].l,t[v].l,l,mid,x,mid)^ask(t[u].r,t[v].r,mid+1,r,mid+1,y);//区间两边都有分布
}
int main()
{
    n=read(),m=read();
    for(int i=1;i<=n;i++) a[i]=read();
    for(int i=2;i<=n;i++) x=read(),add(x,i);
    dfs(1,0);
    for(int i=1;i<=n;i++) root[i]=insert(root[i-1],1,maax,dep[re[i]],a[re[i]]);//每一次是根据dfs序建树,但是我们需要用到原序列中的dep,a值,所以要用到re数组
    for(int i=1;i<=m;i++)
    {
        x=read(),y=read();
        printf("%.3lf\n",0.001*ask(root[ed[x]],root[dfn[x]-1],1,maax,dep[x],mi(dep[x]+y,maax)));
        //                   两个版本异或一下,在ask函数中完成了    总区间    要查询的区间范围
    }
}