题解 P5305 【[GXOI/GZOI2019]旧词】

· · 题解

这道题目非常简洁:要求 \sum_{i\le x}depth(lca(i,y))^k (都写在题面里了)

要直接解决这个问题是有点困难的,那么——

我们先看它的弱化版:[LNOI2014]LCA

要求 \sum_{i=l}^rdepth(lca(i,z))

少了个k次方呢!

首先转化一下 \sum_{i=l}^rdepth(lca(i,z))=\sum_{i=1}^rdepth(lca(i,z))-\sum_{i=1}^{l-1}depth(lca(i,z))

可以用前缀和来解决这个问题。

把每个询问拆成 (1,l-1)(1,r) 分别解决。

那么我们可以按顺序把 1n 的点加进来,同时计算加到每个点时的询问答案。

那么现在考虑加到第 i 个点时如何计算贡献。

先上张图举个例子~

其中涂成红色的点就是已经加入的点。

我们先计算一下每个节点的子树中有多少已经加入的点,用 siz[x] 表示。没有加入的节点对当前询问肯定没有贡献。所以我们直接统计加入的节点就好了。

那么 siz[1]=4 , siz[2]=2, siz[3]=1, siz[4]=1 , siz[5]=1 , siz[6]=0

现在考虑计算当前加入的节点到4的LCA深度和。

ans=siz[4]\times dep[4]+(siz[2]-siz[4])\times dep[2]+(siz[1]-siz[2])\times dep[1]

即加入的节点中在 4 的子树中的到 4LCA 肯定为 4。在 2 的子树但不在 4 的子树中的到 4LCA 肯定为 2,在 1 的子树但不在 2 的子树中的到 4LCA 肯定为 1

把式子拆开化简得到:

ans=siz[4]\times (dep[4]-dep[2])+siz[2]\times (dep[2]-dep[1])+siz[1] ans=siz[4]+siz[2]+siz[1]

于是我们发现其实答案就是该节点到根路径的节点子树大小之和。

普遍形式:设查询节点到根的路径为 \{v_1,v_2,v_3\dots v_k\}

ans=siz[v_1]\times dep[v_1]+(siz[v_2]-siz[v_1])\times dep[v_2]+(siz[v_3]-siz[v_2])\times dep[v_3]+\dots +(siz[v_k]-siz[v_{k-1}])\times dep[v_k]

由于是到根节点的路径,所以 dep[v_i]=dep[v_{i+1}]+1

于是化简得:

ans=siz[v_1]\times (dep[v_1]-dep[v_2])+siz[v_2]\times (dep[v_2]-dep[v_3])+siz[v_3]\times (dep[v_3]-dep[v_4])+\dots +siz[v_k]\times dep[v_k] ans=siz[v_1]+siz[v_2]+siz[v_3]+\dots +siz[v_k]

就此得出结论:答案就是该节点到根路径的节点子树大小之和。

查询直接查询要查询的节点到根路径上的节点的子树大小之和。

加入节点时只有它到根的路径上的节点的子树大小加了一。

使用树链剖分加线段树维护即可。

回到本题

本题没有 l 的限制,所以我们不用前缀和拆询问处理了。

那么我们抬出之前的式子改成这道题的样子。

ans=siz[v1]\times dep[v1]^k+(siz[v2]-siz[v1])\times dep[v2]^k+(siz[v3]-siz[v2])\times dep[v3]^k+\dots+(siz[vk]-siz[vk-1])\times dep[vk]^k

化简一下,得到:

ans=siz[v_1]\times (dep[v_1]^k-dep[v_2]^k)+siz[v_2]\times (dep[v_2]^k-dep[v_3]^k)+siz[v_3]\times (dep[v_3]^k-dep[v4]^k)+\dots siz[v_k]\times dep[v_k]^k

可以发现,dep[v_2] 总等于 dep[v_1]-1(因为是从查询节点到根节点的路径嘛)。

所以对于(dep[v_i]^k-dep[v_{i+1}]^k)这种东西,我们可以对它进行预处理。

设对于点i的预处理答案为 val[i]

则答案为:

ans=siz[v_1]\times val[v_1]+siz[v_2]\times val[v_2]+siz[v_3]\times val[v_3]+...siz[v_k]\times val[v_k]

就是在线段树维护时多加一个权值的问题了。

在线段树的每个节点上附加一个权值,表示这个区间里所有点的 val 之和。

这样就可以区间修改和下传标记了。

跟上面那道弱化版差得不多吧~

代码:

