题解 P5298 【[PKUWC2018]Minimax】

· · 题解

由于一个点最多有两个儿子

所以我们分类讨论,用权值线段树合并得到每个权值的概率;

〇没有儿子:权值插入线段树,return;

①只有一个儿子:直接继承儿子的所有概率

②有两个儿子:


某个儿子的某个权值出现的概率

=P[作为最大值出现]+P[作为最小值出现]

=P[该儿子权值是ta]*P[取了最大值]*P[另一儿子权值比ta小]

+P[该儿子权值是ta]*P[取了最小值]*P[另一儿子权值比ta大]

递归合并到单一节点时打乘法标记就好了

#include"cstdio"
#include"cstring"
#include"iostream"
#include"algorithm"
using namespace std;

const int MAXN=3e5+5;
const int MOD=998244353;

int n,np,ct,cnt;
long long sum;
int h[MAXN],p[MAXN],deg[MAXN],so[2][MAXN];
bool nlf[MAXN];
int rt[MAXN],sn[2][MAXN<<5],siz[MAXN<<5],tag[MAXN<<5];
struct rpg{
    int li,nx;
}a[MAXN];

int read()
{
    int x=0;char ch=getchar();
    while(ch<'0'||'9'<ch) ch=getchar();
    while('0'<=ch&&ch<='9') x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
    return x;
}

void add(int ls,int nx){a[++np]=(rpg){h[ls],nx};h[ls]=np;}
int Fpw(int a,int b)
{
    int x=1;
    while(b){
        if(b&1) x=1ll*x*a%MOD;
        b>>=1;a=1ll*a*a%MOD;
    }return x;
}

void ins(int &k,int l,int r,int v)
{
    if(!k) k=++ct;siz[k]=tag[k]=1;
    if(l==r) return;
    int mid=l+r>>1;
    if(v<=mid) ins(sn[0][k],l,mid,v);
    else ins(sn[1][k],mid+1,r,v);
}

void pushdown(int x)
{
    if(tag[x]==1) return;
    if(sn[0][x]){
        tag[sn[0][x]]=1ll*tag[sn[0][x]]*tag[x]%MOD;
        siz[sn[0][x]]=1ll*siz[sn[0][x]]*tag[x]%MOD;
    }if(sn[1][x]){
        tag[sn[1][x]]=1ll*tag[sn[1][x]]*tag[x]%MOD;
        siz[sn[1][x]]=1ll*siz[sn[1][x]]*tag[x]%MOD;
    }tag[x]=1;
    return;
}

int un(int x,int y,int p,int srx,int sry)
{
    if(!x||!y){
        if(x){
            int v=(1ll*(1-sry+MOD)*p+1ll*sry*(1-p+MOD))%MOD;
            tag[x]=1ll*tag[x]*v%MOD;
            siz[x]=1ll*siz[x]*v%MOD;
            return x;
        }if(y){
            int v=(1ll*(1-srx+MOD)*p+1ll*srx*(1-p+MOD))%MOD;
            tag[y]=1ll*tag[y]*v%MOD;
            siz[y]=1ll*siz[y]*v%MOD;
            return y;
        }return 0;
    }pushdown(x);pushdown(y);
    int tp1=siz[sn[1][x]],tp2=siz[sn[1][y]];
    sn[0][x]=un(sn[0][x],sn[0][y],p,(srx+tp1)%MOD,(sry+tp2)%MOD);
    sn[1][x]=un(sn[1][x],sn[1][y],p,srx,sry);
    siz[x]=siz[sn[0][x]]+siz[sn[1][x]];siz[x]%=MOD;
    return x;
}

void dfs(int x)
{
    if(!nlf[x]){ins(rt[x],1,1e9,p[x]);return;}
    if(so[0][x]) dfs(so[0][x]);
    if(so[1][x]) dfs(so[1][x]);
    if(deg[x]==1) rt[x]=rt[so[0][x]];
    else rt[x]=un(rt[so[0][x]],rt[so[1][x]],p[x],0,0);
}

void ddfs(int k,int l,int r)
{
    pushdown(k);
    if(l==r){
        ++cnt;
        sum+=1ll*siz[k]*siz[k]%MOD*l%MOD*cnt%MOD;
        return;
    }int mid=l+r>>1;
    if(sn[0][k]) ddfs(sn[0][k],l,mid);
    if(sn[1][k]) ddfs(sn[1][k],mid+1,r);
    return;
}

int main()
{
    n=read();int inv=Fpw(10000,MOD-2);
    for(int i=1;i<=n;++i){
        int f=read();
        if(i>1) add(f,i),nlf[f]=1,so[deg[f]++][f]=i;
    }for(int i=1;i<=n;++i){
        p[i]=read();
        if(nlf[i]) p[i]=1ll*p[i]*inv%MOD;
    }dfs(1);
    ddfs(rt[1],1,1e9);
    printf("%lld\n",sum%MOD);
    return 0;
}