题解 P5298 【[PKUWC2018]Minimax】
由于概率不为0且不为1,所以所有点权都可能取到。
设
若
右儿子同理。
考虑线段树合并,合并到叶子节点的时候显然只会有一个点有值。
边向下传递,边维护左儿子的前缀、后缀,右儿子的前缀、后缀概率和,然后按照上面的方法类似转移即可(此时可能转移到的是一个区间的值)。
线段树上打乘法标记即可。
时间复杂度
由于合并时需要动态开点(否则前缀、后缀和就会出现问题),所以空间复杂度也是
Code:
#include<iostream>
#include<algorithm>
#include<vector>
typedef long long LL;
const int md=998244353,N=3e5+5,iv1e4=796898467;
struct edge{
int to,nxt;
}e[N];
int n,fa[N],head[N],cnt,p[N],rt[N];
std::vector<int>vec;
int ls[N<<6],rs[N<<6],g[N<<6],tag[N<<6],nd,ans;
inline void upd(int&a){a+=a>>31&md;}
inline int mod(int a){return(a>>31&md)+a;}
void pushdown(int o){
if(tag[o]!=1){
if(ls[o])
tag[ls[o]]=(LL)tag[ls[o]]*tag[o]%md,g[ls[o]]=(LL)g[ls[o]]*tag[o]%md;
if(rs[o])
tag[rs[o]]=(LL)tag[rs[o]]*tag[o]%md,g[rs[o]]=(LL)g[rs[o]]*tag[o]%md;
tag[o]=1;
}
}
void add(int&o,int l,int r,const int&pos){
if(!o)o=++nd,tag[o]=1;
upd(g[o]+=1-md);
if(l<r){
const int mid=l+r>>1;
if(pos<=mid)add(ls[o],l,mid,pos);else add(rs[o],mid+1,r,pos);
}
}
int merge(int ld,int rd,int pl,int pr,int sl,int sr,const int&P){
if(!ld&&!rd)return 0;
pushdown(ld),pushdown(rd);
int o=++nd;tag[o]=1;
if(!ld){
g[o]=((LL)P*pl+(md+1LL-P)*sl)%md*g[rd]%md;
tag[o]=((LL)P*pl+(md+1LL-P)*sl)%md*tag[rd]%md;
ls[o]=ls[rd],rs[o]=rs[rd];
return o;
}
if(!rd){
g[o]=((LL)P*pr+(md+1LL-P)*sr)%md*g[ld]%md;
tag[o]=((LL)P*pr+(md+1LL-P)*sr)%md*tag[ld]%md;
ls[o]=ls[ld],rs[o]=rs[ld];
return o;
}
ls[o]=merge(ls[ld],ls[rd],pl,pr,mod(sl+g[rs[ld]]-md),mod(sr+g[rs[rd]]-md),P);
rs[o]=merge(rs[ld],rs[rd],mod(pl+g[ls[ld]]-md),mod(pr+g[ls[rd]]-md),sl,sr,P);
upd(g[o]=g[ls[o]]+g[rs[o]]-md);
return o;
}
void dfs(int now){
if(head[now])
for(int i=head[now];i;i=e[i].nxt){
dfs(e[i].to);
if(!rt[now])rt[now]=rt[e[i].to];else
rt[now]=merge(rt[now],rt[e[i].to],0,0,0,0,p[now]);
}
else
add(rt[now],1,n,p[now]);
}
void getans(int o,int l,int r){
if(!o)return;
if(l==r){
ans=(ans+(LL)l*vec[l]%md*g[o]%md*g[o])%md;
return;
}
pushdown(o);
const int mid=l+r>>1;
getans(ls[o],l,mid),getans(rs[o],mid+1,r);
}
int main(){
std::ios::sync_with_stdio(0),std::cin.tie(0);
std::cin>>n;
for(int i=1;i<=n;++i)
std::cin>>fa[i],e[++cnt]=(edge){i,head[fa[i]]},head[fa[i]]=cnt;
for(int i=1;i<=n;++i){
std::cin>>p[i];
if(head[i])p[i]=(LL)p[i]*iv1e4%md;else vec.push_back(p[i]);
}
vec.push_back(0);
std::sort(vec.begin(),vec.end());
vec.erase(std::unique(vec.begin(),vec.end()),vec.end());
for(int i=1;i<=n;++i)if(!head[i])p[i]=std::lower_bound(vec.begin(),vec.end(),p[i])-vec.begin();
dfs(1);
getans(rt[1],1,n);
std::cout<<ans<<std::endl;
return 0;
}