题解 P5298 【[PKUWC2018]Minimax】

· · 题解

由于概率不为0且不为1,所以所有点权都可能取到。

f_i,g_i分别表示当前节点的左儿子、右儿子点权为i的概率。

i在左儿子中,则当前节点取到i的概率为:

f_i\times(p\times(\sum_{j=1}^i g_j)+(1-p)\times(\sum_{j=i+1}^m g_j))

右儿子同理。

考虑线段树合并,合并到叶子节点的时候显然只会有一个点有值。

边向下传递,边维护左儿子的前缀、后缀,右儿子的前缀、后缀概率和,然后按照上面的方法类似转移即可(此时可能转移到的是一个区间的值)。

线段树上打乘法标记即可。

时间复杂度O(n\log n)

由于合并时需要动态开点(否则前缀、后缀和就会出现问题),所以空间复杂度也是O(n\log n)

Code:

#include<iostream>
#include<algorithm>
#include<vector>
typedef long long LL;
const int md=998244353,N=3e5+5,iv1e4=796898467;
struct edge{
    int to,nxt;
}e[N];
int n,fa[N],head[N],cnt,p[N],rt[N];
std::vector<int>vec;
int ls[N<<6],rs[N<<6],g[N<<6],tag[N<<6],nd,ans;
inline void upd(int&a){a+=a>>31&md;}
inline int mod(int a){return(a>>31&md)+a;}
void pushdown(int o){
    if(tag[o]!=1){
        if(ls[o])
            tag[ls[o]]=(LL)tag[ls[o]]*tag[o]%md,g[ls[o]]=(LL)g[ls[o]]*tag[o]%md;
        if(rs[o])
            tag[rs[o]]=(LL)tag[rs[o]]*tag[o]%md,g[rs[o]]=(LL)g[rs[o]]*tag[o]%md;
        tag[o]=1;
    }
}
void add(int&o,int l,int r,const int&pos){
    if(!o)o=++nd,tag[o]=1;
    upd(g[o]+=1-md);
    if(l<r){
        const int mid=l+r>>1;
        if(pos<=mid)add(ls[o],l,mid,pos);else add(rs[o],mid+1,r,pos);
    }
}
int merge(int ld,int rd,int pl,int pr,int sl,int sr,const int&P){
    if(!ld&&!rd)return 0;
    pushdown(ld),pushdown(rd);
    int o=++nd;tag[o]=1;
    if(!ld){
        g[o]=((LL)P*pl+(md+1LL-P)*sl)%md*g[rd]%md;
        tag[o]=((LL)P*pl+(md+1LL-P)*sl)%md*tag[rd]%md;
        ls[o]=ls[rd],rs[o]=rs[rd];
        return o;
    }
    if(!rd){
        g[o]=((LL)P*pr+(md+1LL-P)*sr)%md*g[ld]%md;
        tag[o]=((LL)P*pr+(md+1LL-P)*sr)%md*tag[ld]%md;
        ls[o]=ls[ld],rs[o]=rs[ld];
        return o;
    }
    ls[o]=merge(ls[ld],ls[rd],pl,pr,mod(sl+g[rs[ld]]-md),mod(sr+g[rs[rd]]-md),P);
    rs[o]=merge(rs[ld],rs[rd],mod(pl+g[ls[ld]]-md),mod(pr+g[ls[rd]]-md),sl,sr,P);
    upd(g[o]=g[ls[o]]+g[rs[o]]-md);
    return o;
}
void dfs(int now){
    if(head[now])
        for(int i=head[now];i;i=e[i].nxt){
            dfs(e[i].to);
            if(!rt[now])rt[now]=rt[e[i].to];else
                rt[now]=merge(rt[now],rt[e[i].to],0,0,0,0,p[now]);
        }
    else
        add(rt[now],1,n,p[now]);
}
void getans(int o,int l,int r){
    if(!o)return;
    if(l==r){
        ans=(ans+(LL)l*vec[l]%md*g[o]%md*g[o])%md;
        return;
    }
    pushdown(o);
    const int mid=l+r>>1;
    getans(ls[o],l,mid),getans(rs[o],mid+1,r);
}
int main(){
    std::ios::sync_with_stdio(0),std::cin.tie(0);
    std::cin>>n;
    for(int i=1;i<=n;++i)
        std::cin>>fa[i],e[++cnt]=(edge){i,head[fa[i]]},head[fa[i]]=cnt;
    for(int i=1;i<=n;++i){
        std::cin>>p[i];
        if(head[i])p[i]=(LL)p[i]*iv1e4%md;else vec.push_back(p[i]);
    }
    vec.push_back(0);
    std::sort(vec.begin(),vec.end());
    vec.erase(std::unique(vec.begin(),vec.end()),vec.end());
    for(int i=1;i<=n;++i)if(!head[i])p[i]=std::lower_bound(vec.begin(),vec.end(),p[i])-vec.begin();
    dfs(1);
    getans(rt[1],1,n);
    std::cout<<ans<<std::endl;
    return 0;
}