S O S D P

· · 题解

前置芝士:高维前缀和(SOS DP)。

如果不会的话左传题解区,包教包会 唔这里会讲一讲的(

简单来说,如何求一个一维序列的前缀和?

废话当然是直接累加啊……

那二维呢?(这里前缀和定义为每个维度都小于等于给定数的所有下标对应的值的和)

正常的写法是每个数用三个来推(也就是

sum[i][j]+=sum[i-1][j]+sum[i][j-1]-sum[i-1][j-1];

这种的,做的是原地前缀和,也就是说 sum 数组原来存的是原值),但我们发现这样的话一个式子里出现了 4 项,如果把它推广到 x 维,且不论代码挺难写,而且一个式子里会有 2^x 项,如果每一维的取值都是小于等于 2(事实上一般都是这个情况),那复杂度并不优秀,是 O(3^x) 的(因为事实上这个情况下只需要枚举对应下标的子集,可以比朴素的 O(4^x) 好一点,复杂度分析大概就是每一位在两个集合里的出现方案只有三种,但总之都很难看)。

所以我们需要一个针对它的优化。

我们考虑我们二维前缀和实际上是要干一个什么事:求出一个前缀子矩阵里所有数的和。那我们不如先求出子矩阵每一列的和然后再把它们全部累起来(或者先行再列也无所谓),然后很容易能发现这两步都可以一维前缀和优化。具体来说,代码会长成这个样子:

首先对所有元素执行:sum[i][j]+=sum[i-1][j];

然后再对所有元素执行:sum[i][j]+=sum[i][j-1];

是不是很简洁,并且显然很好推广到高维?

这样的话,比如我们要对一个 x 维的东西累前缀和,复杂度就降低为 O(x\times2^x),比较舒适。

好了回到本题。

首先如果我们把每种多米诺都给枚举出来并算出权值,建图,那么就转换成了要求:一个顶点的子集,大小为 k 给定,其中没有两个点之间有边,点权和最大是多少(最后拿所有输入的和减去这玩意就是答案)。

这样的话点数大概会是 8\times10^6 级别,由于目测不会有什么神仙线性算法,显然大概率难以接受。考虑怎么优化。

观察到这个图的一个特性:一个点的度数至多为 6(十分显然)。这就给了我们一些启示,事实上,我们可以证明一定存在一种最优方案,所有取出的点都在前 7(k-1)+1 个点里。证明大概就是反证法,不妨最后一个点不在,然后前 k-1 个点每个点至多使得六个点不能取,算上自己一共 7 个,这样最多使得 7(k-1) 个点不可达,这样把最后一个点调整到前 7(k-1)+1 个点一定可行且更优。

然后本题中 k\le8 所以实际上只需要保留前 50 个点。如果不考虑复杂度问题,我们已经有了一个挺不错的算法:弄一个 50 维的数组表示每个点取或者不取(状压一下就变成一个一维数组,大小为 2^{50}),然后初始对于每一维,把只有这一位是 1 的值填进去,然后对于两个相邻的点,把只有这两位的值填成 +\infty(也就是非法),然后累个前缀和就能算了。当然实际上也可以只统计每个状态是否合法,值后面再算,不过这个无所谓,复杂度一致。

我们发现这个复杂度显然要炸,但是这个 2 指数级的复杂度加上 50 的奇妙范围不难想到 meet-in-the-middle。大概就是,把 50 个点分成两个部分分别按照上述方法计算,然后之间计算一下贡献。至于之间的贡献怎么算,我们可以对于前一半的每个状态(集合)算出其在后半部分中最大能匹配到的集合总权,这个显然可以反向前缀和一下干掉,实际上就是要算后半部分中每个集合最大能匹配到前半部分中的哪个集合,这个也可以累前缀和但代码里由于一些历史遗留问题选择了把每个单个的最大对应状态预处理出来然后暴力枚举每一位求集合交。

然后这个复杂度实际上还是有点危,因为反向累前缀和那一步实际上是要分后半部分那个集合的大小来算的,于是我们发现实际上只需要保留大小 \le k 的部分(给后半部分类前缀和的时候也是如此),好消息。

这样的话就可以了。肉眼观察一下前半部分放 22 个点,后半部分放 28 个点的时候是差不多可以的。那就可以过了。注意一下这个数据范围会爆 int,但 unsigned 也够用了。

代码实际上并不难写挺难写的,所以还是贴一下吧(

#include <bits/stdc++.h>
#define fi first
#define se second
#define INF INT_MAX
#define cnt __builtin_popcount
#define int unsigned
using namespace std;
constexpr int N = 4e6 + 9, M = 50, W = 21, C = 9e6 + 9;
int n, k, tot, d, val[N], ans, sum;
pair<int, int> dm[N << 1];
int id(int x, int y) { return x * n + y; }
int getval(pair<int, int> x) { return val[x.fi] + val[x.se]; }
bool issim(pair<int, int> x, pair<int, int> y) {
  return x.fi == y.fi || x.fi == y.se || x.se == y.fi || x.se == y.se;
}
int iadd(int x, int y) { return x == INF || y == INF ? INF : x + y; }
int f[1 << W], g[C], gst[M], gval[1 << W][9];
int num[C], ntot;
void dfs(int hbt, int rem, int x) {
  if (!hbt)
    num[ntot++] = x;
  else {
    dfs(hbt - 1, rem, x);
    if (rem) dfs(hbt - 1, rem - 1, x | (1 << (hbt - 1)));
  }
}
int pos(int x) { return lower_bound(num, num + ntot, x) - num; }
signed main() {
  ios::sync_with_stdio(false), cin.tie(nullptr);
  dfs(M - W, 8, 0), cin >> n >> k;
  for (int i = 0; i < n; ++i)
    for (int j = 0; j < n; ++j) {
      cin >> val[id(i, j)], sum += val[id(i, j)];
      if (i) dm[tot++] = {id(i - 1, j), id(i, j)};
      if (j) dm[tot++] = {id(i, j - 1), id(i, j)};
    }
  if (tot > M)
    nth_element(dm, dm + M, dm + tot,
                [](auto x, auto y) { return getval(x) > getval(y); }),
        tot = M;
  d = min(tot >> 1, W);
  for (int i = 0; i < d; ++i) {
    f[1 << i] = getval(dm[i]);
    for (int j = 0; j < i; ++j)
      if (issim(dm[i], dm[j])) f[(1 << i) | (1 << j)] = INF;
  }
  for (int i = 0; i < d; ++i)
    for (int s = 0; s < (1u << d); ++s)
      if ((s & (1 << i)) && cnt(s) <= k) f[s] = iadd(f[s], f[s ^ (1 << i)]);
  for (int i = 0; i < tot - d; ++i) {
    g[pos(1 << i)] = getval(dm[i + d]);
    for (int j = 0; j < i; ++j)
      if (issim(dm[i + d], dm[j + d])) g[pos((1 << i) | (1 << j))] = INF;
  }
  for (int i = 0; i < tot - d; ++i) {
    gst[i] = (1 << d) - 1;
    for (int j = 0; j < d; ++j)
      if (issim(dm[i + d], dm[j])) gst[i] ^= 1 << j;
  }
  for (int i = 0; i < tot - d; ++i)
    for (int p = 0; p < ntot && (num[p] < (1u << (tot - d))); ++p)
      if (int s = num[p]; (s & (1 << i)) && cnt(s) <= k)
        g[p] = iadd(g[p], g[pos(s ^ (1 << i))]);
  for (int p = 0; p < ntot && (num[p] < (1u << (tot - d))); ++p)
    if (int s = num[p]; g[p] != INF && cnt(s) <= k) {
      int gsts = (1 << d) - 1;
      for (int i = 0; i < (tot - d); ++i)
        if (s & (1 << i)) gsts &= gst[i];
      gval[gsts][cnt(s)] = max(gval[gsts][cnt(s)], g[p]);
    }
  for (int i = 0; i < d; ++i)
    for (int s = 0; s < (1u << d); ++s)
      if (s & (1 << i))
        for (int j = 0; j <= k; ++j) {
          gval[s ^ (1 << i)][j] = max(gval[s ^ (1 << i)][j], gval[s][j]);
        }

  for (int s = 0; s < (1u << d); ++s)
    if (f[s] != INF && cnt(s) <= k) ans = max(ans, gval[s][k - cnt(s)] + f[s]);
  cout << sum - ans << endl;
  return 0;
}

以上。