题解 P5298 【[PKUWC2018]Minimax】

· · 题解

题目Link

题意复杂,难以概括……

肯定首先离散化。

题目中写着这玩意是个二叉树,那么可能是考虑左右子树间贡献什么的。

题目里要求的是这个:\sum_{i=1}^{m}i\cdot V_i\cdot D_i^2

直接维护这个东西肯定是布星的(无法转移),我们考虑对于每个点维护D_{[1...n]}

采用线段树合并来维护,正好合适3*10^5的范围。

对于某个叶节点,就是在D_{?}处是1,其他都是0;

然后,如果某个点只有一个儿子,那么直接继承整个D数组就好。

对于某个点u有两个儿子:

首先,我们设左儿子的概率数组为DL,相应地右儿子为DR

那么,对于左儿子权为x(是离散化之后的)的情况,要以继承到u的概率是:

由于各个叶节点权值不同,左右权相同的情况是不存在的。

总的来说是D_i+=DL_i*\left((1-p_u)\sum\limits_{i=x+1}^nDR_i+p_u\sum\limits_{i=1}^{x-1}DR_i\right)

对于右儿子同理。

这么看,需要对D维护一个前缀和和后缀和,动态维护好像很麻烦,可以边合并边算(利用区间和),用完就丢掉。

线段树合并时,如果两方都有节点,那么分下去,pushup解决。

如果只有一方有节点的话(假设是左方):

贡献关系式是:

D_i=DL_i*\left((1-p_u)\sum\limits_{i=x+1}^nDR_i+p_u\sum\limits_{i=1}^{x-1}DR_i\right)

只有一方有节点,表示这个区间内再没有右方节点了。

那么括号里面那个式子是定值,打区间乘法tag即可。

这是一道好题啊!完美的利用了线段树合并的性质!Orz出题人!

Code:

假如思路清晰,那么代码并不难写。

莫名其妙跑到了rk1?不过应该很快就会被超过去。

#include<algorithm>
#include<cstdio>
#define MaxN 300500
#define mod 998244353
#define ll long long
using namespace std;
inline int read()
{
  register int X=0;
  register char ch=0;
  while(ch<48||ch>57)ch=getchar();
  while(ch>=48&&ch<=57)X=X*10+(ch^48),ch=getchar();
  return X;
}
ll powM(ll a,ll t=mod-2)
{
  ll ans=1;
  while(t){
    if(t&1)ans=ans*a%mod;
    a=a*a%mod;
    t>>=1;
  }return ans;
}
const ll inv10=powM(10000);
int n,xx[MaxN],tot;
struct TreeNode
{int l,r,x;}b[MaxN];
struct Node
{
  int l,r;ll x,tag;
  inline void ladd(ll c)
  {tag=tag*c%mod;x=x*c%mod;}
}a[MaxN<<6];
int tn;
inline int create()
{a[++tn].tag=1;return tn;}
inline void up(int num)
{a[num].x=(a[a[num].l].x+a[a[num].r].x)%mod;}
int to;
void change(int l,int r,int &num)
{
  a[num=create()].x=1;
  if (l==r)return ;
  int mid=(l+r)>>1;
  if (to<=mid)change(l,mid,a[num].l);
  else change(mid+1,r,a[num].r);
}
inline void ladd(int num)
{
  if (a[num].tag==1)return ;
  if (a[num].l)a[a[num].l].ladd(a[num].tag);
  if (a[num].r)a[a[num].r].ladd(a[num].tag);
  a[num].tag=1;
}
long long lc,rc;
int marge(int x,int y,ll xl,ll xr,ll yl,ll yr)
{
  if (!x&&!y)return 0;
  if (x&&y){
    ladd(x);ladd(y);
    ll sav1=a[a[x].l].x,sav2=a[a[y].l].x;
    a[x].l=marge(a[x].l,a[y].l,xl,xr+a[a[x].r].x,yl,yr+a[a[y].r].x);
    a[x].r=marge(a[x].r,a[y].r,xl+sav1,xr,yl+sav2,yr);
    up(x);
  }else {
    if (!x){swap(x,y);yl=xl;yr=xr;}
    yr%=mod;yl%=mod;
    a[x].ladd((lc*yr+rc*yl)%mod);
  }return x;
}
int rt[MaxN];
void dfs(int num)
{
  if (!b[num].l){
    to=b[num].x;
    change(1,tot,rt[num]);
  }else if (!b[num].r){
    dfs(b[num].l);
    rt[num]=rt[b[num].l];
  }else {
    dfs(b[num].l);dfs(b[num].r);
    lc=mod+1-b[num].x;rc=b[num].x;
    rt[num]=marge(rt[b[num].l],rt[b[num].r],0,0,0,0);
  }
}
ll ans;
void getans(int l,int r,int num)
{
  ladd(num);
  if (l==r){
    ans=(ans+1ll*l*xx[l]%mod*a[num].x%mod*a[num].x)%mod;
    return ;
  }int mid=(l+r)>>1;
  getans(l,mid,a[num].l);
  getans(mid+1,r,a[num].r);
}
int main()
{
  n=read();
  for (int i=1,fa;i<=n;i++){
    fa=read();
    if (b[fa].l)b[fa].r=i;
    else b[fa].l=i;
  }for (int i=1;i<=n;i++){
    b[i].x=read();
    if (!b[i].l)xx[++tot]=b[i].x;
    else b[i].x=b[i].x*inv10%mod;
  }sort(xx+1,xx+tot+1);
  for (int i=1;i<=n;i++)
    if (!b[i].l)
      b[i].x=lower_bound(xx+1,xx+tot+1,b[i].x)-xx;
  dfs(1);
  getans(1,tot,rt[1]);
  printf("%lld",ans);
  return 0;
}