题解 P4370 【[Code+#4]组合数问题2】

· · 题解

题解 P4370 【[Code+#4]组合数问题2】

P4370 题面

首先我们结合杨辉三角:

                      1
                    1    1
                 1     2    1
              1     3    3    1
           1     4     6    4    1
                     ......

我们发现了一个规律(大家应该都是一眼发现吧): n \choose \lfloor \frac {n} {2} \rfloor 总是在一行中最大的。然后向两边递减。

我们想到了贪心!那就是最开始先加入所有 n \choose \lfloor \frac {n} {2} \rfloor 到一个优先队列中,然后当取出一个值的时候,向左(或向右)扩展一个数,加入进去。

不断重复 k 次,就是 \max (Ans) 了。

做完了。

做完了?

为什么呢?大家应该都有这样的体验:超级大的数对 mod = 10^9 + 7 取模之后就 比较不了大小了 。所以呢,这道题也一样。

其实这道题的关键就是这个“大数比较”的方法: 取对数

有: \log ab = \log a + \log b

还有诸多对数运算就不一一提及了,这样之后,我们就成功地将大数运算转换成了取对数意义下地较小的数的运算。

这个底数理论上是取任何有意义的数都行的,但是为了保持精度,还是老老实实取一些小一点的数为底数吧。

看看代码?

code

#include <bits/stdc++.h>
using namespace std;
const int N = 1000010;
const long long mod = 1e9 + 7;
const long double INF = 1e18;
struct node {
    int a, b; // $a choose b$
    int fx; // -1 lr 1 l 2 r
    long double num;
    bool operator < (const node & tmp) const {
        return num < tmp.num;
    }
};
priority_queue <node> pq;
int n, k;
long long Ans = 0;
long long fac[N + 10], ifac[N + 10];
long double logfac[N + 10];
long long qpow (long long a, long long b) {
    long long ans = 1LL;
    while (b) {
        if (b & 1) ans = ans * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return ans;
}
long long C (int a, int b) {
    if (a < 0 || b < 0 || b > a) return 0;
    else return fac[a] * ifac[a - b] % mod * ifac[b] % mod;
}
long double logC (int a, int b) {
    return logfac[a] - logfac[a - b] - logfac[b];
}
void debug (int x) {
    for (int i=0;i<=x;i++) {
        for (int j=0;j<=i;j++) cout << C (i, j) << " ";
        cout << endl;
    }
    cout << endl;
}
int main () {
    scanf ("%d %d", &n, &k);
    fac[0] = 1LL; logfac[0] = 0; for (int i=1;i<=n+3;i++) {
        fac[i] = fac[i - 1] * i * 1LL % mod;
        logfac[i] = logfac[i - 1] + log (i * 1.0);
    }
    for (int i=0;i<=n+3;i++) {
        ifac[i] = qpow (fac[i], mod - 2);
    }
    for (int i=1;i<=n;i++) {
        pq.push ((node) {i, i / 2, -1, logC (i, i / 2)});
    }
//  debug (n);
    while (k --) {
        node now = pq.top (); pq.pop ();
        Ans = (Ans + C (now.a, now.b)) % mod;
        if (now.b - 1 >= 0 && now.b - 1 <= n && (now.fx == -1 || now.fx == 1)) pq.push ((node) {now.a, now.b - 1, 1, logC (now.a, now.b - 1)});
        if (now.b + 1 >= 0 && now.b + 1 <= n && (now.fx == -1 || now.fx == 2)) pq.push ((node) {now.a, now.b + 1, 2, logC (now.a, now.b + 1)});
    }
    cout << Ans << endl;
}