题解 P2389 【电脑班的裁员】

· · 题解

此题由于数据弱,可以用O(n^3)水过,但实际上有O(n)的算法。

数据加强版:U19030 电脑班的裁员(加强版)

分析

题目大意:在n个数中取不大于m段连续的数,使取的数总和最大。

这题有多种算法,难度根据算法复杂度分布在普及~省选之间。

算法1

裸DP,复杂度O(n^3)

f(i,j)表示前i个数取j段的最大价值。

i不选,则f(i,j)=f(i-1,j)

i选,枚举最后一段的起始位置k

f(i,j)=max\{f(k,j-1)+s_i-s_k\}(k<i)

其中s_i表示前i个数的前缀和。

根据方程,三重循环枚举i,j,k即可。

答案为f(n,m),复杂度O(n^3)

#include <bits/stdc++.h>
typedef long long LL;

const int N = 550;

int main() {
    std::ios::sync_with_stdio(false);
    int n, m;
    std::cin >> n >> m;
    LL s[N] = {0};
    for (int i = 1; i <= n; i++) {
        LL a;
        std::cin >> a;
        s[i] = s[i-1] + a;
    }

    static LL f[N][N] = {0};
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++) {
            f[i][j] = f[i-1][j];
            for (int k = 0; k < i; k++) {
                f[i][j] = std::max(f[i][j], f[k][j-1] + s[i] - s[k]);
            }
        }
    }
    std::cout << f[n][m] << std::endl;
    return 0;
}

算法2

优化算法1,时间O(n^2),空间O(n^2)

回到算法1的方程:

f(i,j)=max\{f(i-1,j),f(k,j)+s_i-s_k\}(k<i)

g(k,j)=f(k,j)-s_k,则

f(i,j)=max\{f(i-1,j),g(k,j)+s_i\}(k<i)

只要维护出g(k,j)的最大值,就不需要枚举k了。

具体地,先枚举j后枚举i,因为k<i,在枚举i的时候顺便维护出g(1到i-1,j)的最大值g_{max},转移的时候直接用g_{max}转移即可。

答案为f(n,m),复杂度O(n^2)

#include <bits/stdc++.h>
typedef long long LL;

const int N = 5e3+50;

int main() {
    std::ios::sync_with_stdio(false);
    int n, m;
    std::cin >> n >> m;
    LL s[N] = {0};
    for (int i = 1; i <= n; i++) {
        LL a;
        std::cin >> a;
        s[i] = s[i-1] + a;
    }

    static LL f[N][N] = {0};
    for (int j = 1; j <= m; j++) {
        LL mx = 0;
        for (int i = 1; i <= n; i++) {
            f[i][j] = std::max(f[i-1][j], mx + s[i]);
            mx = std::max(mx, f[i][j-1] - s[i]);
        }
    }
    std::cout << f[n][m] << std::endl;
    return 0;
}

算法3

另一种思路的DP,时间O(n^2),空间O(n)

f(i,j)表示前i个选j段(i不一定选)的最大价值。

g(i,j)表示前i个选j段(i一定要选)的最大价值。

对于g,讨论i-1选或不选。如果i-1选了,则i可以接上去,不用新增一段。

g(i,j)=max\{g(i-1,j),f(i-1,j-1)\}+a_i

对于f,讨论i选或不选。

f(i,j)=max\{g(i,j),f(i-1,j)\}

这样就可以做到O(n^2)的时间复杂度。因为第一维转移时只涉及到ii-1,所以可以把第一维省掉,空间复杂度O(n)

#include <bits/stdc++.h>
typedef long long LL;

const int N = 5e3+50;

int main() {
    std::ios::sync_with_stdio(false);
    int n, m;
    std::cin >> n >> m;
    LL a[N];
    for (int i = 1; i <= n; i++)
        std::cin >> a[i];

    LL f[N] = {0}, g[N] = {0};
    for (int i = 1; i <= n; i++) {
        for (int j = m; j >= 1; j--) {
            g[j] = std::max(g[j], f[j-1]) + a[i];
            f[j] = std::max(g[j], f[j]);
        }
    }
    std::cout << f[m] << std::endl;
    return 0;
}

算法4

贪心,复杂度O(n\log n)

