AT_arc150_f [ARC150F] Constant Sum Subsequence

· · 题解

让所有 S 的正整数拆分都包含在最短前缀中。考虑到 S 比较小。那么定义 f_i 代表包含 i 的所有正整数拆分的最短前缀。

如何转移 f_i?我们枚举正整数拆分的最后一个数 j,那么有 f_i=\max\{f_i,nxt(f_{i-j},j)\}。这样有 S^2i\to j 的转移,我们可以优化掉许多没用的。

我们直接分治转移,具体模仿 cdq 分治,先递归 [l,mid]f 计算出来,再让 f_l,f_{l+1},\dots,f_{mid} 去转移 f_{mid+1},f_{mid+2},\dots,f_r。最后递归 [mid+1,r]

现在考虑快速让 f_l,f_{l+1},\dots,f_{mid} 去转移 f_{mid+1},f_{mid+2},\dots,f_r。我们枚举 j 范围从 1r-l,然后把 f_l,f_{l+1},\dots,f_{mid} 放在长度为 n\times n 的 A 序列上。若数字 j 存在于 f_kf_{k+1} 之间,那么 nxt(f_k,j)\leq f_{k+1}\leq f_{mid}\leq f_{l\dots r},也就是 f_l,f_{l+1},\dots,f_k 是没有必要转移的。那么我们找到最后一次出现的 j,若其在 f_{k'}f_{k'+1} 之间,那么我们从 f_{k'+1} 开始去转移。又发现 f_{k'+1} 之后的 nxt(f,j) 均相同等于 nxt(f_{mid},j),那么可以直接将 f_{k'+1+j}f_{mid+k} 这一段全部对 nxt(f_{mid},j)\max 即可。

我们枚举 j 的总复杂度为 O(S\log S),故复杂度为两只 log。

::::info[代码]

#include <bits/stdc++.h>
using namespace std;

#define PII pair<int, int>
#define _for(i, a, b) for (int i = (a); i <= (b); i++)
#define _pfor(i, a, b) for (int i = (a); i >= (b); i--)
#define int long long

const int N = 1e6 + 5e5 + 5;

int n, S, a[N], f[N]; 
vector<int> pos[N];

int nxt(int x, int y) {
  int t = x % n, cnt = x / n * n;
  if (t >= pos[y].back()) return cnt + pos[y].front() + n;
  return cnt + *upper_bound(pos[y].begin(), pos[y].end(), t);
}

int pre(int x, int y) {
  int t = x % n, cnt = x / n * n;
  if (t <= pos[y].front()) return cnt - n + pos[y].back();
  return cnt + *(--lower_bound(pos[y].begin(), pos[y].end(), t));
}

struct edge {
  struct tt {
    int l, r, maxn;
  }tree[N * 4];
  void build(int p, int l, int r) {
    tree[p].l = l, tree[p].r = r;
    if (l == r) return;
    int mid = (l + r) >> 1;
    build(p << 1, l, mid);
    build(p << 1 | 1, mid + 1, r);
  }
  void modify(int p, int l, int r, int v) {
    if (l <= tree[p].l && tree[p].r <= r) {
      tree[p].maxn = max(tree[p].maxn, v);
      return;
    }
    int mid = (tree[p].l + tree[p].r) >> 1;
    if (l <= mid) modify(p << 1, l, r, v);
    if (r > mid) modify(p << 1 | 1, l, r, v);
  } 
  int query(int p, int x) {
    if (tree[p].l == tree[p].r) return tree[p].maxn;
    int mid = (tree[p].l + tree[p].r) >> 1, res = tree[p].maxn;
    if (x <= mid) res = max(res, query(p << 1, x));
    else res = max(res, query(p << 1 | 1, x));
    return res;
  }
}tr;

void solve(int l, int r) {
  if (l >= r) return;
  int mid = (l + r) >> 1;
  solve(l, mid);
  _for(k, 1, r - l) {
    int R = min(mid, r - k);
    int L = lower_bound(f + l, f + mid + 1, pre(f[R] + 1, k)) - f;
    if (f[L] < pre(f[R] + 1, k)) continue;
    L = max(mid + 1, L + k); R += k;
    tr.modify(1, L, R, nxt(f[R - k], k));
  } 
  _for(i, mid + 1, r) f[i] = max(f[i], tr.query(1, i));
  solve(mid + 1, r);
}

signed main() {
  cin >> n >> S;
  tr.build(1, 1, S);
  _for(i, 1, n) cin >> a[i], pos[a[i]].push_back(i);
  solve(0, S);
  cout << max(f[S], 1ll) << endl;
}

::::