题解 P5298 【[PKUWC2018]Minimax】
由于一个点最多有两个儿子
所以我们分类讨论,用权值线段树合并得到每个权值的概率;
〇没有儿子:权值插入线段树,return;
①只有一个儿子:直接继承儿子的所有概率
②有两个儿子:
某个儿子的某个权值出现的概率
=P[作为最大值出现]+P[作为最小值出现]
=P[该儿子权值是ta]*P[取了最大值]*P[另一儿子权值比ta小]
+P[该儿子权值是ta]*P[取了最小值]*P[另一儿子权值比ta大]
递归合并到单一节点时打乘法标记就好了
#include"cstdio"
#include"cstring"
#include"iostream"
#include"algorithm"
using namespace std;
const int MAXN=3e5+5;
const int MOD=998244353;
int n,np,ct,cnt;
long long sum;
int h[MAXN],p[MAXN],deg[MAXN],so[2][MAXN];
bool nlf[MAXN];
int rt[MAXN],sn[2][MAXN<<5],siz[MAXN<<5],tag[MAXN<<5];
struct rpg{
int li,nx;
}a[MAXN];
int read()
{
int x=0;char ch=getchar();
while(ch<'0'||'9'<ch) ch=getchar();
while('0'<=ch&&ch<='9') x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
return x;
}
void add(int ls,int nx){a[++np]=(rpg){h[ls],nx};h[ls]=np;}
int Fpw(int a,int b)
{
int x=1;
while(b){
if(b&1) x=1ll*x*a%MOD;
b>>=1;a=1ll*a*a%MOD;
}return x;
}
void ins(int &k,int l,int r,int v)
{
if(!k) k=++ct;siz[k]=tag[k]=1;
if(l==r) return;
int mid=l+r>>1;
if(v<=mid) ins(sn[0][k],l,mid,v);
else ins(sn[1][k],mid+1,r,v);
}
void pushdown(int x)
{
if(tag[x]==1) return;
if(sn[0][x]){
tag[sn[0][x]]=1ll*tag[sn[0][x]]*tag[x]%MOD;
siz[sn[0][x]]=1ll*siz[sn[0][x]]*tag[x]%MOD;
}if(sn[1][x]){
tag[sn[1][x]]=1ll*tag[sn[1][x]]*tag[x]%MOD;
siz[sn[1][x]]=1ll*siz[sn[1][x]]*tag[x]%MOD;
}tag[x]=1;
return;
}
int un(int x,int y,int p,int srx,int sry)
{
if(!x||!y){
if(x){
int v=(1ll*(1-sry+MOD)*p+1ll*sry*(1-p+MOD))%MOD;
tag[x]=1ll*tag[x]*v%MOD;
siz[x]=1ll*siz[x]*v%MOD;
return x;
}if(y){
int v=(1ll*(1-srx+MOD)*p+1ll*srx*(1-p+MOD))%MOD;
tag[y]=1ll*tag[y]*v%MOD;
siz[y]=1ll*siz[y]*v%MOD;
return y;
}return 0;
}pushdown(x);pushdown(y);
int tp1=siz[sn[1][x]],tp2=siz[sn[1][y]];
sn[0][x]=un(sn[0][x],sn[0][y],p,(srx+tp1)%MOD,(sry+tp2)%MOD);
sn[1][x]=un(sn[1][x],sn[1][y],p,srx,sry);
siz[x]=siz[sn[0][x]]+siz[sn[1][x]];siz[x]%=MOD;
return x;
}
void dfs(int x)
{
if(!nlf[x]){ins(rt[x],1,1e9,p[x]);return;}
if(so[0][x]) dfs(so[0][x]);
if(so[1][x]) dfs(so[1][x]);
if(deg[x]==1) rt[x]=rt[so[0][x]];
else rt[x]=un(rt[so[0][x]],rt[so[1][x]],p[x],0,0);
}
void ddfs(int k,int l,int r)
{
pushdown(k);
if(l==r){
++cnt;
sum+=1ll*siz[k]*siz[k]%MOD*l%MOD*cnt%MOD;
return;
}int mid=l+r>>1;
if(sn[0][k]) ddfs(sn[0][k],l,mid);
if(sn[1][k]) ddfs(sn[1][k],mid+1,r);
return;
}
int main()
{
n=read();int inv=Fpw(10000,MOD-2);
for(int i=1;i<=n;++i){
int f=read();
if(i>1) add(f,i),nlf[f]=1,so[deg[f]++][f]=i;
}for(int i=1;i<=n;++i){
p[i]=read();
if(nlf[i]) p[i]=1ll*p[i]*inv%MOD;
}dfs(1);
ddfs(rt[1],1,1e9);
printf("%lld\n",sum%MOD);
return 0;
}