P6934 [ICPC2017 WF] Posterize 题解

· · 题解

可以发现 k, r, d \leq 256p 很大,所以可以考虑本题的时间复杂度仅和 k, r, d 有关且不劣于 \mathrm{O}(n^3), n = 256.

考虑 dp.

dp[i][cnt] 表示 v[1] < v[2] < \cdots < v[i] = cnt, v[i + 1] = v[i + 2] = \cdots = v[k] = +\infty 时的最小误差。(可以理解为我们目前只有 cntv[i]

dp[i][cnt] = \min_{j < i}\{dp[j][cnt - 1] + \sum_{r[k] > \frac{i + j}{2}} (r[k] - i)^2 - (r[k] - j)^2\}

因为 \left | r[k] - i \right | < \left | r[k] - j \right | \Leftrightarrow r[k] > \dfrac{i + j}{2}

由于枚举 i, j, cnt 已经是 \mathrm{O}(n^3) 时间复杂度的了,故需要在 \mathrm{O}(1) 的时间复杂度内计算出 \displaystyle \sum_{r[k] > \frac{i + j}{2}} (r[k] - i)^2 - (r[k] - j)^2.

由于

\sum_{r[k] > \frac{i + j}{2}} (r[k] - i)^2 - (r[k] - j)^2 = (2j - 2i) sum[\frac{i + j}{2}][0] + (i^2 - j^2) sum[\frac{i + j}{2}][1]

其中

sum[x][0] = \sum_{r[k] > x} r[k] sum[x][1] = \#\{r[k] > x \ | \ k\}

故可以考虑前缀和,先记录出前缀和的差分数组,后进行累加得到前缀和数组即可。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <cstdlib>
using namespace std;

const long long MAXN = 300;
const long long INF = 1e15;

inline long long read() {
    long long x = 0, f = 1;
    char ch = getchar();
    while (!isdigit(ch)) {
        if (ch == '-') f = -1;
        ch = getchar();
    }
    while (isdigit(ch)) {
        x = x * 10 + ch - '0';
        ch = getchar();
    }
    return x * f;
}

struct node {
    long long r, p;
    bool operator < (const node &a) const {
        return r < a.r;
    }
} a[MAXN];

long long k, d;
long long dp[MAXN][MAXN]; 
long long DD[MAXN][2], sum[MAXN][2];

int main() {
    d = read(); k = read();
    for (long long i = 1; i <= d; i++) {
        a[i].r = read();
        a[i].p = read();
        DD[a[i].r][0] = a[i].r * a[i].p;
        DD[a[i].r][1] = a[i].p;
    }
    sort(a + 1, a + 1 + d);
    sum[0][0] = DD[0][0]; 
    sum[0][1] = DD[0][1];
    for (long long i = 1; i <= 255; i++) {
        for (long long j = 0; j < 2; j++)
            sum[i][j] = sum[i - 1][j] + DD[i][j];
    }
    for (long long j = 0; j <= 255; j++) {
        for (long long kk = 1; kk <= 256; kk++)
            dp[j][kk] = INF;
    }
    for (long long i = 0; i <= 255; i++) {
        dp[i][1] = 0;
        for (long long j = 1; j <= d; j++) {
            dp[i][1] += a[j].p * (a[j].r - i) * (a[j].r - i);
        }
    }
    for (long long cnt = 2; cnt <= k; cnt++) {
        for (long long i = 0; i <= 255; i++) {
            for (long long j = 0; j < i; j++) {
                long long A = (2 * j - 2 * i) * (sum[255][0] - sum[(i + j) / 2][0]) + (i * i - j * j) * (sum[255][1] - sum[(i + j) / 2][1]);
                dp[i][cnt] = min(dp[i][cnt], dp[j][cnt - 1] + A); 
            }
        }
    }
    long long minn = INF;
    for (long long i = 0; i <= 255; i++) {
        minn = min(minn, dp[i][k]);
    }
    printf("%lld\n", minn); 
    return 0;
}