P9375题解

· · 题解

关于多维 dp,它还活着。

观前提示:

本题解主要是关于 50 分部分的多维 dp 卡过本题的乱搞做法(理论上可以算作合法解,毕竟实际的状态数其实一点也不多),同时运用一些技巧降低代码复杂度。想学习正解的小伙伴可以去看看其他的题解。

由于个人编码的习惯,在题解前文中的变量在阐明定义后会采用相对简单的命名,在代码中会采用与不一定前文一致,但足够清晰的变量名,望各位海涵。

正文:

如果你不是很清楚如何对这道题进行多维 dp,那你可以先去看看乌龟棋这道题,本题的状态设计与状态转移和上面那道题完全一致

当然在这里也会浅浅说一下:

首先有一个较为显然的结论就是不同长度的子段数不会超过 15 个(因为 \sum_{1}^{15} = 120),也就是说我们 dp 的维数不会超过 15

根据这个结论也能发现读入的 c_i 中应该会有很多 0,所以还需要把 c 重映射,留下具有实际意义的作为 dp 的维数。具体的操作为对于每一个 c_i \neq 0,我们令 len[++ idx] = i, cnt[idx] = c[i]

同时也需要 O(n^3) 预处理出所有区间的权值,其中 v_{l, r} 表示 a_la_r 的权值。

有了这些前置条件,我们就可以愉快地设 dp 方程了,先以 4 维的举个例:

f_{a,b,c,d} 重映射后为用了 a 个一号子段,b 个二号子段,c 个三号子段,d 个四号子段的操作集合中得到的最大权值。

那么有转移方程:

t = a \times len_1 + b \times len_2 + c \times len_3 + d \times len_4 f_{a,b,c,d} = \max \begin{cases} f_{a - 1, b, c, d} + v_{t - len_1 + 1, t} \\f_{a,b-1,c,d} + v_{t-len_2+1, t} \\f_{a, b, c - 1, d} + v_{t-len_3+1, t} \\f_{a,b,c,d-1} +v_{t-len_4+1, t} \end{cases}

仿照上例即可推出 15 维的方程。

当然仅仅这样是不够的,单是空间就够 MLE 好几次了,但如果我们稍加思考就会发现空间实际上严重不满,我们要做的只是动态开空间,在这里先使用 struct 自定义节点并用 std::map 作为 f 数组。

于是可以得到部分极具观赏性的代码(这里只有十一层循环的原因是面向数据)。

struct node {
  int b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11;
  bool operator <(const node&t)const {
    if (b1 != t.b1) return b1 < t.b1;
    if (b2 != t.b2) return b2 < t.b2;
    if (b3 != t.b3) return b3 < t.b3;
    if (b4 != t.b4) return b4 < t.b4;
    if (b5 != t.b5) return b5 < t.b5;
    if (b6 != t.b6) return b6 < t.b6;
    if (b7 != t.b7) return b7 < t.b7;
    if (b8 != t.b8) return b8 < t.b8;
    if (b9 != t.b9) return b9 < t.b9;
    if (b10 != t.b10) return b10 < t.b10;
    return b11 < t.b11;
  }
}; // node 节点定义及重载预算符

