题解 P6836 【[IOI2020]装饼干】

· · 题解

这题我的方法比较奇怪。

题意:

k种物品,第i个物品有a_i个,权值为2^i

求有多少个y,使得可以选出x组物品,每组的和都为y

先考虑如何判定一个y是否可行:

从最高位开始,依次求出第i位需要的数目b_i。若y的第i位为1,则b\leftarrow b+x

如果b_i \leq a_i,那么说明a_i够用,进入下一位。

如果b_i>a_i,则选上所有a_i后,还剩b_i-a_i个。那么这些2^i只能用两倍的2^{i-1}来凑。因此把b_{i-1}加上2 \times (b_i-a_i)

这样,只要b_0<a_0,则说明这个y可行。

可以发现,若b_i \leq a_i,那么对于所有j<ij不会受之前的影响。

因此,可以把b_i \leq a_i作为分界点,进行DP。

dp_i表示使得b_i \leq a_i的方案数目。只考虑大于等于i的位。那么,dp_0就是答案。

枚举j>i作为前一个分界点。那么,对于所有i<c<j,都要求b_c>a_c

再枚举所有的c,那么b_c就可以很容易地用这个y[c,j-1]这些数位上的值来表示。

于是,对于每个c,可以得出一条不等式。把这些不等式联立,就能得到y[i,j)这些数位上的范围d_{i,j}

因此,dp_i\leftarrow dp_i+d_{i,j}\times dp_j

这样做的复杂度是O(k^3q)的,可以过。

代码:

#include <stdio.h>
#include <vector>
#include "biscuits.h"
using namespace std;
#define ll long long
ll dp[70];
ll count_tastiness(ll x, vector<ll> sz) {
    int k = sz.size();
    for (int i = 0; i < 62 - k; i++) sz.push_back(0);
    k = 62;
    for (int i = k; i >= 0; i--) {
        if (i == k) {
            dp[i] = 1;
            continue;
        }
        dp[i] = 0;
        for (int j = i + 1; j <= k; j++) {
            ll zx = 0, zd = (1ll << (j - i)) - 1, h = 0;
            for (int a = j - 1; a >= i; a--) {
                h = h * 2 + sz[a];
                ll z = (h / x + 1) << (a - i);
                if (a > i) {
                    if (z > zx)
                        zx = z;
                } else {
                    if (z - 1 < zd)
                        zd = z - 1;
                }
            }
            if (zx <= zd)
                dp[i] += dp[j] * (zd - zx + 1);
        }
    }
    return dp[0];
}

不难发现,求d_{i,j}的过程可以优化。提前预处理出d_{i,j},就可以O(k^2)了。

#include <stdio.h>
#include <vector>
#include "biscuits.h"
using namespace std;
#define ll long long
ll dp[70], zz[70][70], dd[70][70];
ll count_tastiness(ll x, vector<ll> sz) {
    int k = sz.size();
    for (int i = 0; i < 62 - k; i++) sz.push_back(0);
    for (int j = 1; j <= 62; j++) {
        ll zx = 0, h = 0;
        for (int a = j - 1; a >= 0; a--) {
            zz[a][j] = zx;
            h = h * 2 + sz[a];
            ll z = (h / x + 1) << a;
            if (z > zx)
                zx = z;
            dd[a][j] = z;
        }
    }
    for (int i = 62; i >= 0; i--) {
        if (i == 62) {
            dp[i] = 1;
            continue;
        }
        dp[i] = 0;
        for (int j = i + 1; j <= 62; j++) {
            ll zx = (zz[i][j] >> i), zd = (1ll << (j - i));
            if ((dd[i][j] >> i) < zd)
                zd = (dd[i][j] >> i);
            if (zx <= zd)
                dp[i] += dp[j] * (zd - zx);
        }
    }
    return dp[0];
}