篮球比赛题解

· · 题解

出题人题解。

正解

定义 dp_i 为:第 i 场比赛胜利的概率,则可得到如下代码(暴力):

int sum = 0;
for (int i = 1; i <= n; i++) {
    for (int j = 0; j <= k; j++)
        f[i] = (f[i] + aa[j] * qpow(i, j) % mod) % mod;
    sum = (sum + f[i]) % mod;
}
for (int i = 1; i <= n; i++)
    dp[i] = f[i] * qpow(sum, mod - 2) % mod;
for (int i = 1; i <= n; i++)
    for (int j = 1; j <= m; j++)
        dp[i + j] = (dp[i + j] + dp[i] * p[j] % mod) % mod;
sum = 0;
for (int i = 1; i <= n; i++)
    sum = (sum + dp[i]) % mod;
cout << sum;

其中 k(i) 即为 \dfrac{f(i)}{\sum_{j=1}^n f(j)}

最终期望胜利场数即为 \sum_{i=1}^ndp_i

由于这种转移时间复杂度很慢,所以考虑矩阵乘法。

首先,求出 \sum_{i=1}^n f(i)。定义 sum(i)=\sum_{j=1}^i f(j),从而构造初始矩阵:

\begin{pmatrix} sum_{i-1} & 1 & i & i^2 & \dots & i^k \\ \end{pmatrix}

接着根据二项式展开构造 base 矩阵:

\begin{pmatrix} 1 & 0 & 0 & 0 &\dots\\ a_0 & C_0^0 & C_1^0 & C_2^0 & \dots\\ a_1 & 0 & C_1^1 & C_2^1 & \dots\\ a_2 & 0 & 0 & C_2^2 & \dots\\ a_3 & 0 & 0 & 0 & \dots\\ \vdots\\ \end{pmatrix}

两个矩阵相乘,可得出:

\begin{pmatrix} sum_i & 1 & {i+1} & (i+1)^2 & \dots & (i+1)^k \\ \end{pmatrix}

满足矩阵乘法。由于要用到分数取模形式,所以设 inv\sum_{i=1}^n f(i) 的逆元。

接着求出 dp_i,设 s_i=\sum_{j=1}^{i-1}dp_j,由于 dp_i=\sum_{j=1}^m dp_{i-j}\times p_j(当 i-j 小于 0 时,dp_{i-j} 可视为 0),所以可以得出初始矩阵:

\begin{pmatrix} s_i & dp_i & dp_{i-1} & dp_{i-2} & \dots & dp_{i-m+1} & 1 & i & i^2 & \dots i^k\\ \end{pmatrix}

接着构造 base 矩阵:

\begin{pmatrix} 1 & 0 & 0 & 0 & 0 & \dots & 0 & 0 & 0 & 0 & 0 & \dots\\ 1 & p_1 & 1 & 0 & 0 & \dots & 0 & 0 & 0 & 0 & 0 & \dots\\ 0 & p_2 & 0 & 1 & 0 & \dots & 0 & 0 & 0 & 0 & 0 & \dots\\ 0 &p_3 & 0 & 0 & 1 & \dots & 0 & 0 & 0 & 0 & 0 & \dots\\ \vdots\\ 0 & p_{m-1} & 0 & 0 & 0 & \dots &1 & 0 & 0 & 0 & 0 & \dots\\ 0 & p_m & 0 & 0 & 0 & \dots & 0 & 0 & 0 & 0 & 0 & \dots\\ 0 & inv\times a_0 & 0 & 0 & 0 &\dots & 0 & C_0^0 & C_1^0 & C_2^0 & C_3^0 & \dots\\ 0 & inv\times a_1 & 0 & 0 & 0 &\dots & 0 & 0& C_1^1 & C_2^1 & C_3^1 & \dots\\ 0 & inv\times a_2 & 0 & 0 & 0 &\dots & 0 & 0 & 0 & C_2^2 & C_3^2 & \dots\\ 0 & inv\times a_3 & 0 & 0 & 0 &\dots & 0 & 0 & 0 & 0 & C_3^3 & \dots\\ \vdots\\ \end{pmatrix}

两个矩阵相乘得到:

\begin{pmatrix} s_{i+1} & dp_{i+1} & dp_i & dp_{i-1} & dp_{i-2} & \dots & dp_{i-m+2} & 1 & i+1 & (i+1)^2 & \dots & (i+1)^k\\ \end{pmatrix}

Code

by 2023gdgz01
#include <cstdio>
#include <cstring>

long long n, m, k, inv, maxmk, p[55], A[55], c[55][55];

struct matrix {
    long long c[105][105];
    inline matrix operator* (const matrix &r) {
        matrix temp;
        memset(temp.c, 0, sizeof(temp.c));
        for (register int i = 1; i <= maxmk; ++i)
            for (register int j = 1; j <= maxmk; ++j)
                for (register int y = 1; y <= maxmk; ++y)
                    temp.c[i][j] = (temp.c[i][j] + c[i][y] * r.c[y][j] % 998244353) % 998244353;
        return temp;
    }
};

matrix a, ans;

inline int max(int x, int y) {
    if (x > y)
        return x;
    return y;
}

inline void matrix_quick_power(matrix a, long long b) {
    while (b) {
        if (b & 1)
            ans = ans * a;
        a = a * a;
        b >>= 1;
    }
}

inline long long quick_power(long long a, long long b) {
    long long ans = 1;
    while (b) {
        if (b & 1)
            ans = ans * a % 998244353;
        a = a * a % 998244353;
        b >>= 1;
    }
    return ans % 998244353;
}

int main() {
    scanf("%lld%lld%lld", &n, &m, &k);
    for (register int i = 1; i <= m; ++i)
        scanf("%lld", p + i);
    for (register int i = 0; i <= k; ++i)
        scanf("%lld", A + i);
    maxmk = max(m, k) + 1 << 1;
    for (register int i = 0; i <= k; ++i) {
        c[i][0] = 1;
        for (register int j = 1; j <= i; ++j)
            c[i][j] = (c[i - 1][j - 1] + c[i - 1][j]) % 998244353;
    }
    for (register int i = 2; i <= k + 2; ++i)
        ans.c[1][i] = 1;
    a.c[1][1] = 1;
    for (register int i = 2; i <= k + 2; ++i) {
        a.c[i][1] = A[i - 2];
        for (register int j = i; j <= k + 2; ++j)
            a.c[i][j] = c[j - 2][i - 2];
    }
    matrix_quick_power(a, n);
    inv = quick_power(ans.c[1][1], 998244351);
    memset(ans.c, 0, sizeof(ans.c));
    memset(a.c, 0, sizeof(a.c));
    for (register int i = m + 2; i <= m + k + 2; ++i)
        ans.c[1][i] = 1;
    a.c[1][1] = a.c[2][1] = 1;
    for (register int i = 2; i <= m + 1; ++i) {
        a.c[i][2] = p[i - 1];
        if (i != m + 1)
            a.c[i][i + 1] = 1;
    }
    for (register int i = m + 2; i <= m + k + 2; ++i) {
        a.c[i][2] = inv * A[i - m - 2] % 998244353;
        for (register int j = i; j <= m + k + 2; ++j)
            a.c[i][j] = c[j - m - 2][i - m - 2];
    }
    matrix_quick_power(a, n + 1);
    printf("%lld", ans.c[1][1]);
    return 0;
}