P8275 [USACO22OPEN] 262144 Revisited P

· · 题解

首先有一个显然的区间 DP:设 f_{i,j} 为区间 [i,j] 的权值,则 f_{i,j} = \min \limits_{i \leq k < j} \{ \max(f_{i,k}, f_{k+1,j}) + 1\}

进一步地,由于 f_{i,k} 关于 k 单调不降,f_{k+1,j} 关于 k 单调不升,因此我们只需要二分出最大的 k' 使得 f_{i,k'} \leq f_{k' + 1,j},然后在 k \in \{k',k'+1\} 中做决策即可,也可以利用决策单调性做到 \mathcal{O}(n^2)

但到此为止,直接 DP 似乎已经看不到前途了。我们可能需要找到一种更优的计算方式。

考虑一种无脑的合并方式:每次将区间划分成长度尽量相同的两部分,然后归并。这样我们可以得到答案的一个上界 \max a_i + \lceil \log n \rceil。既然答案的上界只有 \mathcal{O}(A),我们不妨转变一下思路,考虑依次枚举 v,然后计算权值 \geq v 的区间个数。

由于 f_{i,k} 关于 k 单调不降,我们考虑对于每个左端点 i 维护最大的右端点 r_i 使得 [i,r_i) 的权值不大于 v。维护的方式如下:初始时 r_i = i,当 v-1 \to v 时,我们依次枚举 i,令 r_i \gets r_{r_i},然后对于所有 a_i = v 的位置 i,令 r_i \gets i+1

我们接下来说明为什么上面的做法是对的。当 i = r_i 时正确性显然。否则考虑 [i,r_i)[r_i,r_{r_i}) 这两个区间,不妨设它们分别为区间 \mathcal{A},\mathcal{B}。当 \mathcal{A}\mathcal{B} 的权值都是 v-1 正确性显然。否则若 \mathcal{A} 的权值小于 v - 1,那么由 \mathcal{A} 的极大性,有 a_{r_i} \geq v-1。若 a_{r_i} \geq v,则此时 r_{r_i} = r_i,否则 a_{r_i} = v-1,两者的正确性都是显然的。当 \mathcal{B} 的权值小于 v-1 时同理。

于是我们得到了一个 \mathcal{O}(nA) 的做法。目前而言,算法的效率上似乎没有什么实质性的进展,我们还需要寻找加速的方法。

不妨再关注一下 f 的单调性。如果一个区间的 f\leq v,那么其所有子区间的 f 值都 \leq v

这句话听上去是一句废话,但我们可以做一个容斥,每次改为计算权值 < v 的区间个数。这时考虑一类特殊的区间:我们称一个区间是 [l,r] 是 “极大的”,当且仅当其向左或向右扩展都会使 f 值增大。

对于所有权值为 v-1 的极大区间,显然它们不会互相包含。于是我们要求的就变成了这样的一个问题:给定 L 个左右端点单调增的区间,求它们覆盖的区间数。这可以通过简单容斥做到 \mathcal{O}(L)

也就是说,如果我们直接在枚举 v 的过程中维护极大区间的集合,就能够避免每次 \mathcal{O}(n) 的枚举。但从直觉上看,极大区间的数量似乎还是可能很多,复杂度究竟能优化多少还有待商榷。

为此,我们需要证明如下引理:

引理:设 f(n) 表示长度为 n 的序列的极大区间的数量,则 f(n)=O(n \log n)

证明:我们可以证明,f(n) \leq f(\log n!) = f(\log 1 + \cdots + \log n) \leq f(n \log n)

考虑原序列的笛卡尔树,设其中一个最大元素的位置为 p,则有:

f(n) \leq f(p-1) + f(n - p) + \text{val}(p)

其中 \text{val}(p) 为原序列中包含位置 p 的极大区间数。不妨设 2p \leq n,我们可以证明,\text{val}(p) \leq \mathcal{O}(p \log \frac{n}{p})

具体来说,我们设 \text{val}_{k}(p) 表示包含 p,且权值为 a_p + k 的极大区间数,则有 \text{val}_k(p) \leq \min(p,2^k),其中 0 \leq k \leq \lceil \log_2 n \rceilp 的限制是因为权值相同的极大区间左端点位置一定不同,而 2^k 的限制是因为权值从 a_p + k - 1a_p + k 必然会经过一次扩展,每次扩展我们可以选择向左或是向右,而我们需要从 a_p 开始扩展 k 次。

考虑对所有 \text{val}_i(p) 求和。注意到,当 0 \leq k < \lceil \log_2 p \rceil 时,\min(p,2^k) = 2^k,这部分 \text{val}_i(p) 的和为 \mathcal{O}(p)。当 \lceil \log_2p \rceil \leq k \leq \lceil \log_2 n \rceil 时,\min(p,2^k) = p,这部分 \text{val}_i(p) 的和为 \mathcal{O}(p \log \frac{n}{p})。于是命题得证。

对于引理的证明,我们考虑归纳。由 \mathcal{O}(p \log \frac{n}{p}) \leq \mathcal{O}(\log\frac{n}{p} + \log\frac{n-1}{p-1} + \cdots + \log \frac{n-p+1}{1}) \leq \mathcal{O}(\log \binom{n}{p}) \leq \mathcal{O}(\log \frac{n!}{(p-1)!(n-p)!}),则有

\begin{aligned} f(n) &\leq f(p-1) + f(n-p) + \text{val}(p) \\ &\leq \mathcal{O}(\log(p-1)!) + \mathcal{O}(\log(n-p)!) + \mathcal{O}(\log n! - \log(p-1)! - \log (n-p)!) \\ &\leq \mathcal{O}(n!) \end{aligned}

则原命题得证。

接下来只需要考虑如何维护极大区间。对当前的权值 v,我们称一个区间 [l,r) 是一个段,当且仅当其中所有元素都不大于 v,且 \min(a_{l-1},a_r) > v。对于当前的所有极大区间,其要么权值为 v,要么是一个段。

我们将所有极大区间储存在其所属的段 [l,r) 的左端点的集合 \text{seg}_l 内。当 v-1 \to v 时,我们依次执行以下步骤:对所有需要更新的段,更新极大区间并重新计算它们的贡献。合并过后,段内所有极大区间的权值变为 v。接着在所有 a_i = v 的位置添加一个极大区间 [i,i+1),然后将相邻的段合并。

使用双向链表维护段的合并,总时间复杂度为 \mathcal{O}(A + n \log n)。具体细节见代码。

#include <bits/stdc++.h>
using namespace std;
#define fi first
#define se second
typedef vector <int> vi;
typedef pair <int, int> pi;
typedef long long LL;
const int N = 3e5 + 5, M = 1e6 + 35;
int n, a[N], pre[N], nex[N]; LL cur;
vi p[M];
list <pi> seg[N];
LL sum(int l, int r) {
    return 1LL * (l + r) * (r - l + 1) / 2;
}
LL calc(list <pi> L) {
    LL ret = 0;
    for (auto it = L.begin(); it != L.end(); it++) {
        auto [x, y] = *it;
        if (next(it) == L.end()) {
            ret += sum(1, y - x);
            break;
        } else {
            int nx = next(it) -> fi;
            ret += 1LL * (nx - x) * y - sum(x, nx - 1);
        }
    }
    return ret;
}
void upd(list <pi> &L) {
    if (L.size() <= 1) 
        return;
    cur += calc(L);
    int mx = -1;
    list <pi> nL;
    auto it = L.begin();
    for (auto [x, y] : L) {
        while (next(it) != L.end() && next(it) -> fi <= y)
            it++;
        if (it -> se > mx) {
            nL.push_back({x, mx = it -> se});
        }
    }
    swap(L, nL);
    cur -= calc(L);
}
int main() {
    ios :: sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
        p[a[i]].push_back(i);
    }
    cur = 1LL * n * (n + 1) / 2;
    for (int i = 1; i <= n; i++) 
        pre[i] = nex[i] = i;
    LL ans = 0;
    vi q;
    for (int v = 1; cur > 0; v++) {
        ans += cur;
        vi nq;
        for (auto l : q) {
            upd(seg[l]);
            if (seg[l].size() > 1) nq.push_back(l);
        }
        for (auto i : p[v]) {
            int l = pre[i], r = nex[i + 1];
            bool add = seg[l].size() <= 1;
            nex[l] = r, pre[r] = l;
            seg[l].push_back({i, i + 1});
            cur--;
            seg[l].splice(seg[l].end(), seg[i + 1]);
            add &= seg[l].size() > 1;
            if (add) nq.push_back(l); 
        }
        swap(q, nq);
    }
    cout << ans << "\n";
    return 0;   
}