//直接用 LCA 代码改的,可能有点迷惑(? 
#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
vector<long long> e[50010];
struct node
{
    long long bas,sum,tag;//bas为附加的权值
}nodes[400010];
struct ask
{
    long long pos,val,ID,nag;
}q[100010];
bool cmp(ask one,ask ano)
{
    return one.pos<ano.pos;
}
long long n,m,s,opl,opr,opz,f,cnt,tot,waste,ans[50010],kkk,dep[50010],fa[50010],son[50010],siz[50010],hb[50010],dfn[50010],ton[50010],power[50010],bruh;
const long long mod=998244353;
long long qpow(long long bas,long long tim)//快速幂用来处理val 
{
    long long res=1,fur=bas;
    while(tim)
    {
        if(tim&1)   res=(res*fur)%mod;
        fur=(fur*fur)%mod;
        tim>>=1;
    }
    return res;
}
void dfs(long long x,long long las)
{
    fa[x]=las;
    dep[x]=dep[las]+1;
    siz[x]=1;
    long long b=0,s=0;
    for(long long i=0;i<e[x].size();++i)
    {
        long long y=e[x][i];
        dfs(y,x);
        siz[x]+=siz[y];
        if(siz[y]>b)
        {
            b=siz[y];
            s=y;
        }
    }
    son[x]=s;
}
void dfs2(long long x,long long las,long long heavy)//树剖dfs
{
    if(heavy)   hb[x]=hb[las];
    else    hb[x]=x;
    dfn[x]=++cnt;
    ton[cnt]=x;
    if(son[x])  dfs2(son[x],x,1);
    for(long long i=0;i<e[x].size();++i)
    {
        long long y=e[x][i];
        if(y^son[x])    dfs2(y,x,0);
    }
}
void pushdown(long long x)
{
    if(nodes[x].tag)
    {
        //(siz[x]+a)*val[x]+(siz[y]+a)*val[y]+...+(siz[z]+a)*val[z]
        //=siz[x]*val[x]+siz[y]*val[y]+...+siz[z]*val[z]+a*(val[x]+val[y]+...+val[z])
        nodes[x].sum+=(nodes[x].bas*nodes[x].tag);
        nodes[x].sum%=mod;
        nodes[x<<1].tag+=nodes[x].tag;
        nodes[x<<1].tag%=mod;
        nodes[x<<1|1].tag+=nodes[x].tag;
        nodes[x<<1|1].tag%=mod;
        nodes[x].tag=0;
    }
}
void build(long long l,long long r,long long x)//预处理val区间和 
{
    if(l^r)
    {
        long long mid=(l+r)>>1;
        build(l,mid,x<<1);
        build(mid+1,r,x<<1|1);
        nodes[x].bas=(nodes[x<<1].bas+nodes[x<<1|1].bas)%mod;
    }
    else    nodes[x].bas=power[dep[ton[l]]]-power[dep[ton[l]]-1];
//  printf(" %lld %lld %lld\n",l,r,nodes[x].bas);
} 
void update(long long l,long long r,long long x,long long fr,long long ba)
{
    if(l>ba||r<fr)  return;
    if(l>=fr&&r<=ba)    nodes[x].tag=(nodes[x].tag+1)%mod;
    else
    {
        pushdown(x);
        long long mid=(l+r)>>1;
        update(l,mid,x<<1,fr,ba);
        update(mid+1,r,x<<1|1,fr,ba);
        pushdown(x<<1);
        pushdown(x<<1|1);
        nodes[x].sum=(nodes[x<<1].sum+nodes[x<<1|1].sum)%mod;
    }
}
long long find(long long l,long long r,long long x,long long fr,long long ba)
{
    if(l>ba||r<fr)  return 0;
    pushdown(x);
    if(l>=fr&&r<=ba)    return nodes[x].sum;
    else
    {
        long long mid=(l+r)>>1;
        return (find(l,mid,x<<1,fr,ba)+find(mid+1,r,x<<1|1,fr,ba))%mod;
    }
}
void output(long long l,long long r,long long x)
{
    pushdown(x);
    printf(" %lld %lld %lld\n",l,r,nodes[x].sum);
    if(l^r)
    {
        long long mid=(l+r)>>1;
        output(l,mid,x<<1);
        output(mid+1,r,x<<1|1);
    }
}
void update_LCA(long long x,long long y)
{
    long long fx=hb[x],fy=hb[y];
    while(fx^fy)
    {
        if(dep[fx]<dep[fy])
        {
            swap(fx,fy);
            swap(x,y);
        }
        update(1,s,1,dfn[fx],dfn[x]);
        x=fa[fx];
        fx=hb[x];
    }
    update(1,s,1,min(dfn[x],dfn[y]),max(dfn[x],dfn[y]));
}
long long find_LCA(long long x,long long y)
{
    long long res=0;
    long long fx=hb[x],fy=hb[y];
    while(fx^fy)
    {
        if(dep[fx]<dep[fy])
        {
            swap(fx,fy);
            swap(x,y);
        }
        res+=find(1,s,1,dfn[fx],dfn[x]);
        res%=mod;
        x=fa[fx];
        fx=hb[x];
    }
    res+=find(1,s,1,min(dfn[x],dfn[y]),max(dfn[x],dfn[y]));
    res%=mod;
    return res;
}
int main()
{
    scanf("%lld %lld %lld",&n,&m,&bruh);
    s=n;
    for(long long i=1;i<=n;++i) power[i]=qpow(i,bruh);
    for(long long i=2;i<=n;++i)
    {
        scanf("%lld",&f);
        e[f].push_back(i);
    }
    dfs(1,1);
    dfs2(1,1,0);
    build(1,s,1);
    for(long long i=1;i<=m;++i)
    {
        scanf("%lld %lld",&opr,&opz);
        q[++tot].ID=i;
        q[tot].nag=1;
        q[tot].pos=opr;
        q[tot].val=opz;
    }
    sort(q+1,q+1+tot,cmp);
    while(q[kkk].pos<=0&&kkk<=tot)
    {
        ans[q[kkk].ID]+=(q[kkk].nag*0);
        ++kkk;
    }
    for(long long i=1;i<=n;++i)
    {
        update_LCA(1,i);
        while(q[kkk].pos<=i&&kkk<=tot)
        {
            ans[q[kkk].ID]+=(q[kkk].nag*find_LCA(1,q[kkk].val));
            ans[q[kkk].ID]%=mod;
            ++kkk;
        }
//      puts("");
//      output(1,s,1);
//      puts("\n");
    }
    for(long long i=1;i<=m;++i) printf("%lld\n",((ans[i]%mod)+mod)%mod);
    return 0;
}