题解 P6773 【[NOI2020]命运】
题意
给定一棵
题解
神仙题,但是大多数题解直接摆结论所以对于我这种还带有 whk 后遗症的菜鸡来说理解起来有点困难,所以写一篇自认为比较清晰的。
首先第一个问题是怎么 DP。
注意到有一个 key observation 是这样的:对于某一个节点
同时,对于
考虑转移,如果将
第一个和式表示
这个时候设
然后我们就有
这个时候考虑线段树合并。合并的过程中借用一点 cdq 分治的思路,也就是说先合并左子树顺便更新
代码
#include<bits/stdc++.h>
using namespace std;
typedef int ll;
typedef long long int li;
const ll MAXN=5e5+51,MOD=998244353;
ll n,m,u,v,totn;
ll depth[MAXN],ls[MAXN<<5],rs[MAXN<<5],sm[MAXN<<5],tag[MAXN<<5],rt[MAXN];
vector<ll>vg[MAXN],link[MAXN];
inline ll read()
{
register ll num=0,neg=1;
register char ch=getchar();
while(!isdigit(ch)&&ch!='-')
{
ch=getchar();
}
if(ch=='-')
{
neg=-1;
ch=getchar();
}
while(isdigit(ch))
{
num=(num<<3)+(num<<1)+(ch-'0');
ch=getchar();
}
return num*neg;
}
inline void update(ll node)
{
sm[node]=(sm[ls[node]]+sm[rs[node]])%MOD;
}
inline void spread(ll node)
{
if(tag[node]!=1)
{
sm[ls[node]]=(li)tag[node]*sm[ls[node]]%MOD;
sm[rs[node]]=(li)tag[node]*sm[rs[node]]%MOD;
tag[ls[node]]=(li)tag[node]*tag[ls[node]]%MOD;
tag[rs[node]]=(li)tag[node]*tag[rs[node]]%MOD;
tag[node]=1;
}
}
inline void change(ll l,ll r,ll pos,ll val,ll &node)
{
!node?tag[node=++totn]=1:1;
if(l==r)
{
return (void)(tag[node]=1,sm[node]=val);
}
ll mid=(l+r)>>1;
spread(node);
pos<=mid?change(l,mid,pos,val,ls[node]):change(mid+1,r,pos,val,rs[node]);
update(node);
}
inline ll query(ll l,ll r,ll ql,ll qr,ll node)
{
if(ql<=l&&qr>=r)
{
return sm[node];
}
ll mid=(l+r)>>1,res=0;
spread(node);
res=(res+(ql<=mid?query(l,mid,ql,qr,ls[node]):0))%MOD;
res=(res+(qr>mid?query(mid+1,r,ql,qr,rs[node]):0))%MOD;
return res;
}
inline ll merge(ll x,ll y,ll l,ll r,ll &su,ll &sv)
{
if(!x&&!y)
{
return 0;
}
if(!x)
{
sv=(sv+sm[y])%MOD,tag[y]=(li)tag[y]*su%MOD;
return sm[y]=(li)sm[y]*su%MOD,y;
}
if(!y)
{
su=(su+sm[x])%MOD,tag[x]=(li)tag[x]*sv%MOD;
return sm[x]=(li)sm[x]*sv%MOD,x;
}
if(l==r)
{
ll cu=sm[x],cv=sm[y];
sv=(sv+cv)%MOD,sm[x]=((li)sm[x]*sv+(li)sm[y]*su)%MOD;
return su=(su+cu)%MOD,x;
}
ll mid=(l+r)>>1;
spread(x),spread(y);
ls[x]=merge(ls[x],ls[y],l,mid,su,sv);
rs[x]=merge(rs[x],rs[y],mid+1,r,su,sv);
update(x);
return x;
}
inline void dfs(ll node,ll fa)
{
depth[node]=depth[fa]+1;
ll mxd=0,su,sv;
for(register int i:link[node])
{
mxd=max(mxd,depth[i]);
}
change(0,n,mxd,1,rt[node]);
for(register int i:vg[node])
{
if(i!=fa)
{
dfs(i,node),su=0,sv=query(0,n,0,depth[node],rt[i]);
rt[node]=merge(rt[node],rt[i],0,n,su,sv);
}
}
}
int main()
{
n=read();
for(register int i=0;i<n-1;i++)
{
u=read(),v=read(),vg[u].emplace_back(v),vg[v].emplace_back(u);
}
m=read();
for(register int i=1;i<=m;i++)
{
u=read(),v=read(),link[v].push_back(u);
}
dfs(1,0),printf("%d\n",query(0,n,0,0,rt[1]));
}