题解 P5298 【[PKUWC2018]Minimax】

· · 题解

线段树合并。

这类线段树合并优化树形 DP 的题一般都是先考虑一个 O(n^2) 的 DP。

不难发现对于每一个节点,其子树内所有的权值都可以取到。

故设 f_{i,j}i 的权值取到所有叶子中第 j 小的权值。

lu 的左子,ru 的右子,则有如下转移:

f_{u,i}=p_u\left(f_{l,i}\cdot\sum_{j=1}^{i-1}f_{r,j}+f_{r,i}\cdot\sum_{j=1}^{i-1}f_{l,j}\right)+(1-p_u)\left(f_{l,i}\cdot\sum_{j=i+1}^{n}f_{r,j}+f_{r,i}\cdot\sum_{j=i+1}^{n}f_{l,j}\right)

初始状态就是叶子取到自己的概率为 1

前缀和优化可以做到 O(n^2)

由于叶子上只有 1 个状态,所以考虑线段树合并。

在合并到空节点的时候,可以知道这一段的所有状态都是 0,所以就直接在 \operatorname{merge}(l,r) 过程中保存 [1,l-1][r+1,n] 的两个和值就可以了。

于是这题就做完了~

时空复杂度均为 O(n\log n)

#include <iostream>
#include <cmath>
#include <cstring>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;

inline int qread() {
    register char c = getchar();
    register int x = 0, f = 1;
    while (c < '0' || c > '9') {
        if (c == '-') f = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') {
        x = (x << 3) + (x << 1) + c - 48;
        c = getchar();
    }
    return x * f;
}

inline int Abs(const int& x) {return (x > 0 ? x : -x);}
inline int Max(const int& x, const int& y) {return (x > y ? x : y);}
inline int Min(const int& x, const int& y) {return (x < y ? x : y);}

const int N = 300005;
const long long mod = 998244353;
struct Edge {
    int to, nxt;
    Edge() {
        nxt = -1;
    }
};
int n, hd[N], pnt, bcnt, scnt[N];
long long p[N], b[N];
Edge e[N << 1];

inline long long Add(long long x, long long y) {return (x + y >= mod ? x + y - mod : x + y);}
inline long long Subt(long long x, long long y) {return (x < y ? x - y + mod : x - y);}

inline long long Power(long long x, long long y) {
    long long ans = 1;
    while (y) {
        if (y & 1) ans = ans * x % mod;
        x = x * x % mod;
        y >>= 1;
    }
    return ans;
}

#define Sum(p) (p ? p->sum : 0)
struct Node {
    long long sum, mul;
    Node *l, *r;
    Node() {
        l = r = NULL;
        sum = 0;
        mul = 1;
    }
    inline void Update() {
        sum = Add(Sum(l), Sum(r));
    }
};
Node nd[N * 40];
int top;
struct Segtree {
    Node *_root[N];
    inline void Pushdown(Node *p, int pl, int pr) {
        if (p->l) {
            p->l->sum = p->l->sum * p->mul % mod;
            p->l->mul = p->l->mul * p->mul % mod;
        }
        if (p->r) {
            p->r->sum = p->r->sum * p->mul % mod;
            p->r->mul = p->r->mul * p->mul % mod;
        }
        p->mul = 1;
    }
    inline void Modify(Node *&p, int pl, int pr, int idx, int v) {
        if (!p) p = &nd[++top];
        if (pl == idx && pr == idx) {
            p->sum = v;
            return;
        }
        Pushdown(p, pl, pr);
        int mid = pl + pr >> 1;
        if (idx <= mid) Modify(p->l, pl, mid, idx, v);
        else Modify(p->r, mid + 1, pr, idx, v);
        p->Update();
    }
    inline long long Query(Node *p, int pl, int pr, int l, int r) {
        if (!p) return 0;
        if (pl == l && pr == r) return Sum(p);
        Pushdown(p, pl, pr);
        int mid = pl + pr >> 1;
        if (mid >= r) return Query(p->l, pl, mid, l, r);
        else if (mid + 1 <= l) return Query(p->r, mid + 1, pr, l, r);
        else return Add(Query(p->l, pl, mid, l, mid), Query(p->r, mid + 1, pr, mid + 1, r));
    }
    inline Node* Merge(Node *p1, Node *p2, int pl, int pr, long long pu, long long ps1, long long ss1, long long ps2, long long ss2) {
        long long lmul = Add(pu * ps2 % mod, Subt(1, pu) * ss2 % mod);
        long long rmul = Add(pu * ps1 % mod, Subt(1, pu) * ss1 % mod);
        if (!p1 && !p2) return NULL;
        if (!p2) {
            //printf("p1 %d %d %lld\n", pl, pr, lmul);
            p1->mul = p1->mul * lmul % mod;
            p1->sum = p1->sum * lmul % mod;
            return p1;
        }
        if (!p1) {
            //printf("p2 %d %d %lld\n", pl, pr, rmul);
            p2->mul = p2->mul * rmul % mod;
            p2->sum = p2->sum * rmul % mod;
            return p2;
        }
        Pushdown(p1, pl, pr); Pushdown(p2, pl, pr);
        int mid = pl + pr >> 1;
        long long s1l = Sum(p1->l), s1r = Sum(p1->r), s2l = Sum(p2->l), s2r = Sum(p2->r);
        p1->l = Merge(p1->l, p2->l, pl, mid, pu, ps1, Add(ss1, s1r), ps2, Add(ss2, s2r));
        p1->r = Merge(p1->r, p2->r, mid + 1, pr, pu, Add(ps1, s1l), ss1, Add(ps2, s2l), ss2);
        p1->Update();
        return p1;
    }
};
Segtree sgt;

