P11233 [CSP-S 2024] 染色

· · 题解

给一个考场 20min 想到的思路。

思路:

考虑朴素的 dp,令 dp_{i,j} 表示 i 左侧第一个同色为第 j 个数字,先令:

\operatorname{get}(i,j) = \begin{cases} a_i & a_i = a_j \\ 0 & a_i \ne a_j \end{cases}

h_i = \operatorname{get}(i - 1, i) 则状态转移方程为:

\begin{cases} dp_{i,0} = dp_{i - 1, 0} + h_{i - 1}\\ dp_{i, i - 1} = h_i + \max dp_{i - 1, k}(k < i - 1) \\ dp_{i, j} = \operatorname{get}(i, j) + \max (dp_{j + 1, l} + \sum\limits_{t = j + 2}^{i - 1} h_i) (l < j) \end{cases}

朴素实现是 O(n^4) 的,可以拿到 20pts,考虑优化。

考虑令:

s_i = \max\limits_{j = 0}^{i - 1} dp_{i,j}, t_i = \max\limits_{j = 0}^{i - 2} dp_{i,j} sum_i = \sum_{i = 1}^i h_i

则状态转移方程可以优化为:

\begin{cases} dp_{i,0} = dp_{i - 1, 0} + h_{i - 1} \\ dp_{i, i - 1} = h_i + s_{i -1} \\ dp_{i, j} = \operatorname{get}(i, j) + \max(dp_{j + 1, l} + sum_{i - 1} - sum_{j + 1})(l < j) \end{cases}

主要是下面那个式子不好做,整理下得:

\begin{aligned} dp_{i,j} &= \operatorname{get}(i, j) + sum_{i - 1} - sum_{j+1} + \max(dp_{j + 1, l}) (l <j) \\ &= \operatorname{get}(i, j) + sum_{i - 1} - sum_{j + 1} + t_{j + 1} \end{aligned}

时间复杂度优化为 O(n^2)

然后注意到答案是 s_n,考虑如何快速求出 s_i

\begin{aligned} s_i &= \max\limits_{j = 0}^{i - 2} dp_{i,j} \\ &= sum_{i - 1} + \max\limits_{j = 0}^{i - 2} \operatorname{get}(i, j) - sum_{j + 1} + t_{j + 1} \\ & = \begin{cases} sum_{i - 1} + a_i + \max\limits_{j = 1}^{i - 1} (t_j - sum_j) & a_i = a_j \\ sum_{i - 1} + \max\limits_{j = 1}^{i - 1} (t_j - sum_j) & a_i \ne a_j \end{cases} \end{aligned}

这样就可以快速求出 s_i,开一个桶 Max_i 表示所有使得 a_j = it_j - sum_j 的最大值即可,此时时间复杂度为 O(nw),其中 w 是值域。

注意当 a_i \ne a_j 时,是求一个 j \in [1, a_i - 1] \cup [a_i +1, w]Max_j 最大值,即两个区间的最大值,需要支持单点修改,树状数组维护即可。

时间复杂度为 O(n \log w)

先给个赛后补的 Code。

完整代码:

#include<bits/stdc++.h>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/hash_policy.hpp>
#define Add(x, y) (x + y >= mod) ? (x + y - mod) : (x + y)
#define lowbit(x) x & (-x)
#define pi pair<ll, ll>
#define pii pair<ll, pair<ll, ll>>
#define iip pair<pair<ll, ll>, ll>
#define ppii pair<pair<ll, ll>, pair<ll, ll>>
#define ls(k) k << 1
#define rs(k) k << 1 | 1
#define fi first
#define se second
#define full(l, r, x) for(auto it = l; it != r; ++it) (*it) = x
#define Full(a) memset(a, 0, sizeof(a))
#define open(s1, s2) freopen(s1, "r", stdin), freopen(s2, "w", stdout);
#define For(i, l, r) for(register int i = l; i <= r; ++i)
#define _For(i, l, r) for(register int i = r; i >= l; --i)
using namespace std;
using namespace __gnu_pbds;
typedef double db;
typedef unsigned long long ull;
typedef long long ll;
bool Begin;
const int N = 2e5 + 10, M = 1e6 + 10;
inline ll read(){
    ll x = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9'){
        if(c == '-')
          f = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9'){
        x = (x << 1) + (x << 3) + (c ^ 48);
        c = getchar();
    }
    return x * f;
}
inline void write(ll x){
    if(x < 0){
        putchar('-');
        x = -x;
    }
    if(x > 9)
      write(x / 10);
    putchar(x % 10 + '0');
}
int T, n;
ll pre;
ll a[N], s[N], dp[N], t[N], Max[M];
namespace Tree{
    ll a[M], b[M];
    void init(){
        for(int i = 0; i < M; ++i)
          a[i] = b[i] = -1e18;
    }
    void add(int x, ll v){
        for(int i = x; i < M; i += lowbit(i))
          a[i] = max(a[i], v);
        x = M - x - 1;
        for(int i = x; i < M; i += lowbit(i))
          b[i] = max(b[i], v);
    }
    ll query(int x){
        ll ans = -1e18;
        for(int i = x - 1; i > 0; i -= lowbit(i))
          ans = max(ans, a[i]);
        x = M - x - 2;
        for(int i = x; i > 0; i -= lowbit(i))
          ans = max(ans, b[i]);
        return ans;
    }
};
inline ll get(int x, int y){
    if(!x || !y)
      return 0;
    if(a[x] == a[y])
      return a[x];
    return 0;
}
void solve(){
    memset(dp, -0x7f, sizeof(dp));
    memset(t, -0x7f, sizeof(t));
    memset(Max, -0x7f, sizeof(Max));
    n = read();
    for(int i = 1; i <= n; ++i){
        a[i] = read();
        s[i] = s[i - 1] + get(i - 1, i);
    }
    Tree::init();
    dp[1] = pre = 0;
    for(int i = 2; i <= n; ++i){
        t[i] = dp[i] = pre + get(i - 2, i - 1);
        pre = dp[i];
        dp[i] = max(dp[i], s[i - 1] + a[i] + Max[a[i]]);
//      cerr << Max[a[i]] << '\n';
        dp[i] = max(dp[i], s[i - 1] + Tree::query(a[i]));
        Max[a[i - 1]] = max(Max[a[i - 1]], dp[i] - s[i]);
        Tree::add(a[i - 1], dp[i] - s[i]);
        dp[i] = max(dp[i], get(i - 1, i) + dp[i - 1]);
    }
    write(dp[n]);
    putchar('\n');
}
bool End;
int main(){
//  open("A.in", "A.out");
    T = read();
    while(T--)
      solve();
    cerr << '\n' << abs(&Begin - &End) / 1048576 << "MB";
    return 0;
}