题解 P6834 【[Cnoi2020]梦原】

· · 题解

这个题目可以被划分为两部分来做:树形态的概率和树形态确定后的贡献。

容易发现,当树的形态确定后,问题就是 积木大赛 / 铺设道路 放到树上的版本。

和上述两个问题同样,此时的答案为 \sum \max \{a_x - a_{fa},0\}, x \in V

这样,本题的第二部分就解决了。

下面考虑本题的第一部分,树形态的概率。

观察上面第二部分的答案式,发现答案只与所有的父子关系有关,所以我们只需要考虑每个结点的父亲的概率,并不需要关心整棵树的样子。

根据题意,对于结点 i, 其父节点是在 [\max \{1, i - k\},i - 1] 随机选择的点。

那么当结点 i 的父节点为 x 时,对答案期望的贡献E(i \oplus x)=\dfrac{\max \{a_x - a_{x},0\}}{i - \max \{1, i - k\}}

答案为 \sum\limits_{i=1}^n {\sum\limits_{x=\max \{1, i - k\}}^{i-1}}{E(i \oplus x)}

算法一

可以通过暴力枚举每个结点及其可能的父节点,暴力统计答案,时间复杂度 O(n^2),可以获得 \mathrm{40pts}

(代码中的注释部分)

算法二

对于每一个结点 x 可以贡献的期望答案,实际上为其所有可能父节点 f 中,f 权值比 x 大时,获得 a_x-a_f 的答案,再乘上选择这个父节点的概率。

发现算法一的时间复杂度瓶颈在于枚举所有父节点,考虑用树状数组来优化这一过程。

用两棵树状数组,一棵统计a_f \le a_x 的父节点数量,一棵统计 \forall f,满足 a_f \le a_x 的权值和。

使用这两棵树状数组维护的信息替代算法一中的枚举父节点,即可获得满分。

Code

#include<bits/stdc++.h>
using namespace std;

#define int long long

template < typename Tp >
void read(Tp &x) {
    x = 0; int fh = 1; char ch = 1;
    while(ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
    if(ch == '-') fh = -1, ch = getchar();
    while(ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar();
    x *= fh;
}

const int mod = 998244353;
const int maxn = 1000000 + 7;

int n, k;
int a[maxn], b[maxn];
int s[maxn]; 

int lowbit(int x) {
    return x & (-x);
}

struct BIT { 
    int c[maxn], N;
    void build(int x) {N = x;}
    void add(int pos, int k) {
        for(int i = pos; i <= N; i += lowbit(i)) c[i] = ((c[i] + k) % mod + mod) % mod;
    }
    int query(int x) {
        int res = 0;
        while(x) {
            res = (res + c[x]) % mod;
            x -= lowbit(x);
        }
        return res;
    }
}T, T1;

void Init(void) {
    read(n); read(k);
    for(int i = 1; i <= n; i++) {
        read(a[i]);
        s[i] = s[i - 1] + a[i];
    }
}

int fpow(int x, int p){
    int res = 1;
    while(p) {
        if(p & 1) res = res * x % mod;
        x = x * x % mod; p >>= 1;
    }
    return res;
}

void Work(void) {
    int ans = 0;
    for(int i = 1; i <= n; i++) b[i] = a[i];
    sort(b + 1, b + n + 1);
    int M; M = unique(b + 1, b + n + 1) - b - 1;
    T.build(M), T1.build(M);
    for(int i = 1; i <= n; i++) a[i] = lower_bound(b + 1, b + M + 1, a[i]) - b;
    for(int i = 1; i <= n; i++) {
        T.add(a[i], b[a[i]]); T1.add(a[i], 1);
        int st;
        if(i - k - 1 >= 1) {
            T.add(a[i - k - 1], -b[a[i - k - 1]]);
            T1.add(a[i - k - 1], -1);
        }
        st=max(1ll, i - k);
        int p = i - st;
        int sum = T.query(a[i]), sour = T1.query(a[i]);
        sum = ((-sum + sour * b[a[i]] % mod) % mod + mod) % mod;
        sum = sum * fpow(p, mod - 2) % mod;
        ans = (ans + sum) % mod;
//      for(int j = st; j < i; j++) {
//          ans = (ans + (max(0ll, a[i] - a[j]) * fpow(p, mod - 2) % mod) % mod) % mod;
//      }
    }
    ans = (ans + b[a[1]]) % mod;
    printf("%lld\n", ans);
}

signed main(void) {
    Init();
    Work();
    return 0;
}