// dp 主体
for (int a1 = 0; a1 <= cnt[1]; ++ a1)
  for (int a2 = 0; a2 <= cnt[2]; ++ a2)
    for (int a3 = 0; a3 <= cnt[3]; ++ a3)
      for (int a4 = 0; a4 <= cnt[4]; ++ a4)
        for (int a5 = 0; a5 <= cnt[5]; ++ a5)
          for (int a6 = 0; a6 <= cnt[6]; ++ a6)
            for (int a7 = 0; a7 <= cnt[7]; ++ a7)
              for (int a8 = 0; a8 <= cnt[8]; ++ a8)
                for (int a9 = 0; a9 <= cnt[9]; ++ a9)
                  for (int a10 = 0; a10 <= cnt[10]; ++ a10)
                    for (int a11 = 0; a11 <= cnt[11]; ++ a11) {
                      int& op = mp[ {a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11}];
                      int opt = a1 * len[1] + a2 * len[2] + a3 * len[3] + a4 * len[4] + a5 * len[5] + a6 * len[6] + a7 * len[7] + a8 * len[8] + a9 * len[9] + a10 * len[10] + a11 * len[11];
                      if (a1) op = max(op, mp[ {a1 - 1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11}] + v[opt - len[1] + 1][opt]);
                      if (a2) op = max(op, mp[ {a1, a2 - 1, a3, a4, a5, a6, a7, a8, a9, a10, a11}] + v[opt - len[2] + 1][opt]);
                      if (a3) op = max(op, mp[ {a1, a2, a3 - 1, a4, a5, a6, a7, a8, a9, a10, a11}] + v[opt - len[3] + 1][opt]);
                      if (a4) op = max(op, mp[ {a1, a2, a3, a4 - 1, a5, a6, a7, a8, a9, a10, a11}] + v[opt - len[4] + 1][opt]);
                      if (a5) op = max(op, mp[ {a1, a2, a3, a4, a5 - 1, a6, a7, a8, a9, a10, a11}] + v[opt - len[5] + 1][opt]);
                      if (a6) op = max(op, mp[ {a1, a2, a3, a4, a5, a6 - 1, a7, a8, a9, a10, a11}] + v[opt - len[6] + 1][opt]);
                      if (a7) op = max(op, mp[ {a1, a2, a3, a4, a5, a6, a7 - 1, a8, a9, a10, a11}] + v[opt - len[7] + 1][opt]);
                      if (a8) op = max(op, mp[ {a1, a2, a3, a4, a5, a6, a7, a8 - 1, a9, a10, a11}] + v[opt - len[8] + 1][opt]);
                      if (a9) op = max(op, mp[ {a1, a2, a3, a4, a5, a6, a7, a8, a9 - 1, a10, a11}] + v[opt - len[9] + 1][opt]);
                      if (a10) op = max(op, mp[ {a1, a2, a3, a4, a5, a6, a7, a8, a9, a10 - 1, a11}] + v[opt - len[10] + 1][opt]);
                      if (a11) op = max(op, mp[ {a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11 - 1}] + v[opt - len[11] + 1][opt]);
                    }

十分甚至有九分震撼,相信应该不会有人选择上面的代码的(但确实能过)。

自此本题解的算法环节结束,接下开始优化代码复杂度:

  1. node 的定义太麻烦了,可不可以不写那么多条件分支。

    当然可以,使用 std::vector 可以并使用内置的小于号即可完成比较。

  2. 没有问题,可以使用递归控制循环层数,同时赋值操作相似度很高,考虑使用一个 `std::vector` 存入所有的循环变量,稍加修改即可。

值得一提的是把上面的循环改成递归后常数会大很多,因为上面的代码相当于使用了循环展开卡常。

细节方面就是在代码实现中记得判无解情况。

代码:

#include <bits/stdc++.h>

using namespace std;

int const N = 5e3;

int n, idx;
int a[N], score[N][N]; // 对应 a, v 数组
map<vector<int>, int> dp; // 对应 f 数组
vector<int> loopVariables(15, 0); 
int segmentLength[N], segmentCount[N]; // 对应 len, cnt 数组

void calculateMaxScore(int step) { // dp 主体
    if (step == idx + 1) {
        int &opt = dp[loopVariables];
        int totalLength = 0;
        for (int i = 1; i <= idx; ++ i)
            totalLength += loopVariables[i] * segmentLength[i];
        for (int i = 1; i <= idx; ++ i) {
            if (loopVariables[i] == 0) continue;
            loopVariables[i] -= 1;
            opt = max(opt, dp[loopVariables] + score[totalLength - segmentLength[i] + 1][totalLength]);
            loopVariables[i] += 1;
        }
        return;
    }
    for (int i = 0; i <= segmentCount[step]; ++ i) {
        loopVariables[step] = i;
        calculateMaxScore(step + 1);
    }
}

main() {
    cin >> n;
    for (int i = 1; i <= n; ++ i) cin >> a[i];

    int totalLength = 0;
    for (int i = 1, x; i <= n; ++ i) {
        cin >> x;
        if (x != 0) {
            segmentLength[++ idx] = i;
            segmentCount[idx] = x;
        }
        totalLength += x * i;
    }

    if (totalLength != n) { // 无解判定
        puts("-1");
        return 0;
    }

    function<bool(int)> isPerfectSquare = [](int num) -> bool { // 定义一个局部函数
        int root = sqrt(num);
        return root * root == num;
    };

    for (int i = 1; i <= n; ++ i) { // 预处理
        for (int j = i; j <= n; ++ j) {
            int tempScore1 = 0, tempScore2 = 0;
            for (int k = i; k <= j; ++ k) {
                tempScore1 += isPerfectSquare(abs(a[k] - a[i]));
                tempScore2 += isPerfectSquare(abs(a[k] - a[j]));
            }
            score[i][j] = tempScore1 * tempScore2;
        }
    }

    calculateMaxScore(1);

    cout << dp[loopVariables] << endl;
    return 0;
}