P7804 [JOI Open 2021] Financial Report

· · 题解

首先,题目描述过于抽象,本质是求 A 的一个子序列,使得这个子序列包含 A_n,并且子序列中相邻的元素位置距离 \le D,并且从前往后破纪录元素尽量多,求破纪录元素个数。

只有 A 相对顺序有用,所以先将 A 离散化。

考虑 dp,其中 dp[i]i 为破纪录元素时能包含几个破纪录元素。i 肯定为破纪录元素,所以我们需要枚举在 i 前面的第一个破纪录元素 j。则 A_j<A_i。而为了保证存在一个子序列包含 ij,还满足相邻元素距离 \le D,必须存在 j<p_1<p_2<\dots<p_x<i 使得

则可以取所有满足条件的 j 并且求 dp[i]=\max(dp[j])+1。直接 dp 可以做到 O(n^2)。而发现由于 i:1\to nA_i 的大量变化,维护相邻元素距离约束相对困难,于是考虑对 A_i 扫描线。

按照 A_i 从小到大更新 dp[i],则计算某个 dp[i] 的时候肯定只有已经被更新的 dp[j] 能影响 dp[i]

我们维护一个 set 包含所有已经被更新的位置:这些位置上的值均小于 A_i。维护一个并查集,其中 ij 连通当且仅当 ij 被更新并且 j-i\le D,则可以从并查集的连通块信息得出任何 i 的最左端 j 使得相邻元素距离约束成立。

然后用并查集得到 j 之后用线段树维护 dp 区间最大值,总共时间复杂度 O(n\log n)

#include <bits/stdc++.h>
using namespace std;

#define all(a) a.begin(), a.end()
#define fi first
#define se second
#define pb push_back
#define mp make_pair

using ll = long long;
using pii = pair<int, int>;
//#define int ll
const int MOD = 1000000007;

int las[300005];

int ufs(int u) {
  if (las[u] == -1)
    return u;
  return las[u] = ufs(las[u]);
}
void conn(int u, int v) {
  u = ufs(u), v = ufs(v);
  if (u != v)
    las[max(u, v)] = min(u, v);
}

int seg[300005 << 2];
void upd(int idx, int l, int r, int u, int v) {
  if (!(l <= u && u < r))
    return;
  if (r - l == 1) {
    seg[idx] = v;
    return;
  }
  upd(idx * 2, l, (l + r) / 2, u, v);
  upd(idx * 2 + 1, (l + r) / 2, r, u, v);
  seg[idx] = max(seg[idx * 2], seg[idx * 2 + 1]);
}
int qry(int idx, int l, int r, int L, int R) {
  if (L <= l && r <= R)
    return seg[idx];
  if (R <= l || r <= L)
    return 0;
  return max(qry(idx * 2, l, (l + r) / 2, L, R),
             qry(idx * 2 + 1, (l + r) / 2, r, L, R));
}
signed main() {
  ios_base::sync_with_stdio(false);
  cin.tie(0);
  memset(las, -1, sizeof las);
  int n, d;
  cin >> n >> d;
  vector<pii> x(n);
  for (int i = 0; i < n; i++)
    cin >> x[i].fi, x[i].se = -i;
  sort(all(x));
  set<int> on;
  int ans = 0;
  for (auto [_, i] : x) {
    i = -i;
    auto pos = on.insert(i).fi;
    if (pos != on.begin() && *pos - *prev(pos) <= d)
      conn(*pos, *prev(pos));
    if (next(pos) != on.end() && *next(pos) - *pos <= d)
      conn(*pos, *next(pos));
    int dp = qry(1, 0, n, ufs(i), i) + 1;
    ans = max(ans, dp);
    upd(1, 0, n, i, dp);
  }
  cout << ans << endl;
}