P14231 复读机 / repeat 题解

· · 题解

题意

一个长度为 n 的序列 a_i

处理 q 次查询,每次给出一个区间 [l,r]k,需要在在区间中选择一个长为 k 的子序列,最小化相邻两项的和的最大值。

思路

出题人用 3\times 10^5 卡 2log?不存在的。

和楼上一样,对于一个数 v,取出最长的相邻两项和 v 的子序列,如果其长度 \ge k 那么答案就 \le v

定义 2a_i \le va_i 为小数,反之为大数,那么所有小数都可选,最优情况为小数全选,然后对于相邻的两个小数,它们之间最小的大数可选就选上。

二分答案,考虑能取出多少个数。如果从小到大变化所求的最大值 v,会不断有大数变成小数,并且出现新的大数,每个数的贡献在答案值域上看是一个后缀,我们可以求出这个后缀。

具体讲,对于数 a_k,找到它前面第一个小于等于它和它后面第一个小于它的数,即 a_i\le a_ka_k>a_j,那么 a_k 会在答案变化到 \max(a_i,a_j)+a_k 时第一次产生贡献(成为大数),可用单调栈求出,记为 x_k

因此对于每次查询,直接二分答案 mid,考虑怎么 check,此时区间 [l,r]x_i\le mid 的数都会有贡献,可以用主席树求,但对于靠近端点的两个小数 a_{pos1}a_{pos2},可能会导致区间 [l,pos1-1][pos2+1,r] 有新的大数,用 st 表查区间 min,单独判断一下是否满足即可。求这两个小数可以二分。

时间复杂度:二分答案 + 主席树和二分求位置,O(n\log^2n)

写完一交,最后一个子任务 TLE,卡 2log 是吧。

卡常技巧:

  1. 快读
  2. st 表两维交换
  3. 二分使用 while(L < R) 的写法
  4. 把询问离线下来,按左右端点从小到大排序后再依次处理(这个很玄学,但真的有用)

再稍微卡卡就过了,但并不保证用了这几个技巧就能过。

我的 AC 代码再交都不一定能过。

#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
using namespace std;

// 快读

const int N = 3e5 + 10, LN = 30;

int n, q;
int a[N], mav;
int lx[N], rx[N];
int x_[N]; // 产生贡献最小值
int stk[N], top;
int lg[N];
int miv[LN][N], ln;
int rt[N];
int tree[N * 42];
int ls[N * 42], rs[N * 42], num;
int ans[N];

struct qry {
    int l, r, k, id;
    bool operator < (qry x) {
        if (l == x.l) return r < x.r;
        return l < x.l;
    }
} b[N];

int update(int v, int l, int r, int x) {
    int u = ++num;
    ls[u] = ls[v], rs[u] = rs[v], tree[u] = tree[v] + 1;
    if(l == r) {
        return u;
    }
    int mid = l + r >> 1;
    if(x <= mid) ls[u] = update(ls[v], l, mid, x);
    else rs[u] = update(rs[v], mid + 1, r, x);
    return u;
}

int query(int u, int v, int l, int r, int x) {
    if(!v)
        return 0;
    if(r <= x) {
        return tree[v] - tree[u];
    }
    if(l == r) return 0;
    int mid = l + r >> 1;
    if(mid <= x) return tree[ls[v]] - tree[ls[u]] + query(rs[u], rs[v], mid + 1, r, x);
    return query(ls[u], ls[v], l, mid, x);
}

int getmin(int l, int r) {
    if(l > r) return 1e9;
    int p = lg[r - l + 1];
    return min(miv[p][l], miv[p][r - (1 << p) + 1]);
}

int main() {
//  freopen("ex_repeat6.in", "r", stdin);
//  freopen("ex_repeat.out", "w", stdout);

    io.read(n), io.read(q);
    for(int i = 1; i <= n; i++) {
        io.read(a[i]);
        mav = max(mav, a[i]);
        miv[0][i] = a[i];
    }

    stk[top = 1] = 0;
    for(int i = 1; i <= n; i++) {
        while(top && a[stk[top]] > a[i]) top--;
        lx[i] = a[stk[top]];
        stk[++top] = i;
    }
    stk[top = 1] = n + 1;
    for(int i = n; i >= 1; i--) {
        while(top && a[stk[top]] >= a[i]) top--;
        rx[i] = a[stk[top]];
        stk[++top] = i;
    }

    for(int i = 1; i <= n; i++) {
        x_[i] = max(lx[i], rx[i]) + a[i];
    }

    for(int i = 2; i <= n; i++) {
        lg[i] = lg[i >> 1] + 1;
    }

    ln = lg[n];
    for(int j = 1; j <= ln; j++) {
        int i = 1;
        for(; i + (1 << j) - 1 <= n; i++) {
            miv[j][i] = min(miv[j - 1][i], miv[j - 1][i + (1 << j - 1)]);
        }
        for(; i <= n; i++) {
            miv[j][i] = miv[j][i - 1];
        }
    }

    rt[0] = ++num;
    for(int i = 1; i <= n; i++) {
        rt[i] = update(rt[i - 1], 1, mav << 1, x_[i]);
    }

    for (int i = 1; i <= q; i++) {
        io.read(b[i].l), io.read(b[i].r), io.read(b[i].k);
        b[i].id = i;
    }

    sort(b + 1, b + 1 + q);

    int l, r, k, L, R, res, mid, L_, R_, posl, posr, mid_;
    for (int i = 1; i <= q; i++) {
        l = b[i].l, r = b[i].r, k = b[i].k;
        L = 1, R = mav << 1, res = mav << 1;
        while(L < R) {
            mid = L + R >> 1;
            int rr = 0;
            L_ = l, R_ = r + 1, posl = l, posr = r;
            while(L_ < R_) {
                mid_ = L_ + R_ >> 1;
                if((getmin(l, mid_) << 1) <= mid) {
                    R_ = mid_;
                } else {
                    L_ = mid_ + 1;
                }
            }
            posl = L_;
            L_ = l, R_ = r;
            while(L_ < R_) {
                mid_ = L_ + R_ + 1 >> 1;
                if((getmin(mid_, r) << 1) <= mid) {
                    L_ = mid_;
                } else {
                    R_ = mid_ - 1;
                }
            }
            posr = L_;
            if(posl ^ (r + 1))
                rr = query(rt[posl - 1], rt[posr], 1, mav << 1, mid) + (getmin(l, posl - 1) + a[posl] <= mid) + (getmin(posr + 1, r) + a[posr] <= mid);
            if(rr >= k) {
                R = mid;
            } else {
                L = mid + 1;
            }
        }
        res = L;
        ans[b[i].id] = res;
    }

    for (int i = 1; i <= q; i++) {
        io.write(ans[i], '\n');
    }

    return 0;
}