题解 P5298 【[PKUWC2018]Minimax】
线段树合并。
这类线段树合并优化树形 DP 的题一般都是先考虑一个
不难发现对于每一个节点,其子树内所有的权值都可以取到。
故设
设
初始状态就是叶子取到自己的概率为
前缀和优化可以做到
由于叶子上只有
在合并到空节点的时候,可以知道这一段的所有状态都是
于是这题就做完了~
时空复杂度均为
#include <iostream>
#include <cmath>
#include <cstring>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;
inline int qread() {
register char c = getchar();
register int x = 0, f = 1;
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = (x << 3) + (x << 1) + c - 48;
c = getchar();
}
return x * f;
}
inline int Abs(const int& x) {return (x > 0 ? x : -x);}
inline int Max(const int& x, const int& y) {return (x > y ? x : y);}
inline int Min(const int& x, const int& y) {return (x < y ? x : y);}
const int N = 300005;
const long long mod = 998244353;
struct Edge {
int to, nxt;
Edge() {
nxt = -1;
}
};
int n, hd[N], pnt, bcnt, scnt[N];
long long p[N], b[N];
Edge e[N << 1];
inline long long Add(long long x, long long y) {return (x + y >= mod ? x + y - mod : x + y);}
inline long long Subt(long long x, long long y) {return (x < y ? x - y + mod : x - y);}
inline long long Power(long long x, long long y) {
long long ans = 1;
while (y) {
if (y & 1) ans = ans * x % mod;
x = x * x % mod;
y >>= 1;
}
return ans;
}
#define Sum(p) (p ? p->sum : 0)
struct Node {
long long sum, mul;
Node *l, *r;
Node() {
l = r = NULL;
sum = 0;
mul = 1;
}
inline void Update() {
sum = Add(Sum(l), Sum(r));
}
};
Node nd[N * 40];
int top;
struct Segtree {
Node *_root[N];
inline void Pushdown(Node *p, int pl, int pr) {
if (p->l) {
p->l->sum = p->l->sum * p->mul % mod;
p->l->mul = p->l->mul * p->mul % mod;
}
if (p->r) {
p->r->sum = p->r->sum * p->mul % mod;
p->r->mul = p->r->mul * p->mul % mod;
}
p->mul = 1;
}
inline void Modify(Node *&p, int pl, int pr, int idx, int v) {
if (!p) p = &nd[++top];
if (pl == idx && pr == idx) {
p->sum = v;
return;
}
Pushdown(p, pl, pr);
int mid = pl + pr >> 1;
if (idx <= mid) Modify(p->l, pl, mid, idx, v);
else Modify(p->r, mid + 1, pr, idx, v);
p->Update();
}
inline long long Query(Node *p, int pl, int pr, int l, int r) {
if (!p) return 0;
if (pl == l && pr == r) return Sum(p);
Pushdown(p, pl, pr);
int mid = pl + pr >> 1;
if (mid >= r) return Query(p->l, pl, mid, l, r);
else if (mid + 1 <= l) return Query(p->r, mid + 1, pr, l, r);
else return Add(Query(p->l, pl, mid, l, mid), Query(p->r, mid + 1, pr, mid + 1, r));
}
inline Node* Merge(Node *p1, Node *p2, int pl, int pr, long long pu, long long ps1, long long ss1, long long ps2, long long ss2) {
long long lmul = Add(pu * ps2 % mod, Subt(1, pu) * ss2 % mod);
long long rmul = Add(pu * ps1 % mod, Subt(1, pu) * ss1 % mod);
if (!p1 && !p2) return NULL;
if (!p2) {
//printf("p1 %d %d %lld\n", pl, pr, lmul);
p1->mul = p1->mul * lmul % mod;
p1->sum = p1->sum * lmul % mod;
return p1;
}
if (!p1) {
//printf("p2 %d %d %lld\n", pl, pr, rmul);
p2->mul = p2->mul * rmul % mod;
p2->sum = p2->sum * rmul % mod;
return p2;
}
Pushdown(p1, pl, pr); Pushdown(p2, pl, pr);
int mid = pl + pr >> 1;
long long s1l = Sum(p1->l), s1r = Sum(p1->r), s2l = Sum(p2->l), s2r = Sum(p2->r);
p1->l = Merge(p1->l, p2->l, pl, mid, pu, ps1, Add(ss1, s1r), ps2, Add(ss2, s2r));
p1->r = Merge(p1->r, p2->r, mid + 1, pr, pu, Add(ps1, s1l), ss1, Add(ps2, s2l), ss2);
p1->Update();
return p1;
}
};
Segtree sgt;
inline void AddEdge(int u, int v) {
e[++pnt].to = v;
e[pnt].nxt = hd[u];
hd[u] = pnt;
scnt[u]++;
}
inline void Read() {
n = qread();
for (int i = 1;i <= n;i++) {
int fa = qread();
AddEdge(fa, i);
}
long long inv1e4 = Power(10000, mod - 2);
for (int i = 1;i <= n;i++) {
int x = qread();
if (!~hd[i]) p[i] = b[++bcnt] = x;
else p[i] = x * inv1e4 % mod;
}
}
inline void Prefix() {
sort(b + 1, b + bcnt + 1);
bcnt = unique(b + 1, b + bcnt + 1) - b - 1;
for (int i = 1;i <= n;i++) {
if (!~hd[i]) p[i] = lower_bound(b + 1, b + bcnt + 1, p[i]) - b;
}
}
inline void Dfs(int u) {
if (!~hd[u]) {
sgt.Modify(sgt._root[u], 1, bcnt, p[u], 1);
//dp[u][p[u]] = 1;
return;
}
if (scnt[u] == 1) {
Dfs(e[hd[u]].to);
//memcpy(dp[e[hd[u]].to], dp[u], sizeof(dp[u]));
sgt._root[u] = sgt._root[e[hd[u]].to];
return;
}
int ls = e[hd[u]].to;
int rs = e[e[hd[u]].nxt].to;
Dfs(ls);
Dfs(rs);
sgt._root[u] = sgt._root[ls];
//sgt._root[u]->mul = sgt._root[u]->mul * p[u] % mod;
//sgt._root[rs]->mul = sgt._root[rs]->mul * Subt(1, p[u]) % mod;
sgt._root[u] = sgt.Merge(sgt._root[u], sgt._root[rs], 1, bcnt, p[u], 0, 0, 0, 0);
}
int main() {
memset(hd, -1, sizeof(hd));
Read();
Prefix();
Dfs(1);
long long ans = 0;
for (int i = 1;i <= bcnt;i++) {
long long val = sgt.Query(sgt._root[1], 1, bcnt, i, i);
//printf("%lld ", val);
ans = (ans + 1ll * i * b[i] % mod * val % mod * val % mod) % mod;
}
//puts("");
printf("%lld", ans);
#ifndef ONLINE_JUDGE
while (1);
#endif
return 0;
}