P10861 [HBCPC2024] MACARON Likes Happy Endings

· · 题解

或许更好的阅读体验。

思路:

s 为前缀异或,那么 c(i, j) = [s_j \oplus s_{i - 1} = d] \ge 0,于是 w(l, r) = \sum_{l \le i \le j \le r} 满足四边形不等式,于是 dp 时:

dp_{i, j} = \min_{k \le i} dp_{k - 1, j - 1} + w(k, i)

然后分治去做,算 w(k, i) 的时候可以简单走指针维护 s_is_{i - 1} 的桶计算,于是时间复杂度是 O(nk \log n) 的。

证明:

c(i, j) \ge 0,则 w(l, r) = \sum_{l \le i \le j \le r} c(i, j) 满足四边形不等式。

考虑证明 w(x, y) + w(x + 1, y + 1) \le w(x, y + 1) + w(x + 1, y),考虑拆贡献 x + 1 \le i \le j \le y(i, j) 对两边贡献是相同的,然后左边是 \sum_{i = x, j \in [x, y]} c_{i, j} + \sum_{i \in [x + 1, y + 1], j = y + 1} c_{i, j},然后右边 w(x, y + 1) 去掉 x + 1 \le i \le j \le y(i, j) 后是 \sum_{i = x,j \in [x, y + 1]} c(i, j) + \sum_{i \in [x, y + 1], j = y + 1} c(i, j),于是右边多考虑了 c(x, y + 1),得证。

完整代码:

#include<bits/stdc++.h>
#define ls(k) k << 1
#define rs(k) k << 1 | 1
#define lowbit(x) x & (-x)
#define fi first
#define se second
#define popcnt(x) __builtin_popcount(x)
#define open(s1, s2) freopen(s1, "r", stdin), freopen(s2, "w", stdout);
using namespace std;
typedef __int128 __;
typedef long double lb;
typedef double db;
typedef unsigned long long ull;
typedef long long ll;
bool Begin;
const int N = 1e5 + 10, M = 21, W = 4e6 + 10;
inline ll read(){
    ll x = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9'){
        if(c == '-')
            f = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9'){
        x = (x << 1) + (x << 3) + (c ^ 48);
        c = getchar();
    }
    return x * f;
}
inline void write(ll x){
    if(x < 0){
        putchar('-');
        x = -x;
    }
    if(x > 9)
        write(x / 10);
    putchar(x % 10 + '0');
}
ll now;
int n, k, d, L = 1, R = 0;
int a[N], cnt[W], cnt2[W];
ll dp[M][N];
inline void addr(int x){
    ++cnt[a[x - 1]];
    ++cnt2[a[x]];
    now += cnt[a[x] ^ d];
}
inline void delr(int x){
    now -= cnt[a[x] ^ d];
    --cnt[a[x - 1]];
    --cnt2[a[x]];
}
inline void addl(int x){
    ++cnt[a[x - 1]];
    ++cnt2[a[x]];
    now += cnt2[a[x - 1] ^ d];
}
inline void dell(int x){
    now -= cnt2[a[x - 1] ^ d];
    --cnt2[a[x]];
    --cnt[a[x - 1]];
}
inline ll getw(int l, int r){
//  cerr << "Start\n";
//  cerr << L << ' ' << R << ' ' << l << ' ' << r << '\n';
    while(R < r)
        addr(++R);
    while(L < l)
        dell(L++);
    while(L > l)
        addl(--L);
    while(R > r)
        delr(R--);
//  cerr << "End\n";
    return now;
}
inline void solve(int l, int r, int kl, int kr, int x){
    if(l > r || kl > kr)
        return ;
//  cerr << l << ' ' << r << ' ' << x << '\n';
    int mid = (l + r) >> 1, kmid = -1;
    for(int i = kl; i <= min(mid - 1, kr); ++i){
        if(kmid == -1 || dp[x - 1][i] + getw(i + 1, mid) < dp[x][mid]){
            dp[x][mid] = dp[x - 1][i] + getw(i + 1, mid);
            kmid = i;
        }
    }
    assert(kmid != -1);
//  if(!kmid){
//      cerr << mid << ' ' << kl << ' ' << kr << ' ' << x << '\n';
//  }
    solve(l, mid - 1, kl, kmid, x);
    solve(mid + 1, r, kmid, kr, x);
}
int main(){
    L = 1, R = 0;
    n = read(), k = read(), d = read();
    for(int i = 1; i <= n; ++i)
        a[i] = a[i - 1] ^ read();
    for(int i = 1; i <= n; ++i)
        dp[1][i] = getw(1, i);
    for(int x = 2; x <= k; ++x)
        solve(1, n, 0, n, x);
    ll ans = 1e18;
    for(int x = 1; x <= k; ++x)
      ans = min(ans, dp[x][n]);
    write(ans);
    return 0;
}