[Ynoi2018] 未来日记 Solution

· · 题解

先考虑分块做区间第 k 大。

cnt_{i,j} 表示第 i 块中 j 的出现次数, sum1_{i,j} 表示前 i 块中 j 的出现次数。

然后我们发现,只用这些信息会多一个 \log

注意到值域是 1 \times 10^5 ,考虑权值分块。

再记 sum2_{i,j} 表示序列分块的前 i 块中,权值分块前 j 块中的树的出现次数。

询问时散块维护 s1,s2 两个数组,s1_i 累加权值分块前 i 块中的数,s2_i 累加 i 的出现次数。

利用 sum2 以及 s1 的信息得到第 k 大落在哪个权值块中 ,再利用 sum1s2 得到答案。

现在来考虑修改,不妨想象一下,x \to yy \to z 的结果是原先的 x,y \to z 。可以注意到这是一个树形结构,自然可以用并查集维护。

具体的,给序列块 i 中每个值 x 发一个代表,记作 rt_{i,x} ,对于一组修改 (i , x , y) ,若 rt_{i,y} 不存在,令 rt_{i,y} \to rt_{i,x}a_{rt_{i,x}} \to y ,反之令 fa_{rt_{i,x}} \to rt_{i,y}

散块并不适合套用这个方法,众所周知,边角暴力,直接重构 xy 的并查集子树。注意在所有修改后都需更新 cntsum1sum2

取序列块长 600 ,值域块长 320,以下是代码:

#include <cstdio>
#include <iostream>
#include <algorithm>

inline int read() {
    register char c = getchar();
    register int x = 0, f = 1;
    while(c < '0' || c > '9') {
        if (c == '-')
            f = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9') {
        x = (x << 3) + (x << 1) + c - 48;
        c = getchar();
    }
    return x * f;
}

const int N = 1e5 + 5;
const int M = 170;
const int S = 320;

int n, m, bl, a[N];
int top = 0, sta[N];
int L[M], R[M], bel[N], s1[S], s2[N]; 
int cnt[M][N], sumc[M][N], sums[M][S];

int fa[N], rt[M][N];
int find(int x) { return (fa[x] == x ? x : fa[x] = find(fa[x])); }

void Build(int b) {
    for (register int i = L[b]; i <= R[b]; i++) {
        if (!rt[b][a[i]])
            rt[b][a[i]] = i;
        else
            fa[i] = rt[b][a[i]];
        cnt[b][a[i]]++;
    }
}

void ReBuild(int b, int l, int r, int x, int y) {
    register int tmp = 0;
    top = 0;
    rt[b][x] = rt[b][y] = 0;
    for (register int i = L[b]; i <= R[b]; i++) {
        a[i] = a[find(i)];
        if (a[i] == x || a[i] == y)
            sta[++top] = i;
    }
    for (register int i = l; i <= r; i++) 
        if (a[i] == x)
            a[i] = y, tmp++;
    cnt[b][x] -= tmp, cnt[b][y] += tmp;
    for (register int i = 1; i <= top; i++)
        fa[sta[i]] = sta[i];
    for (register int i = 1; i <= top; i++) {
        if (!rt[b][a[sta[i]]])
            rt[b][a[sta[i]]] = sta[i];
        else
            fa[sta[i]] = rt[b][a[sta[i]]]; 
    }
    for (register int i = b; i <= bl; i++) {
        sumc[i][x] -= tmp, sumc[i][y] += tmp;
        if (bel[x] != bel[y])
            sums[i][bel[x]] -= tmp, sums[i][bel[y]] += tmp;
    }
}

