题解 P4767 【[IOI2000]邮局】

2018-12-19 19:45:36


一打眼就看出来是个DP优化的题!

然而我不会

40分:三方的DP不用解释了吧qwq

100分:四边形不等式(参见memset0巨佬的题解,蒟蒻我不会qwq)

101分:忘情水wqs二分

好的,我们来讲101分做法。

首先,一个很显然的结论:取的邮局越多,答案越优

其次,另一个比较显然的结论:设设置$k$个邮局的最优解为$f(k)$,则$f(k)-f(k-1)>f(k+1)-f(k)$。也就是说,该函数的大致图像是:

接下来,观察原方程$f[i][j]$表示$1..i$放了$j$个邮局,发现第一维几乎不可能被优化掉,那么只能从第二维下手了。

如果能恰好取到$k$个邮局的话。。。你做梦!

对于此题,忘情水二分就是干这个事情的。

我们先枚举一个值$C$,表示每放一个邮局需要额外花费$C$的代价。现在,函数图像变成了这样:

图中蓝线为原函数,橙线为现函数,绿色虚线长度依次是$0C,1C,2C,3C,4C...$

可以证明,这是一个单峰函数(然而我不会证,只能感性理解)

于是我们可以二分$C$,在转移时记录当前放置邮局的次数(记录次数,不设上限),然后根据当前次数调整$L$和$R$(具体见代码)

/*xxc 18/12/19       */
/*https://xcfubuki.cn*/
#include <cstdio>
#include <cstring>
#include <algorithm>
#define calc(x, y) (f[x] + w((x) + 1, y) + exc)
#define maxn 100005

using namespace std;
typedef long long LL;

int n, k, a[maxn], cnt[maxn], pre[maxn];
LL s[maxn], f[maxn], exc;

class jc
{
  public:
    int l, r, p;
} que[maxn];

inline int read()
{
    char ch = getchar();
    int ret = 0, f = 1;
    while (ch > '9' || ch < '0')
    {
        if (ch == '-')
            f = -f;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9')
        ret = ret * 10 + ch - '0', ch = getchar();
    return ret * f;
}

LL w(int l, int r)
{
    if (l >= r)
        return 0;
    int mid = (r - l >> 1) + l;
    LL res = a[mid] * (1ll * mid - l + 1) - (s[mid] - s[l - 1]);
    res += (s[r] - s[mid - 1]) - (1ll * r - mid + 1) * a[mid];
    return res;
}

int find(jc x, int s)
{
    int L = x.l, R = x.r;
    while (L <= R)
    {
        int mid = (R - L >> 1) + L;
        if (calc(x.p, mid) > calc(s, mid))
            R = mid - 1;
        else
            L = mid + 1;
    }
    return L;
}

int check()
{
    int hed = 1, til = 0;
    que[++til] = (jc){1, n, 0};
    for (int i = 1; i <= n; ++i)
    {
        f[i] = calc(que[hed].p, i);
        pre[i] = que[hed].p;
        cnt[i] = cnt[que[hed].p] + 1;
        int chs = -1;
        while (hed <= til)
        {
            if (calc(i, que[til].l) < calc(que[til].p, que[til].l))
                chs = que[til--].l;
            else
            {
                int st = find(que[til], i);
                if (st <= que[til].r)
                    chs = st, que[til].r = st - 1;
                break;
            }
        }
        if (chs != -1)
            que[++til] = (jc){chs, n, i};
        if (hed <= til)
        {
            que[hed].l++;
            if (que[hed].l > que[hed].r)
                hed++;
        }
    }
    return cnt[n];
}

void output(int i)
{
    if (0 == i)
        return;
    output(pre[i]);
    printf("%d ", i);
}

int main()
{
    n = read(), k = read();
    for (int i = 1; i <= n; ++i)
        a[i] = read();
    sort(a + 1, a + 1 + n);
    for (int i = 1; i <= n; ++i)
        s[i] = s[i - 1] + 1ll * a[i];
    LL L = 0, R = 1e6, res = 0;
    while (L <= R)
    {
        LL mid = (R - L >> 1) + L;
        exc = mid;
        if (check() <= k)
            res = mid, R = mid - 1;
        else
            L = mid + 1;
    }
    exc = res;
    check();
    printf("%lld\n", f[n] - k * res);
    return 0;
}