【题解】P4119 [Ynoi2018] 未来日记

· · 题解

更好的阅读体验

题意

P4119 [Ynoi2018] 未来日记

给定一个长度为 n 的序列和 m 个操作,每次操作可以:

  1. [l, r] 内的所有值 x 改为值 y

  2. 查询 [l, r] 内第 k 小的值

相同的值算多次。

1 \leq n, m, a_i \leq 10^5

思路

最初分块。

分块 + 值域分块 + 并查集。

显然思路是二分套树状数组求静态 k 小值,但是复杂度不对。

考虑用值域分块代替二分。

将值域 [1, 10^5] 分成 \sqrt{10^5} 块。

sum1_{i, j} 表示前 i 个块内,在第 j 个值域块内的值的个数,sum2_{i, j} 表示前 i 个块内值 j 的个数。这两个数组可以 \mathcal{O}(n \sqrt{n}) 预处理。

询问散块直接暴力 nth_element

询问 [l, r] 时另外处理 cnt1_{i} 表示散块内在第 i 个值域块内的值的个数,cnt2_{i} 表示散块内值 i 的个数。当块长取 \sqrt{n} 时,处理这两个数组复杂度是 \mathcal{O}(\sqrt{n})

假设询问第 k 小。

首先确定第 k 小值所处的值域块。具体实现可以令 sum 初始为 0,枚举值域块 i 时令 sum 不断累加 [l, r] 内第 i 个值域块内的值的个数(用上面处理的数组 \mathcal{O}(1) 求),当 sum \geq k 时说明第 k 小值在第 i 个值域块内。枚举值域块复杂度是 \mathcal{O}(\sqrt{n})

然后枚举值域块内的值,用类似上面的方法求出第 k 小值。复杂度也是 \mathcal{O}(\sqrt{n})

类似于 P4117 [Ynoi2018] 五彩斑斓的世界,维护 rt_{i, j}, val_{i, j}, pos_irt_{i, j} = k 表示 k 是第 i 个块内值 j 对应并查集的根,val_{i, j} = k 表示第 i 个块内根 j 对应的值为 k。令 bel_i 表示下标 i 所属的块的编号,则 pos[i] = rt[ bel[i] ][ a[i] ]

小优化,并查集的根编号不必用数组下标。

直接修改很难,不妨先用 $sum1$ 和 $sum2$ 差分出每块内的答案,对每块答案单独维护,最后再前缀和合并。显然不影响复杂度。 散块直接暴力修改重构即可,复杂度是 $\mathcal{O}(\sqrt{n})

修改整块 i 可以分成三种情况:

时间复杂度 \mathcal{O}((n + m) \sqrt{n})

代码

#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;

const int maxn = 1e5 + 5;
const int maxk = 320;

int n, m;
int block, tot;
int st[maxk], ed[maxk];
int a[maxn], bel[maxn];
int cnt1[maxk], sum1[maxk][maxk];
int cnt2[maxn], sum2[maxk][maxn];
int rt[maxk][maxn], val[maxk][maxn], pos[maxn];

void reduce(int idx) {
    for (int i = st[idx]; i <= ed[idx]; i++) {
        a[i] = val[idx][pos[i]];
    }
}

void build(int idx) {
    int cur = 0;
    for (int i = 1; i <= block; i++) {
        rt[idx][val[idx][i]] = 0;
    }
    for (int i = st[idx]; i <= ed[idx]; i++) {
        if (!rt[idx][a[i]]) {
            cur++;
            rt[idx][a[i]] = cur;
            val[idx][cur] = a[i];
        }
        pos[i] = rt[idx][a[i]];
    }
}

void modify(int l, int r, int x, int y) {
    for (int i = l; i <= r; i++) {
        if (a[i] == x) {
            sum1[bel[i]][bel[x]]--;
            sum1[bel[i]][bel[y]]++;
            sum2[bel[i]][x]--;
            sum2[bel[i]][y]++;
            a[i] = y;
        }
    }
}

void merge(int idx, int x, int y) {
    rt[idx][y] = rt[idx][x];
    val[idx][rt[idx][x]] = y;
    rt[idx][x] = 0;
}