首先,可以发现,对于一段连续的正数或负数,要么全部选,要么全部不选。

所以,我们可以把连续的一段正数或负数缩成一个数。那么序列就变成了正负交替的。以下说明都针对缩完以后的序列。

设序列中正数的个数为cnt,则对于m\ge cnt的情况,最优解肯定是取所有正数。

考虑m=cnt-1的情况,此时我们需要从m=cnt的情况减少一段。

有两种方法:一种是舍弃一个正数,另一种是取一个负数,使两边的正数合并成一段。

怎么取最优?若舍弃正数a,会损失a的价值。若取负数a,会损失-a的价值。

统一起来,就是若舍弃/取走数字a,会损失\mid a\mid的价值。这样,我们只要找绝对值最小的数舍弃/取走即可。

舍弃/取走一个数后,序列会变成什么样呢?

事实上,舍弃/取走一个数a_i,相当于与两边的数合并,合并完的值是a_{i-1}+a_i+a_{i+1}

例如1,-2,3,-4,5,若舍弃3,则序列变成1,-3,5;若取走-4,则序列变成1,-2,4

可以发现,若取绝对值最小的数,合并完以后的序列还是正负交替。于是我们可以用刚才的方法继续获得m=cnt-2,cnt-3,\dots的答案,直至m达到题目的要求。

取绝对值最小的数,可以用优先队列做。合并节点可以使用链表。答案为最后合并完的序列的所有正数之和。

复杂度:合并O(n)次,每次O(\log n),总复杂度O(n\log n)

此算法在加强版中期望得分90,常数卡得好可能可以AC。

#include <bits/stdc++.h>
typedef long long LL;

const int N = 1e6+50;

struct Data {
    LL val;
    int pos, tim;

    bool operator < (const Data &t) const {
        return val > t.val;
    }
};

int pl[N], pr[N];
int tim[N] = {0};

void del(int u) {
    if (u == 0) return;
    pr[pl[u]] = pr[u];
    pl[pr[u]] = pl[u];
    tim[u] = -1;
}

int main() {
    std::ios::sync_with_stdio(false);
    int n, m;
    std::cin >> n >> m;

    static LL a[N] = {0};
    int top = 0;
    for (int i = 1; i <= n; i++) {
        LL r; std::cin >> r;
        if (r == 0) continue;
        if (top == 0) {
            if (r > 0) a[++top] = r;
            continue;
        }
        if (a[top] > 0 == r > 0) a[top] += r;
        else a[++top] = r;
    }
    if (top > 0 && a[top] < 0) top--;

    for (int i = 0; i <= top; i++) {
        pl[i] = i == 0 ? top : i - 1;
        pr[i] = i == top ? 0 : i + 1;
    }

    std::priority_queue<Data> q;
    for (int i = 1; i <= top; i++) {
        q.push({ std::abs(a[i]), i, 0 });
    }

    for (int cnt = top + 1 >> 1; cnt > m; cnt--) {
        Data d = q.top(); q.pop();
        while (tim[d.pos] != d.tim) {
            d = q.top(); q.pop();
        }
        int u = d.pos, l = pl[u], r = pr[u];
        a[u] += a[l] + a[r];
        if (l != 0 && r != 0)
            q.push({ std::abs(a[u]), u, ++tim[u] });
        else del(u);
        del(l); del(r);
    }

    LL ans = 0;
    for (int i = pr[0]; i != 0; i = pr[i])
        if (a[i] > 0) ans += a[i];
    std::cout << ans << std::endl;
    return 0;
}

算法5

优化算法4,时间复杂度O(n)

算法步骤:

假设当前还需要合并k个数。

先找出剩下的数中绝对值第k小的数和第3k小的数(若多个数绝对值相同则比较编号),记为mid1mid2。把\mid a\mid \le mid1的数全部合并,把\mid a\mid>mid2的数全部忽略(不再参与合并)。

重复以上操作,直到剩下m个正数为止。

结论1:每轮合并的个数一定\le k证明如下:

设当前合并a_{i-1},a_i,a_{i+1},则合并完的数b=a_{i-1}+a_i+a_{i+1}

