题解:P15574 [USACO26FEB] Milk Buckets S

· · 题解

看了官方题解,感觉题解没有很好的解释每一步是什么来的,也可能是我太菜了,所以我来写一篇详细推导的题解。

设一个桶的容量为 a

第一个桶:每秒接受 1 加仑牛奶,需要 a_1 秒填满。a_1 + 1 时开始反转,a_1 + 2 恢复接收,开始接下一轮。

一轮的时间是 b_1 = a_1 + 1 秒。每隔 b_1 秒,它就会向下倒出 a_1 的牛奶。

第二个桶每次接受 a_1 加仑牛奶,容量为 a_2,需要 \lceil \frac{a_2}{a_1} \rceil 次才可以接满,一共需要 b_2 = \lceil \frac{a_2}{a_1} \rceil \cdot b_1 秒。所以一轮的时间是 b_2 = \lceil \frac{a_2}{a_1} \rceil \cdot b_1

以此类推。

i 个桶只能等第 i-1 个桶倒水。

它需要等待的次数是 r_i = \lceil \frac{a_i}{a_{i-1}} \rceil

所以第 i 个桶的翻转的一轮时间为:b_i = \lceil \frac{a_i}{a_{i-1}} \rceil \cdot b_{i-1}

1 个桶在第 b_1 秒倒水。第 2 个桶在第 b_2 秒刚好接到最后一次倒水,此时它刚好装满。题目规定装满后的下一秒开始翻转。所以第 2 个桶是在第 b_2 + 1 秒完成倒水。这就产生了一秒钟的延迟。同理,第 3 个桶会在第 b_3 + 2 秒刚好装满,在第 b_3 + 2 秒完成倒水。又多了一秒延迟。

以此类推,到了最后一个桶(第 n 个桶),它的延迟是 n - 1,它完成第一次倒水(把奶倒入池子)的时间点是 b_n + (n - 1) 秒。

所以第 n 个桶在第 k \cdot b_n + (n - 1) 时刻发生第 k 次翻转。

为了计算池子中的牛奶数,我们应该计算在总时间 t 内,第 n 个桶翻转了多少次。

我们需要找到满足条件的最大翻转次数 k,即:

k \cdot b_n + (n - 1) \le t

移项:

k \cdot b_n \le t - (n - 1) k = \lfloor \frac{t - (n - 1)}{b_n} \rfloor

如果 t < n - 1k0,因为没有发生任何翻转。

每次,第 n 个桶往池子中倒入 a_n 加仑牛奶。

最终池子有 a_n \cdot k = a_n \cdot \lfloor \frac{t - (n - 1)}{b_n} \rfloor 加仑牛奶。

但是直接 for 肯定会 TLE,不 TLE 也会爆 long long,因为一轮的时间 b_i 是乘法级增长的。

我们观察 r_i = \lceil \frac{a_i}{a_{i-1}} \rceil,发现当 a_i \le a_{i - 1} 时一轮翻转的时间 b_i = \lceil \frac{a_i}{a_{i-1}} \rceil \cdot b_{i-1} = 1 \cdot b_{i-1} = b_{i - 1},一轮翻转的时间根本没有增长。如果 a_i > a_{i - 1} 呢?那么 r_i 至少是 \ge 2 的!

回到题目,题目给定的 t 最大是 10^{18}。这意味着如果某一个桶的一轮翻转时间 b_i > 10^{18},那没有任何牛奶倒到水池中,此时答案是 0

那一轮的时间最多翻转 \log_2(10^{18}) \le 60 次,就会 \ge 10^{18}

所以,从上往下的推导中,真正会让某一个桶的翻转时间变大的,即 a_i > a_{i-1} 的桶,最多只会有 60 个。如果大于 60 个,那么答案为 0

对于 a_i \le a_{i-1} 的桶,对一轮翻转的时间没有影响,我们可以跳过它们。

我们可以使用 std::set 维护 a_i > a_{i-1} 的桶,每次修改第 x 个桶都只会影响 xx-1 的关系和 x + 1x 的关系。

对于每一次查询,我们只需要遍历 set 即可。

代码:

#include <iostream>
#include <set>
#include <vector>

int main() {
    int n;
    std::cin >> n;
    std::vector<long long> a(n + 2, 0);
    for (int i = 1; i <= n; i++) { std::cin >> a[i]; }
    std::set<long long> s;
    for (int i = 2; i <= n; i++) {
        if (a[i] > a[i - 1]) s.insert(i);
        else s.erase(i);
    }

    int q;
    std::cin >> q;
    while (q--) {
        int i;
        long long v, t;
        std::cin >> i >> v >> t;
        a[i] = v;
        if (i >= 2) {
            if (a[i] > a[i - 1]) s.insert(i);
            else s.erase(i);
        }
        if (i + 1 >= 2 && i + 1 <= n) {
            if (a[i + 1] > a[i]) s.insert(i + 1);
            else s.erase(i + 1);
        }

        long long ans = std::max(0ll, t - (n - 1));
        ans /= a[1] + 1;
        for (int i : s) {
            if (ans == 0) break;
            long long r = (a[i] + a[i - 1] - 1) / a[i - 1];
            ans /= r;
        }
        ans *= a[n];
        std::cout << ans << std::endl;
    }
}