题解:P11924 [PA 2025] 贪婪大盗 / Piracka Chciwość

· · 题解

首先 \mathcal{O}(n^2\log n) 是简单的,直接从后往前按题意模拟即可。

注意到 \max\{a_i\} 只有 64,我们能看出来什么呢?首先我以为是根据小差距去暴力维护某些东西,但我还是见的太少了。

我们发现每次只会选其中一半数给它投票,那么说明一次结果只有一般的人给了钱。那么说明给的钱最多也只有 64,因为下一个人完全可以选择另一半!这样我们就可以加速模拟过程了,我们需要每个数当前的钱数,然后用 a_i 在分一次类,即设 S_{i,j} 表示所有满足 b_u=i,a_u=ju 组成的集合。

每次从小到大枚举 k,即给的钱数,然后将要给这么多钱的人更新,但是这题非常恶心,要求找到编号更大的,考虑使用动态开点线段树,那么我们在第一次总数超过所需数的 k 通过二分的方式找到 mid 使得当前所有编号大于等于 mid 的数刚好凑够所需要的点数。

维护时需要线段树合并和线段树分裂,时间复杂度 \mathcal{O}(na^2\log n),实际跑的挺快。

代码:

#include <bits/stdc++.h>
#define rep(i, l, r) for (int i (l); i <= (r); ++ i)
#define rrp(i, l, r) for (int i (r); i >= (l); -- i)
#define eb emplace_back
using namespace std;
#define pii pair <int, int>
#define inf 1000000000
#define ls (p << 1)
#define rs (ls | 1)
constexpr int N = 5e4 + 5, M = 1e6;
typedef long long ll;
typedef unsigned long long ull;
inline int rd () {
  int x = 0, f = 1;
  char ch = getchar ();
  while (! isdigit (ch)) {
    if (ch == '-') f = -1;
    ch = getchar ();
  }
  while (isdigit (ch)) {
    x = (x << 1) + (x << 3) + (ch ^ 48);
    ch = getchar ();
  }
  return x * f;
}
int n, m;
int a[N], b[N];
// vector <int> vc[65][65], die;
int rt[65][65], die[65];
int sum[N << 7], lc[N << 7], rc[N << 7], tot;
void psu (int p) {
  sum[p] = sum[lc[p]] + sum[rc[p]];
}
void upd (int &p, int l, int r, int x) {
  if (! p) p = ++ tot;
  if (l == r) return ++ sum[p], void ();
  int mid (l + r >> 1);
  if (x <= mid) upd (lc[p], l, mid, x);
  else upd (rc[p], mid + 1, r, x);
  psu (p);
}
void merge (int &x, int y) {
  if (! x || ! y) return void (x |= y);
  merge (lc[x], lc[y]);
  merge (rc[x], rc[y]);
  psu (x);
}
void split (int p1, int &p2, int l, int r, int x) {
  if (! p1) return ;
  if (l == r) return ;
  if (! p2) p2 = ++ tot;
  int mid (l + r >> 1);
  if (x <= mid) {
    swap (rc[p1], rc[p2]);
    split (lc[p1], lc[p2], l, mid, x);
  } else {
    split (rc[p1], rc[p2], mid + 1, r, x);
  }
  psu (p1), psu (p2);
}
int R[65];
int solve (int l, int r, int nd) {
  if (l == r) return l;
  int mid (l + r >> 1);
  int s (0);
  rep (i, 0, 64) s += sum[rc[R[i]]];
  if (s >= nd) {
    rep (i, 0, 64) R[i] = rc[R[i]];
    return solve (mid + 1, r, nd);
  }
  else {
    rep (i, 0, 64) R[i] = lc[R[i]];
    return solve (l, mid, nd - s);
  }
}
void dfs (int p, int l, int r, int k) {
  if (l == r && sum[p]) {
    return b[l] = k, void ();
  }
  int mid (l + r >> 1);
  if (sum[lc[p]]) dfs (lc[p], l, mid, k);
  if (sum[rc[p]]) dfs (rc[p], mid + 1, r, k); 
}
int32_t main () {
  // freopen ("1.in", "r", stdin);
  // freopen ("1.out", "w", stdout);
  n = rd (), m = rd ();
  rep (i, 1, n) a[i] = rd ();
  int nxt (n);
  rrp (i, 1, n - 1) {
    int len (n - i + 1);
    int cnt = 1;
    rep (i, 0, 64) cnt += sum[die[i]];
    if (cnt * 2 >= len) {
      rep (i, 1, 64) rep (j, 0, 64) {
        merge (rt[0][j], rt[i][j]); rt[i][j] = 0;
      } 
      upd (rt[0][a[nxt]], 1, n, nxt);
      rep (j, 1, 64) merge (rt[0][j], die[j]), die[j] = 0;
      b[i] = m; nxt = i;
    } else {
      int s1 (0);
      rep (k, 0, 64) {
        int now (0);
        if (b[nxt] + a[nxt] == k) ++ now;
        rep (j, 0, k) {
          now += sum[rt[j][k - j]];
        } 
        if ((cnt + now) * 2 < len) s1 += now * k, cnt += now;
        else {
          int ln = now;
          now = (len - cnt * 2 + 1) / 2;
          s1 += now * k; 
          cnt += now;
          if (s1 > m) break;
          int ret;
          if ((cnt + ln - now) * 2 <= len + 1) ret = nxt;
          else {
            memset (R, 0, sizeof R);
            rep (j, 0, k) R[j] = rt[j][k - j];
            ret = solve (1, n, now);
          }
          int tmp[65]; memset (tmp, 0, sizeof tmp);
          rep (i, 0, 64) {
            rep (j, 0, 64) {
              if (i + j > k) {
                merge (tmp[j], rt[i][j]); rt[i][j] = 0;
              } else {
                if (i + j == k) {
                  int lef (rt[i][j]);
                  int all = sum[lef];
                  rt[i][j] = 0;
                  split (lef, rt[i][j], 1, n, ret - 1);
                  assert (sum[lef] + sum[rt[i][j]] == all);
                  all = sum[tmp[j]] + sum[lef];
                  merge (tmp[j], lef);
                  assert (all == sum[tmp[j]]);
                }
              }              
            }
          }
          rrp (i, 0, 64) {
            rrp (j, 0, 64) {
              merge (rt[i + j][j], rt[i][j]);
              rt[i][j] = 0;
            }
          }
          rep (i, 0, 64) merge (rt[0][i], tmp[i]);
          if (a[nxt] + b[nxt] < k) upd (rt[a[nxt] + b[nxt]][a[nxt]], 1, n, nxt);
          else if (a[nxt] + b[nxt] == k && ret == nxt) upd (rt[a[nxt] + b[nxt]][a[nxt]], 1, n, nxt);
          else {
            upd (rt[0][a[nxt]], 1, n, nxt);
          }
          b[i] = m - s1;
          nxt = i;
          rep (i, 0, 64) merge (rt[0][i], die[i]), die[i] = 0;
          break;
        }
      }
      if (s1 > m) upd (die[a[i]], 1, n, i);
    }
  }
  rep (i, 1, 64) dfs (die[i], 1, n, -1);
  rep (i, 0, 64) rep (j, 0, 64) dfs (rt[i][j], 1, n, i);
  rep (i, 1, n) printf ("%d ", b[i]);
}