题解 CF1172C2 【Nauuo and Pictures (hard version)】

· · 题解

裸dp

先只看一个“被喜欢的”位置,这个位置的初始值是 w

SA 为“被喜欢的”数之和,SB 为“不被喜欢的”数之和。

f_w[i][j][k] 表示:现在 SA=jSB=k,一个值为 w 、“被喜欢的”位置经过 i 次操作后的期望值。

边界情况:f_w[0][j][k]=w

转移:

  1. 下一次操作选到了当前这个位置。概率:\frac w{j+k}。转移到:f_{w+1}[i-1][j+1][k]
  2. 下一次操作选到了另一个“被喜欢的”位置。概率:\frac{j-w}{j+k}。转移到:f_w[i-1][j+1][k]
  3. 下一次操作选到了一个“不被喜欢的”位置。概率:\frac k{j+k}。转移到:f_w[i-1][j][k-1]

所以,f_w[i][j][k]=\frac w{j+k}f_{w+1}[i-1][j+1][k]+\frac{j-w}{j+k}f_w[i-1][j+1][k]+\frac k{j+k}f_w[i-1][j][k-1]

g_w[i][j][k] 表示“不被喜欢的”的对应状态,计算方式类似。

这样大约能过简单版。

优化

有两个优化:

  1. f_w[i][j][k]=wf_1[i][j][k]

    证明:

    假设已经证明了 $f_w[i-1][j][k]=wf_1[i-1][j][k]$,就可以归纳地证明 $f_w[i][j][k]=wf_1[i][j][k]$: $\begin{aligned}f_1[i][j][k]&=\frac 1{j+k}f_2[i-1][j+1][k]+\frac{j-1}{j+k}f_1[i-1][j+1][k]+\frac k{j+k}f_1[i-1][j][k-1]\\&=\frac2{j+k}f_1[i-1][j+1][k]+\frac{j-1}{j+k}f_1[i-1][j+1][k]+\frac k{j+k}f_1[i-1][j][k-1]\\&=\frac{j+1}{j+k}f_1[i-1][j+1][k]+\frac k{j+k}f_1[i-1][j][k-1]\end{aligned} \begin{aligned}f_w[i][j][k]&=\frac w{j+k}f_{w+1}[i-1][j+1][k]+\frac{j-w}{j+k}f_w[i-1][j+1][k]+\frac k{j+k}f_w[i-1][j][k-1]\\&=\frac{w(w+1)}{j+k}f_1[i-1][j+1][k]+\frac{w(j-w)}{j+k}f_1[i-1][j+1][k]+\frac {wk}{j+k}f_1[i-1][j][k-1]\\&=\frac{w(j+1)}{j+k}f_1[i-1][j+1][k]+\frac {wk}{j+k}f_1[i-1][j][k-1]\\&=wf_1[i][j][k]\end{aligned}

    还有一个比较简单但不那么严谨的理解方式:每一步期望的增量都与期望成正比。(这里被 _rqy 喷了,出题人就是菜,这个证明写不严谨。)

    这样的话就只用计算 f_1[i][j][k] 了。

  2. 注意到 i,\,j,\,k,\,m 有一些联系。实际上可以令 f'_w[i][j] 表示 f_w[m-i-j][SA+i][SB-j](这里的 SASB 都是未操作时的初始值)。

g'_1[i][j] 表示 g_w[m-i-j][SA+i][SB-j],计算方式类似。

总结

f'_1[i][j]=1\ (i+j=m) f'_1[i][j]=\frac{SA+i+1}{SA+SB+i-j}f'_1[i+1][j]+\frac{SB-j}{SA+SB+i-j}f'_1[i][j+1]\ (i+j<m) g'_1[i][j]=1\ (i+j=m) g'_1[i][j]=\frac{SA+i}{SA+SB+i-j}g'_1[i+1][j]+\frac{SB-j-1}{SA+SB+i-j}g'_1[i][j+1]\ (i+j<m)

“被喜欢的”位置答案是 w_if'_1[0][0],“不被喜欢的”位置答案是 w_ig'_1[0][0]

如果每次去算逆元就是 \mathcal O(n+m^2\log p),预处理出来就是 \mathcal O(n+m^2+m\log p)

#include <cstdio>
#include <algorithm>

using namespace std;

typedef long long ll;

const int N = 200010;
const int M = 3010;
const int mod = 998244353;

int qpow(int x, int y) //calculate the modular multiplicative inverse
{
    int out = 1;
    while (y)
    {
        if (y & 1) out = (ll) out * x % mod;
        x = (ll) x * x % mod;
        y >>= 1;
    }
    return out;
}

int n, m, a[N], w[N], f[M][M], g[M][M], inv[M << 1], sum[3];

int main()
{
    int i,j;

    scanf("%d%d", &n, &m);

    for (i = 1; i <= n; ++i) scanf("%d", a + i);

    for (i = 1; i <= n; ++i)
    {
        scanf("%d", w + i);
        sum[a[i]] += w[i];
        sum[2] += w[i];
    }

    for (i = max(0, m - sum[0]); i <= 2 * m; ++i) inv[i] = qpow(sum[2] + i - m, mod - 2);

    for (i = m; i >= 0; --i)
    {
        f[i][m - i] = g[i][m - i] = 1;
        for (j = min(m - i - 1, sum[0]); j >= 0; --j)
        {
            f[i][j] = ((ll) (sum[1] + i + 1) * f[i + 1][j] + (ll) (sum[0] - j) * f[i][j + 1]) % mod * inv[i - j + m] % mod;
            g[i][j] = ((ll) (sum[1] + i) * g[i + 1][j] + (ll) (sum[0] - j - 1) * g[i][j + 1]) % mod * inv[i - j + m] % mod;
        }
    }

    for (i = 1; i <= n; ++i) printf("%d\n", int((ll) w[i] * (a[i] ? f[0][0] : g[0][0]) % mod));

    return 0;
}