题解:P6669 [清华集训 2016] 组合数问题

· · 题解

感觉自己很傻。

Solution

直接考虑 Lucas 定理,\binom{n}{m} \equiv \binom{\lfloor\frac{n}{p}\rfloor}{\lfloor\frac{m}{p}\rfloor}\binom{n\bmod p}{m \bmod p} \pmod p,那么我们就可以将 nm 转为 p 进制,然后数位 dp 即可。

套路的,我们设 f_{i, 0/1, 0/1, k} 表示前 i 位,n' 是否卡着限制,m' 是否卡着限制,\binom{n'}{m'}\bmod pk 的方案数。

转移是简单的,时间复杂度 O(Tk^3\log_k{n})

考虑一个很呆的优化:因为 p 为质数,所以只需要判断是否存在 \binom{n'}{m'} = 0 即可。所以我们直接将最后一维改为 0/1 表示是否 \text{}\bmod p0 即可。

时间复杂度 O(Tk^2\log_k{n})

:::success[AC Code]

#include <bits/stdc++.h>
using namespace std;
#define x first
#define y second
#define mp(Tx, Ty) make_pair(Tx, Ty)
#define For(Ti, Ta, Tb) for(auto Ti = (Ta); Ti <= (Tb); Ti++)
#define Dec(Ti, Ta, Tb) for(auto Ti = (Ta); Ti >= (Tb); Ti--)
#define debug(...) fprintf(stderr, __VA_ARGS__)
#define range(Tx) begin(Tx),end(Tx)
typedef unsigned long long ull;
const int N = 65, M = 105, mod = 1e9 + 7;
int k;
long long n, m;
ull f[2][2][2][2];
int C[M][M];
void Add(ull &x, ull y) {
    x = (x + y) % mod;
}
ull quickpow(ull a, ull b, int mod) {
    ull res = 1;
    while (b) {
        if (b & 1) res = res * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return res;
}
ull work(long long n, long long m) {
    ull ans = 0;
    ull A = m % mod, B = max(0ll, m - n) % mod, N = (A - B + 1 + mod) % mod;
    ans = mod - (A + B) % mod * N % mod * quickpow(2, mod - 2, mod) % mod;
    vector<int> numn, numm;
    while (n) numn.push_back(n % k), n /= k;
    while (m) numm.push_back(m % k), m /= k;
    int len = max(numn.size(), numm.size());
    while (numn.size() < len) numn.push_back(0);
    while (numm.size() < len) numm.push_back(0);
    reverse(range(numn)), reverse(range(numm));
    int flag = 0;
    memset(f, 0, sizeof(f));
    f[0][1][1][0] = 1;
    For(i, 0, len - 1) {
        memset(f[flag ^ 1], 0, sizeof(f[flag ^ 1]));
        For(l, 0, 1) {
            For(a, 0, k - 1) {
                For(b, 0, k - 1) {
                    int V = (l || C[a][b] == 0);
                    if (a == numn[i] && b == numm[i]) Add(f[flag ^ 1][1][1][V], f[flag][1][1][l]);
                    if (a == numn[i]) Add(f[flag ^ 1][1][0][V], f[flag][1][0][l]);
                    if (a == numn[i] && b < numm[i]) Add(f[flag ^ 1][1][0][V], f[flag][1][1][l]);
                    if (b == numm[i]) Add(f[flag ^ 1][0][1][V], f[flag][0][1][l]);
                    if (b == numm[i] && a < numn[i]) Add(f[flag ^ 1][0][1][V], f[flag][1][1][l]);
                    if (a < numn[i] && b < numm[i]) Add(f[flag ^ 1][0][0][V], f[flag][1][1][l]);
                    if (a < numn[i]) Add(f[flag ^ 1][0][0][V], f[flag][1][0][l]);
                    if (b < numm[i]) Add(f[flag ^ 1][0][0][V], f[flag][0][1][l]);
                    Add(f[flag ^ 1][0][0][V], f[flag][0][0][l]);
                }
            }
        }
        flag ^= 1;
    }
    For(i, 0, 1) For(j, 0, 1) Add(ans, f[flag][i][j][1]);
    return ans;
}
int main() {
    cin.tie(nullptr)->sync_with_stdio(false);
    int T = 1;
    cin >> T >> k; 
    while (T--) {
        cin >> n >> m;
        For(i, 0, M - 1) {
            C[i][0] = 1;
            For(j, 1, i) C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % k;
        }
        cout << work(n, m) << '\n';
    }
    return 0;
}

:::