题解:AT_abc352_g [ABC352G] Socks 3

· · 题解

首先不难发现:

ANS = \sum_{i = 0}^{n - 1} (i + 1) \times P(i) = \sum_{i = 0}^{n - 1} P(x \geq i)

其中:

P(x \geq k) = \frac{B_1 \times B_2 \times \dots \times B_k \times k!}{S \times (S - 1) \times \dots (S - k + 1)}

这里,S=\sum A_i

分母很好理解,就是总共的方案数,然后 B 就对应了我们选的颜色的分别的袜子的数量。

所以这个式子我们要想计算,就需要计算出原数组这 N 个元素中选 k 个求积的和。考虑 DP,dp_{i,j} 表示前 i 个元素中选了 j 个,那转移很简单:

dp_{i,j} = dp_{i - 1, j - 1} \times A_i + dp_{i - 1, j}

其中 dp_{0,0} 初始化为 1

但这样时间复杂度就炸了。观察发现我们 DP 的过程,其实我们就是在每次对原本的这个式子乘上 (A_i \times x + 1) 这个多项式,所以我们 DP 的过程其实就等价于求:

\prod_{i = 1}^{N} A_i

这个式子的各项系数,直接使用分治 NTT 即可解决,时间复杂度 \mathrm{O}(N \log^2 N)

代码:

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int mod = 998244353;
const int G = 3;
int fastpow(int a, int b, int mod)
{
    int res = 1;
    while (b) {
        if (b & 1) {
            res = (res * a) % mod;
        }
        a = (a * a) % mod;
        b >>= 1;
    }
    return res;
}
int rev[1 << 22];
void change(vector<int> &a, int len)
{
    for (int i = 0; i < len; i++) {
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (len >> 1) : 0);
    }
    for (int i = 0; i < len; i++) {
        if (i < rev[i]) {
            swap(a[i], a[rev[i]]);
        }
    }
}
void ntt(vector<int> &a, int len, int x)
{
    change(a, len);
    for (int h = 2; h <= len; h <<= 1) {
        int omega = fastpow(G, (mod - 1) / h, mod);
        if (x == -1) {
            omega = fastpow(omega, mod - 2, mod);
        }
        for (int i = 0; i < len; i += h) {
            int w = 1;
            for (int j = i; j < i + h / 2; j++) {
                int u = a[j];
                int v = a[j + h / 2] * w % mod;
                a[j] = (u + v) % mod;
                a[j + h / 2] = (u - v + mod) % mod;
                w = (w * omega) % mod;
            }
        }
    }
    if (x == -1) {
        int inv = fastpow(len, mod - 2, mod);
        for (int i = 0; i < len; i++) {  
            a[i] = (a[i] * inv) % mod;
        }
    }
}
vector<int> convo(vector<int> a, vector<int> b)
{
    if (a.empty() || b.empty()) {
        return {0};
    }
    int m = 1;
    while (m < a.size() + b.size() - 1) {  
        m <<= 1;
    }
    a.resize(m);
    b.resize(m);
    ntt(a, m, 1);
    ntt(b, m, 1);
    for (int i = 0; i < m; i++) {
        a[i] = (a[i] * b[i]) % mod;
    }
    ntt(a, m, -1);
    return a;
}
vector<int> solve(vector<int>& a, int l, int r)
{
    if (l > r) {
        return {1};  
    }
    if (l == r) {
        return {1, a[l]}; 
    }
    int mid = (l + r) / 2;
    return convo(solve(a, l, mid), solve(a, mid + 1, r));
}
signed main()
{
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    int n;
    cin >> n;
    vector<int> a(n);
    int cnt = 0;
    for (int i = 0; i < n; i++) {
        cin >> a[i];
        cnt += a[i];
    }
    vector<int> res = solve(a, 0, n - 1);
    int ans = 1, s1 = 1, s2 = 1;
    for (int i = 1; i <= n && i < res.size(); i++) {  
        s1 = (s1 * i) % mod;
        s2 = s2 * (cnt - i + 1) % mod;
        ans = (ans + (res[i] * s1) % mod * fastpow(s2, mod - 2, mod)) % mod;
    }
    cout << (ans + mod) % mod << '\n';
    return 0;
}

AC记录