题解:CF850F Rainbow Balls

· · 题解

题面

现有 n 种颜色,第 i 种颜色有 a_i 种,每次按顺序选出 2 种颜色,将第 2 种颜色变成第 1 种颜色。当所有的颜色相同时,结束操作。求出所需操作的期望值。

推理

选定一个颜色 x 为最后的颜色,设 f_i 表示有 i 个该颜色的球时到达全为 x 颜色的状态所需要的期望步数。则有:

p 为现在颜色为 x 的球有 i 个时,选出两个球,一个球为 x 而另二个球不是 x 的概率,设 m 为总球数,易得:

\begin{equation}p = \frac{i \times (m-i)}{m \times (m - 1)}\end{equation}

那么,f_i 的转移只有三种情况:

除此之外,还需加上此步对答案的贡献,即 在当前局面下 x 成为留到最后颜色的概率,此处设为 k

可得转移式:

\begin{equation} f_i = p \times f_{i - 1} + p \times f_{i + 1} + (1 - 2p) \times f_i + k \end{equation}

k_i 为现在有 i 个颜色为 x 的球时,x 成为最后唯一颜色的概率, 则有:

\begin{aligned} k_i = p \times k_{i - 1} + p \times k_{i + 1} + (1 - 2p) \times k_i \end{aligned}

分析方式与 f_i 同理,继续整理可得:

\begin{equation}k_i - k_{i - 1} = k_{i + 1} - k_i\end{equation}

同时,k_0 = 0k_m = 1(因为在颜色数为 0 时不可能继续变大,颜色数为 m 时已经达成状态)。

\begin{aligned} \therefore k_m - k_0 = \sum_{i=1}^{m}k_i - k_{i - 1} \end{aligned}

代入 (3) 得:

k_m - k_0 &= m \times (k_m - k_{m - 1}) \\ m \times k_{m - 1} &= (m - 1) \times k_m + k_0 \\ k_{m - 1} &= \frac{(m - 1) \times k_m + k_0}{m} \\ \end{aligned}

代入 \because k_m = 1k_0 = 0 得:

\begin{aligned} k_{m - 1} = \frac{m - 1}{m} \end{aligned}

代入 (3) 依次得k_{m - 2} = \frac{m - 2}{m}k_{m - 2} = \frac{m - 3}{m}

也就是:

\begin{aligned} k_i = \frac{i}{m} \end{aligned}

代入 (2) 得:

f_i &= p \times f_{i - 1} + p \times f_{i + 1} + (1 - 2p) \times f_i + \frac{i}{m} \\ p \times f_i - p \times f_{i - 1} &= p \times f_{i + 1} - p \times f_i + \frac{i}{m} \\ \end{aligned}

代入 (1),并消元:

f_i - f_{i - 1} = f_{i + 1} - f_i + \frac{m - 1}{m - i} \end{equation}

又可得:f_0 不存在,f_m = 0。 所以,f_2 = 2f_1 - 1,则:

f_1 &= f_1 - f_s \\ f_1 &= \sum_{i = 1}^{s - 1} f_{i + 1} - f_{i} \\ f_1 &= (s - 1) \times (f_1 - f_2) + (s - 2) \times (s - 1) \\ f_1 &= (s - 1) \times (1 - f_1) + (s - 2) \times (s - 1) \\ s \times f_1 &= (s - 1) ^ 2 \\ f_1 &= \frac{(s - 1)^2}{s} \\ \end{aligned}

接着由 f_2 =2f_1 - 1 即可推出 f_2,再由 (4) 即可推出任意的 f_i

最终的答案就是 \sum_{i = 1}^{n}f_{a_i}(由于 有限个随机变量之和的数学期望等于每个随机变量的数学期望之和,答案即每种颜色成为最终颜色的期望步数之和)。

code

#include <bits/stdc++.h>
#define PII pair <int, int>
#define int long long
#define ST string
#define DB double

#define fr(x, y, z) for(int x = y; x <= z; x ++ )
#define dfr(x, y, z) for(int x = y; x >= z; x -- )

using namespace std;

const int N = 100010, MOD = 1e9 + 7;
int n, s, mx, a[N], f[N];

int qp(int x, int y)
{
    int res = 1;
    while(y)
    {
        if(y & 1) res = res * x % MOD;
        x = x * x % MOD; y >>= 1;
    }
    return res;
}

int add(int x, int y)
{ return (x % MOD + y % MOD) % MOD; }

int mul(int x, int y)
{ return x % MOD * y % MOD; }

int ovr(int x, int y)
{ return x % MOD * qp(y, MOD - 2) % MOD; }

int sub(int x, int y)
{ return (x % MOD + MOD - y % MOD) % MOD; }

signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);

    cin >> n;
    fr(i, 1, n) cin >> a[i];
    fr(i, 1, n) s = (s + a[i]) % MOD;
    fr(i, 1, n) mx = max(mx, a[i]);

    f[1] = ovr(mul(s - 1, s - 1), s);
    f[2] = sub(f[1] * 2, 1);
    fr(i, 2, mx) f[i + 1] = sub(f[i] * 2, add(f[i - 1], ovr(s - 1, s - i)));

    int res = 0;
    fr(i, 1, n) res = add(res, f[a[i]]);
    cout << res << '\n';

    return 0;
}

/*
   2f[1] = f[2] + 1
   f[1] = (s - 1)(f[1] - f[2]) + (s - 2)(s - 1)
-> f[1] = (s - 1) ^ 2 / s

*/