GMOI R2 T2 猫耳小 官方题解

· · 题解

首先特判 k=0 的情况,此时的答案为非 0 数的个数,改法是将它们全改成 0

再特判 k 较大的情况,此时的答案为 0

否则,对于 k 大小适中的情况,我们从前往后遍历数组,同时维护当前区间的 \operatorname{mex} 值。根据 \operatorname{mex} 的定义,显然对于左端点相同的区间,右端点不断右移的过程中 \operatorname{mex} 单调不降。因此,如果 \operatorname{mex} 值在某一时刻超过了 k(一定是遍历到一个 k 导致的),则继续右移是没有意义的,此时清空维护的区间。如果 \operatorname{mex} 值在某一时刻等于 k,则将此时区间右端点改为 k

时间复杂度 \mathcal O(n)

引理一:存在一种最优方案,使得所有修改的位置都被改为 $k$。 证明:考虑任意一种最优方案,如果一个位置被修改,使得不存在 $\operatorname{mex}$ 为 $k$ 的区间,则将它改为 $k$ 也不存在。原因是包含 $k$ 的区间的 $\operatorname{mex}$ 一定不等于 $k$,而从这个位置将数列分开,根据方案合法性,剩余的部分的 $\operatorname{mex}$ 也一定不是 $k$。$\square

因此,我们只需要考虑将所有修改的位置都改为 k 的情况。

同样的理由,我们不需要考虑原数列中包含 k 的连续子段,因此可以将原数列按照 a_i=k 的位置分为若干段,每一段都不包含 k。另外,显而易见地,大于 ka_i 没有意义。综上,只要我们解决了子任务三(保证 a_i < k),原问题得解。

在子任务三中,一段区间的 \operatorname{mex} 等于 k,当且仅当这段区间包含 0\sim k-1 的所有数。我们找到最靠左的这样的区间,则这段区间内至少要有一个数被改为 k。在改为 k 后,从被修改的位置往右是一个子问题。

引理二:存在一种最优方案,使得所有被修改的位置都是上述区间的右端点。

证明:考虑任意一种最优方案,假设一次修改没有改右端点,则改右端点一定不劣。原因是这样可以尽量缩减问题的规模,由于 \operatorname{mex} 有单调性,这么操作并不会使得下一个修改位置提前,只可能使得下一个修改位置延后。\square

子任务三的贪心算法正确性得证。根据上述推理,原问题的贪心算法正确性得证。

代码:

//By: OIer rui_er
#include <bits/stdc++.h>
#define rep(x,y,z) for(int x=(y);x<=(z);x++)
#define per(x,y,z) for(int x=(y);x>=(z);x--)
#define debug(format...) fprintf(stderr, format)
#define fileIO(s) do{freopen(s".in","r",stdin);freopen(s".out","w",stdout);}while(false)
#define likely(exp) __builtin_expect(!!(exp), 1)
#define unlikely(exp) __builtin_expect(!!(exp), 0)
using namespace std;
typedef long long ll;

mt19937 rnd(std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::system_clock::now().time_since_epoch()).count());
int randint(int L, int R) {
    uniform_int_distribution<int> dist(L, R);
    return dist(rnd);
}

template<typename T> void chkmin(T& x, T y) {if(x > y) x = y;}
template<typename T> void chkmax(T& x, T y) {if(x < y) x = y;}

const int N = 1e6+5;

int n, k, a[N], b[N], cnt[N], mex, ans;

int read() {
    int x = 0, k = 1;
    char c = getchar();
    for(; !isdigit(c); c = getchar()) if(c == '-') k *= -1;
    for(; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + (c ^ 48);
    return x * k;
}

void write(int x, char end = 0) {
    if(x < 0) putchar('-'), x = -x;
    if(x < 10) putchar(x ^ 48);
    else {
        write(x / 10);
        putchar((x % 10) ^ 48);
    }
    if(end) putchar(end);
}

int main() {
    n = read(); k = read();
    rep(i, 1, n) a[i] = read();
    if(!k) {
        rep(i, 1, n) if(a[i]) ++ans;
        rep(i, 1, n) b[i] = 0;
    }
    else if(k > n + 1) {
        rep(i, 1, n) b[i] = a[i];
    }
    else {
        rep(i, 1, n) b[i] = a[i];
        for(int l = 0, r = 1; r <= n; r++) {
            if(a[r] > k) continue;
            if(a[r] == k) {
                while(++l < r) if(a[l] < k) --cnt[a[l]];
                mex = 0;
            }
            else {
                ++cnt[a[r]];
                while(cnt[mex]) ++mex;
                if(mex == k) {
                    ++ans;
                    while(++l < r) if(a[l] < k) --cnt[a[l]];
                    --cnt[a[r]];
                    b[r] = k;
                    mex = 0;
                }
            }
        }
    }
    write(ans, '\n');
    rep(i, 1, n) write(b[i], " \n"[i==n]);
    return 0;
}