题解:P4927 [1007] 梦美与线段树

· · 题解

形式化题面:

维护一个长度为 n 的序列,维护区间加,每次询问:从线段树根节点出发,按给定的权重概率游走,求经过所有节点权值和的期望。

先推一下式子。

E_u 表示从节点 u 出发开始,走过所有点的权值和的期望。

从点 u 出发,首先肯定会经过点 u 自己,再以 \frac{sum_{ls}}{sum_{u}} 的概率进入左子,\frac{sum_{rs}}{sum_{u}} 的概率进入右子。

得出递推公式:E_u = sum_u + \frac{sum_{ls}}{sum_{u}} E_{ls} + \frac{sum_{rs}}{sum_{u}} E_{rs}

对于没有儿子的叶子节点,直接就停了:E_u = sum_u

将递推公式两边同时乘上 sum_usum_u E_u = sum_u ^ 2 + sum_{ls} E_{ls} + sum_{rs} E_{rs}

为了简化这个式子,设 F_u = sum_u E_u

带入得到 F_u = sum_u ^ 2 + F_{ls} + F_{rs}

用线段树维护一下就可以了。

就这样一层层递归,得到根节点 F_{root} = \sum sum_u^2

将这个值设为 sumSq = F_{root},则我们要的 E_{root} = \frac{sumSq}{sum_{root}}

注意到其实我们只用维护 sumSqF 其实不用每一个都要,这应该就是这道题巧妙的地方吧。

最后再求一个逆元,搞一个最大公因数就行了。

#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;

#define int __int128
const int MAXN = 1e5 + 5;
const int MOD = 998244353;

int n, m;
int a[MAXN];
int S; //根节点权值(原序列总和)
int sum_sq; //线段树所有节点的平方和

struct Node {
    int val;//节点区间和
    int tag;//懒标记
    int sq;//节点区间长度的平方
    int sv; //len * val
    int len;//节点区间长度
} tr[MAXN << 2];

inline int read() {
    int x = 0, f = 1; char ch = getchar();
    while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
    return x * f;
}

inline void print(int x) {
    if (x < 0) putchar('-'), x = -x;
    if (x > 9) print(x / 10);
    putchar(x % 10 + '0');
}

int qpow(int a, int b) {//快速幂求逆元
    int res = 1;
    a %= MOD;
    while (b) {
        if (b & 1) res = res * a % MOD;
        a = a * a % MOD;
        b >>= 1;
    }
    return res;
}

int gcd(int a, int b) {
    return b ? gcd(b, a % b) : a;
}

void pushup(int u, int l, int r) { //合并
    int ls = u << 1, rs = u << 1 | 1;
    sum_sq = (sum_sq - tr[u].val * tr[u].val);

    tr[u].val = (tr[ls].val + tr[rs].val);
    tr[u].sv = (tr[ls].sv + tr[rs].sv + (r-l+1) * tr[u].val);
    tr[u].sq = (tr[ls].sq + tr[rs].sq + (r-l+1)*(r-l+1));
    tr[u].len = (tr[ls].len + tr[rs].len);

    sum_sq = (sum_sq + tr[u].val * tr[u].val);//加上新的值
}

void put_tag(int u, int l, int r, int k, bool op) { //下放懒标记
// op=1:需要更新全局平方和;op=0:仅更新节点(下传时)
    if (k == 0) return;

    if(op) {
        int delta = 2 * k * tr[u].sv;
        int add = k * k * tr[u].sq;
        sum_sq = (sum_sq + delta + add);
    }

    tr[u].tag = (tr[u].tag + k);
    tr[u].val = (tr[u].val + (r-l+1) * k);
    tr[u].sv = (tr[u].sv + tr[u].sq * k);
}

void pushdown(int u, int l, int mid, int r) {
    if (tr[u].tag == 0) return;
    put_tag(u<<1, l, mid, tr[u].tag, 0);
    put_tag(u<<1|1, mid+1, r, tr[u].tag, 0);
    tr[u].tag = 0;
}

void build(int u, int l, int r) {
    tr[u].len = r - l + 1;
    tr[u].sq = tr[u].len * tr[u].len;
    tr[u].tag = 0;
    if (l == r) {
        tr[u].val = a[l];
        tr[u].sv = tr[u].len * tr[u].val;
        sum_sq = (sum_sq + tr[u].val * tr[u].val);
        S = (S + tr[u].val);
        return;
    }
    int mid = (l + r) >> 1;
    build(u << 1, l, mid);
    build(u << 1 | 1, mid + 1, r);
    pushup(u, l, r);
}

void change(int u, int l, int r, int ul, int ur, int k) {
    if (ul <= l && r <= ur) {
        put_tag(u, l, r, k, 1);
        return;
    }
    int mid = (l + r) >> 1;
    pushdown(u, l, mid, r);
    if (ul <= mid) change(u << 1, l, mid, ul, ur, k);
    if (ur > mid) change(u << 1 | 1, mid + 1, r, ul, ur, k);
    pushup(u, l, r);
}

signed main() {
    n = read(); m = read();
    S = 0;
    sum_sq = 0;
    for (int i = 1; i <= n; i++) a[i] = read();
    build(1, 1, n);

    while (m--) {
        int op = read();
        if (op == 2) {
            int p = sum_sq;
            int q = S;
            int g = gcd(p, q);
            p /= g; q /= g;
            //答案:sum_sq / S
            int inv_q = qpow(q, MOD - 2);
            int ans = p % MOD * inv_q % MOD;
            ans = (ans + MOD) % MOD;
            print(ans);
            putchar('\n');
        } else {
            int l = read(), r = read(), v = read();
            change(1, 1, n, l, r, v);
            S = (S + v * (r - l + 1)); //更新根节点的权值
        }
    }
    return 0;
}