CF1342F Make It Ascending

· · 题解

*VIII. CF1342F Make It Ascending.

题目相当于求将序列 a 划分为若干集合 S_1, S_2, \cdots S_c,集合之间有序,满足 S_i 之和小于 S_{i + 1} 之和(单调递增),且存在一个单调上升序列 p 满足 p_i \in S_i,即我们最终会将 S_i 所有其它元素累和到原序列位置为 p_i​ 的元素上。求最多划分出多少集合。

朴素的想法是将所有限制全部装进 DP 里面,即设 f_{i, j, k} 表示当前考虑了 mask 为 i 的数集,最后一个划分出的集合对应的 p 的最小值为 j,且总和为 k 时,最多划分出多少集合。转移根据实际意义显然。

这样复杂度太大了,因为我们把值域装了进去。考虑如何把值域去掉。由于若 k > k'f_{i, j, k} \leq f_{i, j, k'},则前者显然无用。这说明 对于相同的 f_{i, j, k},仅有最小的 k 有用。根据这一性质,我们尝试 忽略无用状态,交换 DP 的某一维度和值域。这意味着,我们设 f_{i, j, k} 表示当处于上述意义下的 i, j 以及 集合个数k 时,最后一个集合的和的最小值。转移即在 i 这一维枚举子集,同时枚举 j, k

转移条件有三个:

  1. 存在 p\in S 使得 p > j。设 p_{\min} 为最小的这样的 p

f_{i, j, k} 可以转移到 f_{i \cup S, p_{\min}, k + 1},并令后者对 \sum\limits_{p\in S} a_p 取最小值。容易发现时间复杂度为 \mathcal{O}(3 ^ n n ^ 2)。由于不合法状态较多,跑不满(若 f_{i, j, k}\infty 则不用枚举 i 的补集的子集),时限非常大,故可以通过。

#include <bits/stdc++.h>
using namespace std;

const int N = 16;
struct dat {int msk, num, pos;} tr[1 << N][N][N];
int T, n, a[N], ban[N], s[1 << N], f[1 << N][N][N];
vector <pair <int, int>> ans;
int find(int x) {
    int rnk = 1;
    for(int i = 1; i < x; i++) if(!ban[i]) rnk++;
    return rnk;
}
void calc(int msk, int num, int pos) {
    if(!msk) return;
    dat t = tr[msk][num][pos];
    calc(t.msk, t.num, t.pos);
    for(int i = 0; i < n; i++)
        if(msk - t.msk >> i & 1 && i != pos - 1)
            ans.push_back({i + 1, pos});
}
void solve() {
    cin >> n, ans.clear();
    for(int i = 0; i < n; i++) cin >> a[i];
    for(int i = 0; i < 1 << n; i++)
        for(int j = 0; j <= n; j++)
            for(int k = 0; k <= n; k++)
                f[i][j][k] = 1e9;
    f[0][0][0] = 0;
    for(int i = 1; i < 1 << n; i++)
        for(int j = 0; j < n; j++)
            if(i >> j & 1)
                s[i] = s[i - (1 << j)] + a[j];
    for(int i = 0; i < 1 << n; i++)
        for(int j = 0; j < n; j++)
            for(int k = 0; k < n; k++)
                if(f[i][j][k] < 1e9) {
                    int C = (1 << n) - 1 - i;
                    for(int S = C; S; S = (S - 1) & C)
                        if(s[S] > f[i][j][k] && S >> k) {
                            int nk = __builtin_ctz(S >> k << k) + 1;
                            if(s[S] >= f[i + S][j + 1][nk]) continue;
                            f[i + S][j + 1][nk] = s[S];
                            tr[i + S][j + 1][nk] = {i, j, k};
                        }
                }
    for(int num = n; num; num--)
        for(int pos = 0; pos <= n; pos++)
            if(f[(1 << n) - 1][num][pos] != 1e9) {
                calc((1 << n) - 1, num, pos);
                cout << ans.size() << endl;
                memset(ban, 0, sizeof(ban));
                for(auto it : ans) {
                    cout << find(it.first) << " " << find(it.second) << endl;
                    ban[it.first] = 1;
                }
                return;
            }
}
int main() {
    cin >> T;
    while(T--) solve();
    return 0;
}

/*
stupid mistakes:
    pos - 1 -> pos
    注意 pos 的定义 =.= 
*/