题解:AT_abc406_e [ABC406E] Popcount Sum 3

· · 题解

题解:AT_abc406_e [ABC406E] Popcount Sum 3

题意:求 \le N 中满足其恰好有 K 二进制位为 1 的所有数之和。

首先考虑一个特殊情况,若没有 \le N 的限制的答案是多少呢?我们记 F(i, j)i 个二进制位中选 j 个位填 1 的数的和,不难发现:

F(i, j) = (2 ^ i - 1) \times \dbinom{i - 1}{j - 1}

现实意义很简单,对于每个位 k,其如果选定为 1,则方案数为在 i - 1 位中选定 j - 1 位为 1 的方案数,其他情况不会产生贡献。因而单个位贡献为 2 ^ k \times \dbinom{i - 1}{j - 1}。因而提公因式后等比数列求和可得:

F(i, j) = \sum _ {k = 0} ^ {i - 1} {2 ^ k \times \dbinom{i - 1}{j - 1}} = (2 ^ i - 1) \times \dbinom{i - 1}{j - 1}

现在回到题目。加上了 \le N 的限制,其实本可以像数位 DP 那样做,但是似乎这种思路会好想一些?

我们记cnt(n, k)sum(n, k) 分别为 \le n 中满足 k 个二进制位为 1 的方案数、数之和。答案即为 sum(N, K)

考虑如何用已知的 F(n, k) 求出 sum(n, k)

首先,如果 n 的最高位 i 选了 1,后面的答案可以拆成第 i 位的贡献和第 1 位到第 i - 1 位的贡献,即 sum(n - 2 ^ i, k - 1) + 2^i \times cnt(n - 2 ^ i, k - 1),否则若不选 1,后面 i 位随便选,答案为 F(i, k)

cnt(n, k) 的方法同理可得 cnt(n, k) = cnt(n - 2 ^ i, k - 1) + \dbinom{i}{k}

实现方法有递归和递推,我觉得递归会好实现一些,cntsum 可以存一块。此外,由于不明原因,此代码过程量不开 __int128 无法 AC,欢迎 dalao 们指处错误。

#include <bits/stdc++.h>
using namespace std;
#define int unsigned long long int
const int mod = 998244353;
int qpow(int x, int y, const int mod) {
    int res = 1;
    for (; y; y >>= 1) {
        if (y & 1) res = res * x % mod;
        x = x * x % mod;
    } return res;
}
int fac[65], invfac[65];
void init() {
    fac[0] = 1; for (int i = 1; i <= 63; i++) fac[i] = fac[i - 1] * i % mod;
    for (int i = 0; i <= 63; i++) invfac[i] = qpow(fac[i], mod - 2, mod);
}
int comb(int n, int m) { return m > n ? 0 : (__int128)fac[n] * invfac[m] % mod * invfac[n - m] % mod; }
int chose(int i, int j) { return (__int128)((1ULL << i) - 1) * comb(i - 1, j - 1) % mod; }
struct node { int sum, cnt; };
node solve(int n, int k, int t) { // t 记录的是当前位数
    if (k > t + 1) return {0, 0};
    else if (n == 0) return {0, !k};
    if (n >> t) {
        node x = solve(n ^ (1ULL << t), k - 1, t - 1);
        return {(int)((chose(t, k) + x.sum) % mod + (__int128)x.cnt * (1ULL << t) % mod) % mod,
                (x.cnt + comb(t, k)) % mod};
    } else return solve(n, k, t - 1); // 还不是最高位
}
main() {
    ios::sync_with_stdio(false), cin.tie(0);
    init();
    int t, n, k;
    for (cin >> t; t; t--) {
        cin >> n >> k;
        int w = 63; while (!(n >> w & 1)) w--; // 求最高位
        cout << solve(n, k, w).sum % mod << '\n';
    }
    return 0;
}