题解 P5298 【[PKUWC2018]Minimax】
command_block · · 题解
题目Link
题意复杂,难以概括……
肯定首先离散化。
题目中写着这玩意是个二叉树,那么可能是考虑左右子树间贡献什么的。
题目里要求的是这个:
直接维护这个东西肯定是布星的(无法转移),我们考虑对于每个点维护
采用线段树合并来维护,正好合适
对于某个叶节点,就是在
然后,如果某个点只有一个儿子,那么直接继承整个
对于某个点
首先,我们设左儿子的概率数组为
那么,对于左儿子权为
-
最小值:
D_{i}+=DL_i*(1-p_u)\sum\limits_{i=x+1}^nDR_i -
最大值:
D_{i}+=DL_i*p_u\sum\limits_{i=1}^{x-1}DR_i
由于各个叶节点权值不同,左右权相同的情况是不存在的。
总的来说是
对于右儿子同理。
这么看,需要对
线段树合并时,如果两方都有节点,那么分下去,pushup解决。
如果只有一方有节点的话(假设是左方):
贡献关系式是:
只有一方有节点,表示这个区间内再没有右方节点了。
那么括号里面那个式子是定值,打区间乘法tag即可。
这是一道好题啊!完美的利用了线段树合并的性质!Orz出题人!
Code:
假如思路清晰,那么代码并不难写。
莫名其妙跑到了
#include<algorithm>
#include<cstdio>
#define MaxN 300500
#define mod 998244353
#define ll long long
using namespace std;
inline int read()
{
register int X=0;
register char ch=0;
while(ch<48||ch>57)ch=getchar();
while(ch>=48&&ch<=57)X=X*10+(ch^48),ch=getchar();
return X;
}
ll powM(ll a,ll t=mod-2)
{
ll ans=1;
while(t){
if(t&1)ans=ans*a%mod;
a=a*a%mod;
t>>=1;
}return ans;
}
const ll inv10=powM(10000);
int n,xx[MaxN],tot;
struct TreeNode
{int l,r,x;}b[MaxN];
struct Node
{
int l,r;ll x,tag;
inline void ladd(ll c)
{tag=tag*c%mod;x=x*c%mod;}
}a[MaxN<<6];
int tn;
inline int create()
{a[++tn].tag=1;return tn;}
inline void up(int num)
{a[num].x=(a[a[num].l].x+a[a[num].r].x)%mod;}
int to;
void change(int l,int r,int &num)
{
a[num=create()].x=1;
if (l==r)return ;
int mid=(l+r)>>1;
if (to<=mid)change(l,mid,a[num].l);
else change(mid+1,r,a[num].r);
}
inline void ladd(int num)
{
if (a[num].tag==1)return ;
if (a[num].l)a[a[num].l].ladd(a[num].tag);
if (a[num].r)a[a[num].r].ladd(a[num].tag);
a[num].tag=1;
}
long long lc,rc;
int marge(int x,int y,ll xl,ll xr,ll yl,ll yr)
{
if (!x&&!y)return 0;
if (x&&y){
ladd(x);ladd(y);
ll sav1=a[a[x].l].x,sav2=a[a[y].l].x;
a[x].l=marge(a[x].l,a[y].l,xl,xr+a[a[x].r].x,yl,yr+a[a[y].r].x);
a[x].r=marge(a[x].r,a[y].r,xl+sav1,xr,yl+sav2,yr);
up(x);
}else {
if (!x){swap(x,y);yl=xl;yr=xr;}
yr%=mod;yl%=mod;
a[x].ladd((lc*yr+rc*yl)%mod);
}return x;
}
int rt[MaxN];
void dfs(int num)
{
if (!b[num].l){
to=b[num].x;
change(1,tot,rt[num]);
}else if (!b[num].r){
dfs(b[num].l);
rt[num]=rt[b[num].l];
}else {
dfs(b[num].l);dfs(b[num].r);
lc=mod+1-b[num].x;rc=b[num].x;
rt[num]=marge(rt[b[num].l],rt[b[num].r],0,0,0,0);
}
}
ll ans;
void getans(int l,int r,int num)
{
ladd(num);
if (l==r){
ans=(ans+1ll*l*xx[l]%mod*a[num].x%mod*a[num].x)%mod;
return ;
}int mid=(l+r)>>1;
getans(l,mid,a[num].l);
getans(mid+1,r,a[num].r);
}
int main()
{
n=read();
for (int i=1,fa;i<=n;i++){
fa=read();
if (b[fa].l)b[fa].r=i;
else b[fa].l=i;
}for (int i=1;i<=n;i++){
b[i].x=read();
if (!b[i].l)xx[++tot]=b[i].x;
else b[i].x=b[i].x*inv10%mod;
}sort(xx+1,xx+tot+1);
for (int i=1;i<=n;i++)
if (!b[i].l)
b[i].x=lower_bound(xx+1,xx+tot+1,b[i].x)-xx;
dfs(1);
getans(1,tot,rt[1]);
printf("%lld",ans);
return 0;
}