int main() {
    scanf("%d%d", &n, &m);
    for (register int i = 1; i <= n; i++) {
//      scanf("%d", &a[i]);
        a[i] = read();
        fa[i] = i;
    }
    int Siz = 600, Sizv = 317, Mv = 1e5;
    bl = (n - 1) / Siz + 1; 
    for (register int i = 1; i <= Mv; i++)
        bel[i] = (i - 1) / Sizv + 1;
    for (register int i = 1; i <= bl; i++) {
        L[i] = (i - 1) * Siz +1;
        R[i] = std :: min(i * Siz, n);
        Build(i);
        for (register int j = 1; j <= Sizv; j++)
            sums[i][j] = sums[i - 1][j];
        for (register int j = 1; j <= Mv; j++)
            sumc[i][j] = sumc[i - 1][j] + cnt[i][j];
        for (register int j = L[i]; j <= R[i]; j++)
            sums[i][bel[a[j]]]++;       
    }
    while (m--) {
        register int opt, x, y, l, r, k;
//      scanf("%d", &opt);
        opt = read();
        if (opt == 1) {
//          scanf("%d%d%d%d", &l, &r, &x, &y);
            l = read(), r = read(), x = read(), y = read();
            if (x == y)
                continue;
            register int lb = (l - 1) / Siz + 1;
            register int rb = (r - 1) / Siz + 1;
            if (lb == rb)
                ReBuild(lb, l, r, x, y);
            else {
                ReBuild(lb, l, R[lb], x, y);
                ReBuild(rb, L[rb], r, x, y);
                register int tmp = 0, tmps = 0;
                for (register int i = lb + 1; i <= rb - 1; i++) {
                    if (rt[i][x]) {
                        if (!rt[i][y])
                            rt[i][y] = rt[i][x], a[rt[i][x]] = y;
                        else
                            fa[rt[i][x]] = rt[i][y];
//                      tmps += tmp = cnt[i][x];
                        rt[i][x] = 0;
                        tmp = cnt[i][x];
                        tmps += tmp;
                        cnt[i][y] += tmp, cnt[i][x] = 0; 
                    }
                    sumc[i][x] -= tmps, sumc[i][y] += tmps;
                    if (bel[x] != bel[y])
                        sums[i][bel[x]] -= tmps, sums[i][bel[y]] += tmps;
                }
                for (register int i = rb; i <= bl; i++) {
                    sumc[i][x] -= tmps, sumc[i][y] += tmps;
                    if (bel[x] != bel[y])
                        sums[i][bel[x]] -= tmps, sums[i][bel[y]] += tmps;
                }
            }
        }
        else {
//          scanf("%d%d%d", &l , &r, &k);
            l = read(), r = read(), k = read();
            register int lb = (l - 1) / Siz + 1;
            register int rb = (r - 1) / Siz + 1;
            if (lb == rb) {
                for (register int i = l; i <= r; i++) {
                    a[i] = a[find(i)];
                    s1[bel[a[i]]]++;
                    s2[a[i]]++;
                }
                register int vl, vr, sum = 0;
                for (register int i = 1; i <= Sizv; i++) {
                    sum += s1[i];
                    if (sum >= k) {
                        sum -= s1[i];
                        vl = (i - 1) * Sizv + 1, vr = i * Sizv;
                        break;
                    }
                }
                for (register int i = vl; i <= vr; i++) {
                    sum += s2[i];
                    if (sum >= k) {
                        printf("%d\n", i);
                        break;
                    }
                }
                for (register int i = l; i <= r; i++)
                    s2[a[i]]--, s1[bel[a[i]]]--;
            }
            else {
                for (register int i = l; i <= R[lb]; i++) {
                    a[i] = a[find(i)];
                    s1[bel[a[i]]]++;
                    s2[a[i]]++;
                }
                for(register int i = L[rb]; i <= r; i++) {
                    a[i] = a[find(i)];
                    s1[bel[a[i]]]++;
                    s2[a[i]]++;
                }
                register int vl, vr, sum = 0;
                for (register int i = 1; i <= Sizv; i++) {
                    sum += s1[i] + sums[rb - 1][i] - sums[lb][i];
                    if (sum >= k) {
                        sum -= s1[i] + sums[rb - 1][i] - sums[lb][i];
                        vl = (i - 1) * Sizv + 1, vr = i * Sizv;
                        break;
                    }
                }
                for (register int i = vl; i <= vr; i++) {
                    sum += s2[i] + sumc[rb - 1][i] - sumc[lb][i];
                    if (sum >= k) {
                        printf("%d\n", i);
                        break;
                    } 
                }
                for (register int i = l; i <= R[lb]; i++) {
                    s1[bel[a[i]]]--;
                    s2[a[i]]--;
                }
                for (register int i = L[rb]; i <= r; i++) {
                    s1[bel[a[i]]]--;
                    s2[a[i]]--;
                }
            }
        }
    }
    return 0;
}