inline void AddEdge(int u, int v) {
    e[++pnt].to = v;
    e[pnt].nxt = hd[u];
    hd[u] = pnt;
    scnt[u]++;
}

inline void Read() {
    n = qread();
    for (int i = 1;i <= n;i++) {
        int fa = qread();
        AddEdge(fa, i);
    }
    long long inv1e4 = Power(10000, mod - 2);
    for (int i = 1;i <= n;i++) {
        int x = qread();
        if (!~hd[i]) p[i] = b[++bcnt] = x;
        else p[i] = x * inv1e4 % mod;
    }
}

inline void Prefix() {
    sort(b + 1, b + bcnt + 1);
    bcnt = unique(b + 1, b + bcnt + 1) - b - 1;
    for (int i = 1;i <= n;i++) {
        if (!~hd[i]) p[i] = lower_bound(b + 1, b + bcnt + 1, p[i]) - b;
    }
}

inline void Dfs(int u) {
    if (!~hd[u]) {
        sgt.Modify(sgt._root[u], 1, bcnt, p[u], 1);
        //dp[u][p[u]] = 1;
        return;
    }
    if (scnt[u] == 1) {
        Dfs(e[hd[u]].to);
        //memcpy(dp[e[hd[u]].to], dp[u], sizeof(dp[u]));
        sgt._root[u] = sgt._root[e[hd[u]].to];
        return;
    }
    int ls = e[hd[u]].to;
    int rs = e[e[hd[u]].nxt].to;
    Dfs(ls);
    Dfs(rs);
    sgt._root[u] = sgt._root[ls];
    //sgt._root[u]->mul = sgt._root[u]->mul * p[u] % mod;
    //sgt._root[rs]->mul = sgt._root[rs]->mul * Subt(1, p[u]) % mod;
    sgt._root[u] = sgt.Merge(sgt._root[u], sgt._root[rs], 1, bcnt, p[u], 0, 0, 0, 0);
}

int main() {
    memset(hd, -1, sizeof(hd));
    Read();
    Prefix();
    Dfs(1);
    long long ans = 0;
    for (int i = 1;i <= bcnt;i++) {
        long long val = sgt.Query(sgt._root[1], 1, bcnt, i, i);
        //printf("%lld ", val);
        ans = (ans + 1ll * i * b[i] % mod * val % mod * val % mod) % mod;
    }
    //puts("");
    printf("%lld", ans);
    #ifndef ONLINE_JUDGE
    while (1);
    #endif
    return 0;
}