题解 P5305 【[GXOI/GZOI2019]旧词】
CyanSineSin · · 题解
这道题目非常简洁:要求 (都写在题面里了)
要直接解决这个问题是有点困难的,那么——
我们先看它的弱化版:[LNOI2014]LCA
要求
少了个k次方呢!
首先转化一下
可以用前缀和来解决这个问题。
把每个询问拆成
那么我们可以按顺序把
那么现在考虑加到第
先上张图举个例子~
其中涂成红色的点就是已经加入的点。
我们先计算一下每个节点的子树中有多少已经加入的点,用
那么
现在考虑计算当前加入的节点到4的LCA深度和。
即加入的节点中在
把式子拆开化简得到:
于是我们发现其实答案就是该节点到根路径的节点子树大小之和。
普遍形式:设查询节点到根的路径为
由于是到根节点的路径,所以
于是化简得:
就此得出结论:答案就是该节点到根路径的节点子树大小之和。
查询直接查询要查询的节点到根路径上的节点的子树大小之和。
加入节点时只有它到根的路径上的节点的子树大小加了一。
使用树链剖分加线段树维护即可。
回到本题
本题没有
那么我们抬出之前的式子改成这道题的样子。
化简一下,得到:
可以发现,
所以对于
设对于点i的预处理答案为
则答案为:
就是在线段树维护时多加一个权值的问题了。
在线段树的每个节点上附加一个权值,表示这个区间里所有点的
这样就可以区间修改和下传标记了。
跟上面那道弱化版差得不多吧~
代码:
//直接用 LCA 代码改的,可能有点迷惑(?
#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
vector<long long> e[50010];
struct node
{
long long bas,sum,tag;//bas为附加的权值
}nodes[400010];
struct ask
{
long long pos,val,ID,nag;
}q[100010];
bool cmp(ask one,ask ano)
{
return one.pos<ano.pos;
}
long long n,m,s,opl,opr,opz,f,cnt,tot,waste,ans[50010],kkk,dep[50010],fa[50010],son[50010],siz[50010],hb[50010],dfn[50010],ton[50010],power[50010],bruh;
const long long mod=998244353;
long long qpow(long long bas,long long tim)//快速幂用来处理val
{
long long res=1,fur=bas;
while(tim)
{
if(tim&1) res=(res*fur)%mod;
fur=(fur*fur)%mod;
tim>>=1;
}
return res;
}
void dfs(long long x,long long las)
{
fa[x]=las;
dep[x]=dep[las]+1;
siz[x]=1;
long long b=0,s=0;
for(long long i=0;i<e[x].size();++i)
{
long long y=e[x][i];
dfs(y,x);
siz[x]+=siz[y];
if(siz[y]>b)
{
b=siz[y];
s=y;
}
}
son[x]=s;
}
void dfs2(long long x,long long las,long long heavy)//树剖dfs
{
if(heavy) hb[x]=hb[las];
else hb[x]=x;
dfn[x]=++cnt;
ton[cnt]=x;
if(son[x]) dfs2(son[x],x,1);
for(long long i=0;i<e[x].size();++i)
{
long long y=e[x][i];
if(y^son[x]) dfs2(y,x,0);
}
}
void pushdown(long long x)
{
if(nodes[x].tag)
{
//(siz[x]+a)*val[x]+(siz[y]+a)*val[y]+...+(siz[z]+a)*val[z]
//=siz[x]*val[x]+siz[y]*val[y]+...+siz[z]*val[z]+a*(val[x]+val[y]+...+val[z])
nodes[x].sum+=(nodes[x].bas*nodes[x].tag);
nodes[x].sum%=mod;
nodes[x<<1].tag+=nodes[x].tag;
nodes[x<<1].tag%=mod;
nodes[x<<1|1].tag+=nodes[x].tag;
nodes[x<<1|1].tag%=mod;
nodes[x].tag=0;
}
}
void build(long long l,long long r,long long x)//预处理val区间和
{
if(l^r)
{
long long mid=(l+r)>>1;
build(l,mid,x<<1);
build(mid+1,r,x<<1|1);
nodes[x].bas=(nodes[x<<1].bas+nodes[x<<1|1].bas)%mod;
}
else nodes[x].bas=power[dep[ton[l]]]-power[dep[ton[l]]-1];
// printf(" %lld %lld %lld\n",l,r,nodes[x].bas);
}
void update(long long l,long long r,long long x,long long fr,long long ba)
{
if(l>ba||r<fr) return;
if(l>=fr&&r<=ba) nodes[x].tag=(nodes[x].tag+1)%mod;
else
{
pushdown(x);
long long mid=(l+r)>>1;
update(l,mid,x<<1,fr,ba);
update(mid+1,r,x<<1|1,fr,ba);
pushdown(x<<1);
pushdown(x<<1|1);
nodes[x].sum=(nodes[x<<1].sum+nodes[x<<1|1].sum)%mod;
}
}
long long find(long long l,long long r,long long x,long long fr,long long ba)
{
if(l>ba||r<fr) return 0;
pushdown(x);
if(l>=fr&&r<=ba) return nodes[x].sum;
else
{
long long mid=(l+r)>>1;
return (find(l,mid,x<<1,fr,ba)+find(mid+1,r,x<<1|1,fr,ba))%mod;
}
}
void output(long long l,long long r,long long x)
{
pushdown(x);
printf(" %lld %lld %lld\n",l,r,nodes[x].sum);
if(l^r)
{
long long mid=(l+r)>>1;
output(l,mid,x<<1);
output(mid+1,r,x<<1|1);
}
}
void update_LCA(long long x,long long y)
{
long long fx=hb[x],fy=hb[y];
while(fx^fy)
{
if(dep[fx]<dep[fy])
{
swap(fx,fy);
swap(x,y);
}
update(1,s,1,dfn[fx],dfn[x]);
x=fa[fx];
fx=hb[x];
}
update(1,s,1,min(dfn[x],dfn[y]),max(dfn[x],dfn[y]));
}
long long find_LCA(long long x,long long y)
{
long long res=0;
long long fx=hb[x],fy=hb[y];
while(fx^fy)
{
if(dep[fx]<dep[fy])
{
swap(fx,fy);
swap(x,y);
}
res+=find(1,s,1,dfn[fx],dfn[x]);
res%=mod;
x=fa[fx];
fx=hb[x];
}
res+=find(1,s,1,min(dfn[x],dfn[y]),max(dfn[x],dfn[y]));
res%=mod;
return res;
}
int main()
{
scanf("%lld %lld %lld",&n,&m,&bruh);
s=n;
for(long long i=1;i<=n;++i) power[i]=qpow(i,bruh);
for(long long i=2;i<=n;++i)
{
scanf("%lld",&f);
e[f].push_back(i);
}
dfs(1,1);
dfs2(1,1,0);
build(1,s,1);
for(long long i=1;i<=m;++i)
{
scanf("%lld %lld",&opr,&opz);
q[++tot].ID=i;
q[tot].nag=1;
q[tot].pos=opr;
q[tot].val=opz;
}
sort(q+1,q+1+tot,cmp);
while(q[kkk].pos<=0&&kkk<=tot)
{
ans[q[kkk].ID]+=(q[kkk].nag*0);
++kkk;
}
for(long long i=1;i<=n;++i)
{
update_LCA(1,i);
while(q[kkk].pos<=i&&kkk<=tot)
{
ans[q[kkk].ID]+=(q[kkk].nag*find_LCA(1,q[kkk].val));
ans[q[kkk].ID]%=mod;
++kkk;
}
// puts("");
// output(1,s,1);
// puts("\n");
}
for(long long i=1;i<=m;++i) printf("%lld\n",((ans[i]%mod)+mod)%mod);
return 0;
}