P12251 [科大国创杯初中组 2025] 抽卡

· · 题解

花了 40mins 过掉这个题。这个题给初中生做真的大丈夫???

首先瞪一下对一个序列怎么求答案,你发现就是维护一个小根堆,每次把一个数插入两次,然后弹掉堆顶。最后堆里所有元素的和就是答案。

我们先把 a_i 做个前缀和,这样比较方便描述。然后我们算方案数,最后除掉 \prod a_i 即可。

那么套路地,枚举一个阈值 w,转 01,设 f(i,j) 代表已经考虑前 i 次解锁,目前堆里有 j1 的方案数。这样就做到了 O(n^2m)。不难观察到这是一个以 a_i 为分段点的 O(n) 次的分段多项式,容易拉插做到 O(n^4)

但是然后怎么办?这个做法看起来并没有什么优化前途。我们不得不转换视角:我们发现,对于 w > a_i 的部分,我们的转移非常简单:就是对应项乘上 a_i。而满足 w > a_i 的部分是一个前缀——这给了我们一个启示:把 DP 倒过来。

把 DP 的过程看成 DAG 上的游走后,我们当然可以把整个 DAG 的边全部反过来,这样,我们就会先走若干 w \le a_i 的边,然后走若干 w > a_i 的边。前面这一部分不再是一个分段多项式:它整个就是个多项式。而后面这部分的系数我们可以在 DP 外乘上去。这样,我们可以写出如下暴力算法:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int P = 1e9 + 7;
int n, m, a[505];
ll ans, f[505][505];
void F(int w) {
    memset(f, 0, sizeof(f));
    for (int i = 0; i <= n; i++) f[n][i] = i;
    for (int i = n; i; i--) {
        for (int j = 0; j <= i; j++) f[i][j] %= P;
        for (int j = 0; j < i; j++) {
            f[i - 1][j] += f[i][j] * (w - 1);
            f[i - 1][j] += f[i][j + 2 - (j == i - 1)] * (a[i] - w + 1);
        }
    }
}
ll QPow(ll a, ll b) {
    ll res = 1;
    for (; b; b >>= 1, a = a * a % P) if (b & 1) res = res * a % P;
    return res;
}
int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) scanf("%d", a + i), a[i] += a[i - 1];
    for (int i = 1; i <= n; i++) {
        for (int j = a[i - 1] + 1; j <= a[i]; j++) {
            F(j);
            ll t = f[i - 1][0];
            for (int k = 1; k < i; k++) t = t * a[k] % P;
            ans += t;
        }
    }
    ans %= P;
    for (int i = 1; i <= n; i++) ans = ans * QPow(a[i], P - 2) % P;
    printf("%lld\n", ans);
    return 0;
}

接下来的部分就顺理成章了。我们对这个 DP 的前缀和进行插值即可。注意前缀和是 (n+1) 次而不是 n 次。复杂度 O(n^3)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 501, P = 1e9 + 7;
int n, m, a[505];
ll ans, f[505][505], g[505][505];
void F(int w) {
    memset(f, 0, sizeof(f));
    for (int i = 0; i <= n; i++) f[n][i] = i;
    for (int i = n; i; i--) {
        for (int j = 0; j <= i; j++) f[i][j] %= P;
        for (int j = 0; j < i; j++) {
            f[i - 1][j] += f[i][j] * (w - 1);
            f[i - 1][j] += f[i][j + 2 - (j == i - 1)] * (a[i] - w + 1);
        }
    }
    for (int i = 0; i <= n; i++) g[i][w] = f[i][0] % P;
}
ll QPow(ll a, ll b) {
    ll res = 1;
    for (; b; b >>= 1, a = a * a % P) if (b & 1) res = res * a % P;
    return res;
}
ll Calc(int u, int x) {
    ll res = 0;
    for (int i = 0; i <= N; i++) {
        ll a = 1, b = 1;
        for (int j = 0; j <= N; j++) {
            if (i == j) continue;
            a = a * (x - j) % P;
            b = b * (i - j) % P;
        }
        res += g[u][i] * a % P * QPow(b, P - 2) % P;
    }
    return (res % P + P) % P;
}
int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) scanf("%d", a + i), a[i] += a[i - 1];
    for (int i = 0; i <= N; i++) F(i);
    for (int i = 0; i <= n; i++) {
        for (int j = 1; j <= N; j++) (g[i][j] += g[i][j - 1]) %= P;
    }
    for (int i = 1; i <= n; i++) {
        ll t = Calc(i - 1, a[i]) - Calc(i - 1, a[i - 1]) + P;
        for (int k = 1; k < i; k++) t = t * a[k] % P;
        ans += t;
    }
    ans %= P;
    for (int i = 1; i <= n; i++) ans = ans * QPow(a[i], P - 2) % P;
    printf("%lld\n", ans);
    return 0;
}