CF1175G

· · 题解

CDQ 分治 + 斜率优化,\mathrm{O}(nk\log n)(应该是最好写的做法?不知道是不是错的)

按段数 dp,有 f_i=\min\limits_{0\le j<i}\{g_j+(i-j)\max\limits_{i\le k\le j}a_k\}

\max$ 很难处理,考虑 CDQ 分治,设当前处理 $[l,mid]$ 向 $[mid+1,r]$ 的转移。记 $mx_i=\begin{cases}\max\limits_{i<k\le mid}a_k&i\le mid\\\max\limits_{mid<k\le i}a_k&i>mid\end{cases}$,则 $f_i=\min\{g_j+(i-j)\max\{mx_j,mx_i\}\}

mx_j\le mx_if_i=\min\{mx_i\times j +g_j\}+mx_i\times i

从小到大枚举 i,同时不断加入若干一次函数 y=jx+g_j。注意到斜率递减,且查询的横坐标不增,可以用单调栈维护。

具体地,加入直线 C 时,考虑栈中最后两条直线 A,B,若 A,B 的交点横坐标 \ge B,C 的交点横坐标,则弹出 B

查询时,考虑栈中最后两条直线 A,B,若 A,B 的交点横坐标 \ge x,则弹出 B

mx_j> mx_if_i=\min\{mx_j\times i +g_j-mx_j\times j\}

从大到小枚举 i,其余同上。

#include <stdio.h>
#include <bits/stdc++.h>

using namespace std;

using db = double;
template <class T>
inline void chmin(T &x, const T &y) { if(x > y) x = y; }
#define rep(i, l, r) for(int i = l, i##end = r; i <= i##end; ++ i)
#define per(i, r, l) for(int i = r, i##end = l; i >= i##end; -- i)
char inputbuf[1 << 23], *p1 = inputbuf, *p2 = inputbuf;
#define getchar() (p1 == p2 && (p2 = (p1 = inputbuf) + fread(inputbuf, 1, 1 << 23, stdin), p1 == p2) ? EOF : *p1++)
inline int read() {
    int res = 0; char ch = getchar(); bool f = true;
    for(; ch < '0' || ch > '9'; ch = getchar())
        f &= ch != '-';
    for(; ch >= '0' && ch <= '9'; ch = getchar())
        res = res * 10 + (ch ^ 48);
    return f ? res : -res;
}
const int N = 2e4 + 15;

int g[N], f[N];
int a[N], mx[N];
int n, m, ql, qr;

struct seg {
    int k, b;
    int operator () (int x) {
        return k * x + b;
    }
    seg(int _k = 0, int _b = 0) {
        k = _k; b = _b;
    }
} Q[N];

db crx(seg a, seg b) {
    return 1.0 * (a.b - b.b) / (b.k - a.k);
}

void ins(seg x) {
    if(qr && Q[qr].k == x.k) {
        if(Q[qr].b > x.b) -- qr;
        else return ;
    }
    while(qr > 1 && crx(Q[qr - 1], Q[qr]) >= crx(Q[qr], x)) -- qr;
    Q[++ qr] = x;
}

int qry(int x) {
    while(qr > 1 && crx(Q[qr - 1], Q[qr]) >= x) -- qr;
    return qr ? Q[qr](x) : 1e9;
}

void solve(int l, int r) {
    if(l == r) return ;
    int mid = l + r >> 1, j;
    mx[mid] = 0; mx[mid + 1] = a[mid + 1]; 
    per(i, mid - 1, l) mx[i] = max(mx[i + 1], a[i + 1]);
    rep(i, mid + 2, r) mx[i] = max(mx[i - 1], a[i]);
    qr = 0; j = mid;
    rep(i, mid + 1, r) {
        for(; j >= l && mx[j] <= mx[i]; -- j)
            ins(seg(j, g[j]));
        chmin(f[i], qry(-mx[i]) + i * mx[i]);
    }
    qr = 0; j = l;
    per(i, r, mid + 1) {
        for(; j <= mid && mx[j] > mx[i]; ++ j)
            ins(seg(mx[j], g[j] - mx[j] * j));
        chmin(f[i], qry(i));
    }
    solve(l, mid); solve(mid + 1, r);
}

signed main() {
    n = read(); m = read();
    rep(i, 1, n) a[i] = read();
    rep(i, 1, n) f[i] = max(f[i - 1], a[i]);
    rep(i, 1, n) f[i] *= i;
    rep(k, 2, m) {
        memcpy(g, f, sizeof(f));
        rep(i, 0, n) f[i] = 1e9;
        solve(0, n);
    }
    printf("%d\n", f[n]);
    return 0;
}