若合并次数>k,则肯定会出现“三个数中只有一个需要合并的数,合并完又产生了一个需要合并的数”这种情况。也就是|a_i|\le mid1|a_{i-1}|,|a_{i+1}|>mid1,且|b|<=mid1的情况。

然而这种情况不可能出现。由数学方法可以推知,|b|=|a_{i-1}|+|a_{i+1}|-|a_i|>mid1

这样就保证了算法的正确性,不会合并过头。

结论2:每轮合并的个数一定\ge\frac{1}{3}k证明如下:

合并次数最少的情况,就是a_{i-1},a_i,a_{i+1}都需要合并,却只合并了一个a_i。这样的话,每三个需要合并的数中,至少可以成功合并一个,即合并个数\ge\frac{1}{3}k

结论3:|a|>mid2的数可以忽略。证明如下:

因为根据结论2,在最小的3k个数中至少可以合并k个数。所以第3k个数以后的数都不会被合并。

复杂度证明:

设当前还需要合并k个数。

根据结论3,我们只留下3k个数,所以每一轮合并的复杂度为O(k)

根据结论2,每一轮合并,都会使k缩小到\frac{2}{3}k

那么,总复杂度为n+\frac{2}{3}n+(\frac{2}{3})^2n+\dots,根据等比数列公式,复杂度为O(n),期望得分100。

#include <bits/stdc++.h>
typedef long long LL;
typedef std::pair<LL, int> Data;

const int N = 1e6+50;

int n, m; LL a[N];
int L[N], R[N]; bool del[N];
int L2[N], R2[N]; bool del2[N];
int rem;
Data t[N], mx;
std::queue<LL> q; bool inQue[N];

void Init() {
    std::cin >> n >> m;
    int top = 0;
    for (int i = 1; i <= n; i++) {
        LL r; std::cin >> r;
        if (!r) continue;
        if (!top && r < 0) continue;
        if (top && a[top] > 0 == r > 0) a[top] += r;
        else a[++top] = r;
    }
    if (top && a[top] < 0) top--;
    for (int i = 0; i <= top; i++) {
        R[i] = R2[i] = (i + 1) % (top + 1);
        L[i] = L2[i] = (i + top) % (top + 1);
    }
    rem = top;
}

void AddQue(int u) {
    if (u && !inQue[u] && Data(std::abs(a[u]), u) <= mx)
        q.push(u), inQue[u] = true;
}

void Del2(int u) {
    if (!u || del2[u]) return;
    int l = L2[u], r = R2[u];
    R2[l] = r; L2[r] = l;
    del2[u] = true;
}

void Del(int u) {
    if (!u || del[u]) return;
    rem--;
    int l = L[u], r = R[u];
    R[l] = r; L[r] = l;
    del[u] = true;
    Del2(u);
}

void Merge(int u) {
    if (del[u]) return;
    int l = L[u], r = R[u];
    if (l && std::abs(a[l]) < std::abs(a[u])) return;
    if (r && std::abs(a[r]) < std::abs(a[u])) return;
    Del(l); Del(r); a[u] += a[l] + a[r];
    l && r ? AddQue(u) : Del(u);
    AddQue(L[u]); AddQue(R[u]);
}

void Work() {
    while (true) {
        n = 0;
        for (int i = R2[0]; i; i = R2[i])
            if (!del2[i]) t[++n] = Data(std::abs(a[i]), i);
        if (rem <= m * 2 - 1) break;
        int mid = (rem - (m * 2 - 1)) / 2;
        nth_element(t + 1, t + std::min(mid, n), t + n + 1);
        mx = t[std::min(mid, n)];
        nth_element(t + 1, t + std::min(mid*3, n), t + n + 1);
        Data mx2 = t[std::min(mid*3, n)];
        for (int i = R2[0]; i; i = R2[i]) {
            Data cur(std::abs(a[i]), i);
            if (cur > mx2) Del2(i);
            else AddQue(i);
        }
        while (!q.empty()){
            int u = q.front(); q.pop(); inQue[u] = false;
            Merge(u);
        }
    }
    LL ans = 0;
    for (int i = R[0]; i; i = R[i])
        if (a[i] > 0) ans += a[i];
    std::cout << ans << std::endl;
}

int main() {
    std::ios::sync_with_stdio(false);
    Init(); Work();
    return 0;
}