题解 P5283 【[十二省联考2019]异或粽子】

Nemlit

2019-04-29 09:27:23

Solution

## [原文地址](https://www.cnblogs.com/bcoier/p/10788524.html) ### 前置芝士:[可持久化Trie](https://www.luogu.org/problemnew/show/P4735) & [堆](https://www.luogu.org/problemnew/show/P2048) 类似于超级钢琴,我们用堆维护一个四元组$(st, l, r, pos)$表示以$st$为起点,终点在$[l, r]$内,里面的最大值的位置为$pos$ 我们维护一个小根堆(堆顶最大),权值为st-pos的异或和,每一次找出最大的并删掉 所谓删,就是把一个区间从pos处分裂 即:$(st, l, r)->(st, l, pos - 1) (st, pos + 1, r)$ 这样重新维护pos值即可 维护pos值时,我们需要维护区间内与x的异或值最大,不难想到可持久化$Trie$,~~于是只需要把超级钢琴中的$RMQ$变成可持久化$Trie$即可~~ ``` #include<bits/stdc++.h> using namespace std; #define il inline #define re register #define debug printf("Now is Line : %d\n",__LINE__) #define file(a) freopen(#a".in","r",stdin);freopen(#a".out","w",stdout) #define ll long long il ll read() { re ll x = 0, f = 1; re char c = getchar(); while(c < '0' || c > '9') { if(c == '-') f = -1; c = getchar();} while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar(); return x * f; } #define rep(i, s, t) for(re int i = s; i <= t; ++ i) #define maxn 500005 int n, m; ll sum[maxn], ans; namespace Trie { struct PAX { int id, s, ch[2]; }e[maxn * 50]; int cnt, rt[maxn]; il void insert(int&k, int kk, int bit, int id, ll val) { k = ++ cnt; e[k] = e[kk], ++ e[k].s; if(bit == -1) return(void)(e[k].id = id); int c = (val >> bit) & 1; insert(e[k].ch[c], e[kk].ch[c], bit - 1, id, val); } il int query(int l, int r, int bit, ll val) { if(bit == -1) return e[r].id; int c = (val >> bit) & 1; if(e[e[r].ch[c ^ 1]].s > e[e[l].ch[c ^ 1]].s) return query(e[l].ch[c ^ 1], e[r].ch[c ^ 1], bit - 1, val); return query(e[l].ch[c], e[r].ch[c], bit - 1, val); } } using namespace Trie; struct node { int st, l, r, pos; il bool operator < (const node a) const { return (sum[pos] ^ sum[st - 1]) < (sum[a.pos] ^ sum[a.st - 1]); } node(int St, int L, int R) { st = St, l = L, r = R, pos = query(rt[l - 1], rt[r], 32, sum[st - 1]); } }; priority_queue<node>q; signed main() { file(a); n = read(), m = read(); rep(i, 1, n) sum[i] = sum[i - 1] ^ read(); rep(i, 1, n) insert(rt[i], rt[i - 1], 32, i, sum[i]); rep(i, 1, n) q.push(node(i, i, n)); while(m --) { node t = q.top(); q.pop(); ans += sum[t.pos] ^ sum[t.st - 1]; if(t.l < t.pos) q.push(node(t.st, t.l, t.pos - 1)); if(t.pos < t.r) q.push(node(t.st, t.pos + 1, t.r)); } printf("%lld", ans); return 0; } ```