void update(int l, int r, int x, int y) {
    if ((x == y) || (sum2[bel[r]][x] - sum2[bel[l] - 1][x] == 0)) {
        return;
    }
    for (int i = bel[n]; i >= bel[l]; i--) {
        sum1[i][bel[x]] -= sum1[i - 1][bel[x]];
        sum1[i][bel[y]] -= sum1[i - 1][bel[y]];
        sum2[i][x] -= sum2[i - 1][x];
        sum2[i][y] -= sum2[i - 1][y];
    }
    if (bel[l] == bel[r]) {
        reduce(bel[l]);
        modify(l, r, x, y);
        build(bel[l]);
        for (int i = bel[l]; i <= bel[n]; i++) {
            sum1[i][bel[x]] += sum1[i - 1][bel[x]];
            sum1[i][bel[y]] += sum1[i - 1][bel[y]];
            sum2[i][x] += sum2[i - 1][x];
            sum2[i][y] += sum2[i - 1][y];
        }
    } else {
        reduce(bel[l]);
        modify(l, ed[bel[l]], x, y);
        build(bel[l]);
        reduce(bel[r]);
        modify(st[bel[r]], r, x, y);
        build(bel[r]);
        for (int i = bel[l] + 1; i < bel[r]; i++) {
            if (!sum2[i][x]) {
                continue;
            } else if (sum2[i][y]) {
                reduce(i);
                modify(st[i], ed[i], x, y);
                build(i);
            } else {
                sum1[i][bel[y]] += sum2[i][x];
                sum1[i][bel[x]] -= sum2[i][x];
                sum2[i][y] += sum2[i][x];
                sum2[i][x] = 0;
                merge(i, x, y);
            }
        }
        for (int i = bel[l]; i <= bel[n]; i++) {
            sum1[i][bel[x]] += sum1[i - 1][bel[x]];
            sum1[i][bel[y]] += sum1[i - 1][bel[y]];
            sum2[i][x] += sum2[i - 1][x];
            sum2[i][y] += sum2[i - 1][y];
        }
    }
}

int query(int l, int r, int k) {
    int ans, sum = 0;
    if (bel[l] == bel[r]) {
        reduce(bel[l]);
        for (int i = l; i <= r; i++) {
            cnt2[i] = a[i];
        }
        nth_element(cnt2 + l, cnt2 + l + k - 1, cnt2 + r + 1);
        ans = cnt2[l + k - 1];
        for (int i = l; i <= r; i++) {
            cnt2[i] = 0;
        }
        return ans;
    }
    reduce(bel[l]);
    for (int i = l; i <= ed[bel[l]]; i++) {
        cnt1[bel[a[i]]]++;
        cnt2[a[i]]++;
    }
    reduce(bel[r]);
    for (int i = st[bel[r]]; i <= r; i++) {
        cnt1[bel[a[i]]]++;
        cnt2[a[i]]++;
    }
    for (int i = 1; i <= bel[100000]; i++) {
        if ((sum + cnt1[i] + sum1[bel[r] - 1][i] - sum1[bel[l]][i]) >= k) {
            for (int j = (i - 1) * block + 1; j <= i * block; j++) {
                if ((sum + cnt2[j] + sum2[bel[r] - 1][j] - sum2[bel[l]][j]) >= k) {
                    for (int k = l; k <= ed[bel[l]]; k++) {
                        cnt1[bel[a[k]]]--;
                        cnt2[a[k]]--;
                    }
                    for (int k = st[bel[r]]; k <= r; k++) {
                        cnt1[bel[a[k]]]--;
                        cnt2[a[k]]--;
                    }
                    return j;
                } else {
                    sum += (cnt2[j] + sum2[bel[r] - 1][j] - sum2[bel[l]][j]);
                }
            }
        } else {
            sum += (cnt1[i] + sum1[bel[r] - 1][i] - sum1[bel[l]][i]);
        }
    }
}

int main() {
    int opt, l, r, x, y, k;
    scanf("%d%d", &n, &m);
    block = sqrt(n);
    tot = ceil(n * 1.0 / block);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
    }
    for (int i = 1; i < maxn; i++) {
        bel[i] = (i - 1) / block + 1;
    }
    for (int i = 1; i <= tot; i++) {
        st[i] = (i - 1) * block + 1;
        ed[i] = i * block;
    }
    ed[tot] = n;
    for (int i = 1; i <= tot; i++) {
        build(i);
    }
    for (int i = 1; i <= tot; i++) {
        for (int j = 1; j < maxk; j++) {
            sum1[i][j] = sum1[i - 1][j];
        }
        for (int j = 1; j < maxn; j++) {
            sum2[i][j] = sum2[i - 1][j];
        }
        for (int j = st[i]; j <= ed[i]; j++) {
            sum1[i][bel[a[j]]]++;
            sum2[i][a[j]]++;
        }
    }
    for (int i = 1; i <= m; i++) {
        scanf("%d", &opt);
        if (opt == 1) {
            scanf("%d%d%d%d", &l, &r, &x, &y);
            update(l, r, x, y);
        } else {
            scanf("%d%d%d", &l, &r, &k);
            printf("%d\n", query(l, r, k));
        }
    }
    return 0;
}