P10856 Xor-Forces 题解

· · 题解

首先我们发现操作 1 具有结合律,于是我们要解决的就是每次给定一个 x,令 a'_i = a_{i\oplus x},查询 a' 下标在 [l, r] 内的子数组的颜色段个数。

有个显然的转化:区间颜色段个数等于区间长度减去相邻同色对数,只需考虑如何计算 l\lt i\leq r, a_{i\oplus x} = a_{(i - 1)\oplus x}i 个数即可。

考虑线段树,当询问跨过区间中点 m 时,可以直接判断 a_{m \oplus x}a_{(m - 1)\oplus x} 这一对是否相等,并递归到两个儿子去计算。

现在我们要解决的就是线段树上某个节点对应的答案,因为每次异或的 x 不同,一个节点对应的区间中的元素顺序会被打乱,每个节点对应的答案也不是唯一的,此时我们显然无法对于每个节点储存单一的答案。

这里给出一个结论:如果一个节点对应的区间长度为 l,那么对这个区间有用的 x 只有 l 种!

这是为什么呢?考虑询问的是 \texttt{0 1 2 3} 这一段区间,x = 5,此时区间长度 l4,区间内下标异或上 x 后结果是 \texttt{5 4 7 6},正好对应区间 \texttt{4 5 6 7}x = 1 时的答案!

我们发现,对于 x\geq l 的部分,我们其实相当于将该区节点对应区间的答案转化为了线段树上同一层另一个节点在 x' = x \bmod l 时的答案,而这个节点通过位运算是不难找到的。

于是对于每个节点,我们只需要维护 0\sim l - 1l 种不同的 l 对应的答案,查询时只需计算出另一个节点的位置,并查询其在 x' = x \bmod l 的答案即可。

由于线段树一共 \log n 层,每层节点的总长度之和为 n,于是时间复杂度和空间复杂度均为 \mathcal{O}(n\log n)

代码中线段树的实现是左闭右开。

#include <bits/stdc++.h>

using i64 = long long;

constexpr int N = 3e5 + 5, LOG = 20;

int o, k, m, n;
int a[N], sum[N], M[LOG][N];
std::vector<int> f[N << 2];

#define ls(u) (u << 1)
#define rs(u) (u << 1 | 1)

void build(int u, int l, int r) {
    int len = r - l, bit = std::__lg(len); f[u].resize(len); M[bit][l] = u;
    if (len == 1) return f[u][0] = 0, void();
    int mid = (l + r) >> 1;
    build(ls(u), l, mid); 
    build(rs(u), mid, r);
    for (int x = 0; x < len; x++) {
        if (x < len / 2) f[u][x] = f[ls(u)][x] + f[rs(u)][x] + (a[mid ^ x] == a[(mid - 1) ^ x]);
        else f[u][x] = f[ls(u)][x - (len / 2)] + f[rs(u)][x - (len / 2)] + (a[mid ^ x] == a[(mid - 1) ^ x]);
    }
}

int query(int u, int l, int r, int ql, int qr, int y) {
    int bit = std::__lg(r - l); 
    if (ql <= l && r - 1 <= qr) return f[M[bit][l ^ (y >> bit << bit)]][y % (1 << bit)];
    int mid = (l + r) >> 1, L = 0, R = 0; bool fL = false, fR = false;
    if (ql < mid)  L = query(ls(u), l, mid, ql, qr, y), fL = true;
    if (qr >= mid) R = query(rs(u), mid, r, ql, qr, y), fR = true;
    return fL ? (fR ? L + R + (a[mid ^ y] == a[(mid - 1) ^ y]) : L) : R;
}

int main() {

    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    std::cin >> o >> k >> m; n = (1 << k);
    for (int i = 0; i < n; i++) std::cin >> a[i];

    build(1, 0, n);

    int lst = 0, y = 0;
    for (int i = 1; i <= m; i++) {
        int op; std::cin >> op;
        if (op == 1) {
            int x; std::cin >> x; x ^= (o * lst);
            y ^= x;
        }
        if (op == 2) {
            int l, r; std::cin >> l >> r; l ^= (o * lst), r ^= (o * lst);
            if (l > r) std::swap(l, r);
            std::cout << (lst = r - l + 1 - query(1, 0, n, l, r, y)) << "\n";
        }
    }

    return 0;
}