线段树合并---题解 P3521 【[POI2011]ROT-Tree Rotations】
Vocalise
2020-10-23 20:38:00
线段树合并的一种变形应用。
没翻题解做出来了,感动。
本题序列中的结构实际就是多个单点按照次序合并。
而我们求的逆序对,联想 CDQ 分治,正是可以拆分区间解决的问题。
求逆序对最小值,发现对于每一次合并考虑左右两个区间之间的贡献即可。
显然,不解释。
所以我们可以每次贪心取左右两区间之间交换与不交换所得逆序对数的较小值。
求逆序对若朴素 $\mathcal O(n)$ 归并合并,复杂度是 $\mathcal O(n^2)$ 的,会 TLE 80pts。
可以构造每次合并一个单点的数据,这样每个点会合并 $\mathcal O(n)$ 次。
[记录](https://www.luogu.com.cn/record/40348932)
```cpp
#include <cstdio>
#include <iostream>
#include <cstring>
#include <cmath>
#include <algorithm>
typedef long long ll;
const int N = 200001;
inline int read() {
int x = 0, f = 1; char ch = getchar();
while(ch > '9' || ch < '0') { if(ch == '-') f = -1; ch = getchar(); }
do x = x * 10 + ch - 48, ch = getchar(); while(ch >= '0' && ch <= '9');
return x * f;
}
int n,p[N],q1[N],q2[N],t; ll ans;
void Solve() {
int x = read();
if(x != 0) return void(p[++t] = x);
int l = t + 1;
Solve(); int m = t;
Solve(); int r = t;
int i = l, j = m + 1, c = l; ll s1 = 0;
while(i <= m && j <= r) if(p[i] <= p[j]) q1[c++] = p[i++];
else q1[c++] = p[j++], s1 += m - i + 1;
while(i <= m) q1[c++] = p[i++];
while(j <= r) q1[c++] = p[j++], s1 += m - i + 1;
i = l, j = m + 1, c = l; ll s2 = 0;
while(i <= m && j <= r) if(p[j] <= p[i]) q2[c++] = p[j++];
else q2[c++] = p[i++], s2 += r - j + 1;
while(j <= r) q2[c++] = p[j++];
while(i <= m) q2[c++] = p[i++], s2 += r - j + 1;
if(s1 < s2) {
ans += s1;
for(int i = l;i <= r;i++) p[i] = q1[i];
} else {
ans += s2;
for(int i = l;i <= r;i++) p[i] = q2[i];
}
return;
}
int main() {
n = read();
Solve();
std::printf("%lld\n",ans);
return 0;
}
```
于是,问题变成了如何快速维护一定信息,使我们可以快速对于两个序列,求出两个序列分别对于另一个序列的逆序对数。
树上递归地合并和值有关的信息,线段树合并。
那么我们维护了这两个序列的权值线段树。
但是怎么求呢?
在合并的过程中计算。
对于线段树上每个点,维护其大小 $s$,其中左区间节点数量 $s1$,和右区间节点数量 $s2$。
这样我们合并完成一个节点时,将其左儿子中 $s2$ 的值和右儿子中 $s1$ 的值之积加入贡献即可。
这个方法的实质是在权值一维上再将序列分治。
至此,我们在 $\mathcal O(n\log n)$ 时间中解决了这个问题。
---
下面讲讲实现中的问题。
对于线段树上点的 $s1$ 和 $s2$ 值,都只是暂时的。在每一次合并中都一定会被重新计算。
该解法中似乎需要两次按不同顺序合并取小值,但是实际上合并之后的线段树结构是相同的,因此可以一次合并分别维护两个值,并直接使用原来的结点。
这样空间是 $\mathcal O(n\log n)$ 的,但是须稍大。
```cpp
#include <cstdio>
#include <iostream>
#include <cstring>
#include <cmath>
#include <algorithm>
typedef long long ll;
const int N = 200001;
const int LOGN = 19;
const int NODE = N * LOGN;
inline int read() {
int x = 0, f = 1; char ch = getchar();
while(ch > '9' || ch < '0') { if(ch == '-') f = -1; ch = getchar(); }
do x = x * 10 + ch - 48, ch = getchar(); while(ch >= '0' && ch <= '9');
return x * f;
}
int n,cnt; ll ans;
int s[NODE],s1[NODE],s2[NODE],lc[NODE],rc[NODE];
void New(int &rt,int x,int l,int r) {
s[rt = ++cnt] = 1;
if(l == r) return;
int m = (l + r) >> 1;
if(x <= m) New(lc[rt],x,l,m);
else New(rc[rt],x,m + 1,r);
return;
}
struct Pair {
ll x,y;
Pair(ll _x = 0,ll _y = 0) : x(_x), y(_y) {}
friend Pair operator +(const Pair &x,const Pair &y) {
return Pair(x.x + y.x,x.y + y.y);
}
};
Pair Merge(int &x,int y,int l,int r) {
if(!x || !y) {
if(x) s1[x] = s[x], s2[x] = 0;
else x = y, s2[x] = s[x], s1[x] = 0;
return Pair();
}
if(l == r) return s1[x] += s1[y], s2[x] += s2[y], s[x] += s[y], Pair();
int m = (l + r) >> 1;
Pair r1 = Merge(lc[x],lc[y],l,m);
Pair r2 = Merge(rc[x],rc[y],m + 1,r);
s1[x] = s1[lc[x]] + s1[rc[x]];
s2[x] = s2[lc[x]] + s2[rc[x]];
s[x] = s[lc[x]] + s[rc[x]];
return Pair(r1.x + r2.x + 1ll * s2[lc[x]] * s1[rc[x]],r1.y + r2.y + 1ll * s1[lc[x]] * s2[rc[x]]);
}
int Solve() {
int x = read();
if(x != 0) {
int rt = 0; New(rt,x,1,n);
return rt;
}
int lc = Solve(), rc = Solve();
Pair s = Merge(lc,rc,1,n);
ans += std::min(s.x,s.y);
return lc;
}
int main() {
n = read();
Solve();
std::printf("%lld\n",ans);
return 0;
}
```