珂朵莉树模板CF896C Willem, Chtholly and Seniorious题解

泥土笨笨

2021-06-23 15:45:25

Solution

## 0x00 前言 事情是这样的,有一天,教练群里面讨论这道题,我说这道题排名前 $10$ 的题解里面,有 $7$ 个都是错的,我打算写一个对的,避免同学们被误导。这时候,群里的 lxl 默默的来了一句: ![](https://cdn.luogu.com.cn/upload/image_hosting/9vd7ykl9.png) 好的,既然是 lxl 大神出的题,我就更有必要写一篇正确的题解,来表达我对他的膜拜了。 ## 0x01 为什么很多题解不对,照着写会RE 因为如果要用 `split` 操作,截取一段区间的时候,必须要先 `split(r+1)` ,再 `split(l)` ,否则会有 RE ,具体原因我后面会细说。请大家参考其他题解或者资料的时候,也注意这一点。 ## 0x02 什么是珂朵莉树 珂朵莉树,还有个名字叫老司机树(Old Driver Tree, ODT),是一个暴力数据结构。甚至都不一定可以将其称之为数据结构了,我们不妨认为它是一类题目的暴力做法,对于随机数据比较有效。 ## 0x03 珂朵莉树可以解决什么问题 有一类问题,对一个序列,进行一个**区间推平操作**。就是把一个范围内,比如 $[l,r]$ 范围内的数字变成同一个。可能除了推平以外,还夹杂其他操作。如果数据是随机的,就可以用珂朵莉树啦。比如这道题中的操作 $2$ ,将 $[l,r]$ 区间内的所有数都改成 $x$ ,这就是一个区间推平操作。 ## 0x04 珂朵莉树的基本原理 暴力的地方来喽,刚才不是提到有推平操作么?那么推平操作结束以后,被推平的区间内的每个数字都是相同的。其实经过若干次推平以后,我们可以看成,这个序列上的数字是一段一段的,每一小段里面数字相同,整个区间由若干个小段组成。类似这样: ![](https://cdn.luogu.com.cn/upload/image_hosting/k4fik6l0.png) 这个时候,我们定义一个结构体,用一个结构体变量,来表示每个数字相同的段。 ```cpp struct Node { ll l, r;//l和r表示这一段的起点和终点 mutable ll v;//v表示这一段上所有元素相同的值是多少 Node(ll l, ll r = 0, ll v = 0) : l(l), r(r), v(v) {} bool operator<(const Node &a) const { return l < a.l;//规定按照每段的左端点排序 } }; ``` 相关变量的含义,注释里面已经解释了。这里有个细节是, `v` 变量前面加个了 `mutable` 关键字。 `mutable` 的意思是,即使它是个常量,也允许修改v的值,具体我在下面区间修改的地方解释。 当每个数字相同的区间都用一个结构体变量表示以后,我们把这四段插入到一个 `set` 里面, `set` 会按照每段的左端点顺序进行排序,这样这个序列就维护好了,类似下图: ![](https://cdn.luogu.com.cn/upload/image_hosting/nd7klccn.png) 当然,对于本题,一开始的时候,每段都只有一个数,所以我们的set里面维护n个长度为1的段。 ## 0x05 核心操作 `split` 天下大势,分久必合,合久必分,珂朵莉树也一样。随着推平操作的进行,有一些位置被合并到了一个 `Node` 里面,但是也有可能一个 `Node` 要被拆开,其中的一部分要被改变值。 `split` 操作就是干这个用的,参数是一个位置 `pos` ,以 `pos` 去做切割,找到一个包含 `pos` 的区间,把它分成 `[l,pos-1]` , `[pos,r]` 两半。当然,如果 `pos` 本身就是一个区间的开头,就不用切割了,直接返回这个区间。 先看代码 ```cpp set<Node>::iterator split(int pos) { set<Node>::iterator it = s.lower_bound(Node(pos)); if (it != s.end() && it->l == pos) { return it; } it--; if (it->r < pos) return s.end(); ll l = it->l; ll r = it->r; ll v = it->v; s.erase(it); s.insert(Node(l, pos - 1, v)); //insert函数返回pair,其中的first是新插入结点的迭代器 return s.insert(Node(pos, r, v)).first; } ``` 首先,第一行里面的 `s` 是一个全局变量,是那个装 `node` 的 `set` 。大家知道 `set` 里面有个函数叫 `lower_bound` ,它的作用是返回跟参数相等的,或者比参数更大的第一个 `set` 中元素的位置,返回的是一个迭代器。 那么我们按照 `pos` 创建一个 `node` ,然后去查询,就找到了 `it` 这个位置。这个时候有三种情况,一种是我们正好找到了一个区间,它是以 `pos` 开头的,所以就对应了代码中的第一个 `if` 判断,这时候直接返回这个区间的迭代器 `it`。 还有两种情况是,我们找到的这个区间是正好比包含 `pos` 的区间大一点点的,或者`pos`太大了,超过了最后一个区间的右端点。不管怎样先把`it`往前挪一个格,然后这时候看看`it`的右端点,如果比`pos`小,说明是`pos`太大了,就直接返回`s`的`end()`迭代器。否则的话,现在`it`就是应该包含了`pos`的那个区间。这时候,我们要把它一分为二,把原来的那个区间删掉,然后插入两个新区间,分别是`[l,pos-1]`和`[pos,r]`。 这里还有个小技巧,`insert`这个函数是有返回值的,它返回的是一个`pair`,`pair`的第一个字段正好是新插入的那个`node`的位置的迭代器,所以`return`那个东西就行了。 ## 0x06 推平操作`assign` 刚刚的`split`作用是分,现在还需要一个相反的操作,就是合并。当出现对区间的推平操作的时候,我们可以把整个`set`中所有要被合并掉的`node`都删掉,然后插入一个新区间表示推平以后的结果。 ![](https://cdn.luogu.com.cn/upload/image_hosting/y63fx4wt.png) 如图,按照上面的例子,`set`里面有$4$个`node`,此时我们想进行一次推平操作,把`[2,8]`区间内的元素都改成$666$.首先我们发现,`[8,10]`是一个区间,那么需要先`split(9)`,把`[8,8]`和`[9,10]`拆成两个区间。同理,原来的`[1,2]`区间,也需要拆成`[1,1]`和`[2,2]`。 接下来,我们把要被合并的从$2$到$8$的所有`node`都从`set`里面删掉,再重新插入一个`[2,8]`区间就行了。删除某个范围内的元素可以用`set`的`erase`函数,这个函数接受两个迭代器`s`和`t`,把`[s,t)`范围内的东西都删掉。 代码如下: ```cpp void assign(ll l, ll r, ll x) { set<Node>::iterator itr = split(r + 1), itl = split(l); s.erase(itl, itr); s.insert(Node(l, r, x)); } ``` ## 0x07 推平操作里面RE的坑 现在说一下为啥大部分题解都不对,注意刚刚`assign`函数里面调用的那两次`split`,我是先`split(r+1)`,计算出`itr`,然后再`split(l)`,计算`itl`的。**这个顺序不能反**。 为啥不能反?举个具体例子,比如现在有个区间是`[1,4]`,我们想从里面截取`[1,1]`出来,那么我们需要调用两次`split`,分别是`split(2)`和`split(1)`。 假设先调用`split(1)`,如图中间的结果: ![](https://cdn.luogu.com.cn/upload/image_hosting/igxn8u1h.png) 现在的`itl`指向的还是原来的那个`node`,没有什么变化。但是当我们后续调用`itr`的时候,出事儿了。因为这时候,我们把原来的`[1,4]`区间删掉了,拆成了两份,`itr`指向的是后面那个,但是原来`itl`指向的那个已经被`erase`掉了。这时候用`itl`和`itr`调用`s.erase`的时候就会出问题,直接RE。 有同学说我顺序反了没RE啊,也AC。恭喜你,你人品好。这东西理论上会RE,但是实际上概率不大,我对拍了一下,大概1%的概率RE吧。但是人品不好的同学,可能上来就RE一片,而且是随机RE,同一个数据,一会儿能过,一会儿过不了。所以,还是别给自己找麻烦了。 ## 0x08 修改操作`add` 区间内每个数都加上`x`,这个实现方式和前面的推平差不多,我们还是找到这个区间的首尾,然后循环一遍区间内的每个`node`,把每个`node`的`v`都加上`x`就行 ```cpp void add(ll l, ll r, ll x) { set<Node>::iterator itr = split(r + 1), itl = split(l); for (set<Node>::iterator it = itl; it != itr; ++it) { it->v += x; } } ``` 这里是用一个迭代器`it`遍历每个位置,把每个位置的`v`都加`x`。大家发现前面提到的`mutable`的作用了么?因为这里`it`是个常量迭代器,它不能修改它指向的那个`node`,而我们这里要改`node`里面的`v`,所以就把`v`声明为`mutable`,就可以改了。否则会得到类似这样的编译错误: `error: cannot assign to return value because function 'operator->' returns a const value` ## 0x09 其他操作 其他操作都是类似的暴力操作。比如要找区间第$k$小,那么就把区间内所有的`node`拿出来,按照`v`从小到大排序,把每个`node`里面的区间长度相加,看看啥时候加够为止。这里就不细致展开,有问题可以去看代码。 ## 0x0A 复杂度 因为本题数据是随机的,所以每次`assign`操作的区间长度大概在$vmax/3$,所以经过很多次`assign`以后,区间个数不会太多,大概在`log`这个数量级上。这样每次暴力操作的复杂度差不多也是这个数量级。详细的分析,可以参考这篇博客: https://www.luogu.com.cn/blog/blaze/solution-cf896c ## 0x0B 代码时间 ```cpp #include <iostream> #include <set> #include <algorithm> #include <vector> #include <cstdio> using namespace std; typedef long long ll; const ll MOD = 1000000007; const ll MAXN = 100005; struct Node { ll l, r;//l和r表示这一段的起点和终点 mutable ll v;//v表示这一段上所有元素相同的值是多少 Node(ll l, ll r = 0, ll v = 0) : l(l), r(r), v(v) {} bool operator<(const Node &a) const { return l < a.l;//规定按照每段的左端点排序 } }; ll n, m, seed, vmax, a[MAXN]; set<Node> s; //以pos去做切割,找到一个包含pos的区间,把它分成[l,pos-1],[pos,r]两半 set<Node>::iterator split(int pos) { set<Node>::iterator it = s.lower_bound(Node(pos)); if (it != s.end() && it->l == pos) { return it; } it--; if (it->r < pos) return s.end(); ll l = it->l; ll r = it->r; ll v = it->v; s.erase(it); s.insert(Node(l, pos - 1, v)); //insert函数返回pair,其中的first是新插入结点的迭代器 return s.insert(Node(pos, r, v)).first; } /* * 这里注意必须先计算itr。 * 比如现在区间是[1,4],如果要add的是[1,2],如果先split(1) * 那么返回的itl是[1,4],但是下一步计算itr的时候会把这个区间删掉拆成[1,2]和[3,4] * 那么itl这个指针就被释放了 * */ void add(ll l, ll r, ll x) { set<Node>::iterator itr = split(r + 1), itl = split(l); for (set<Node>::iterator it = itl; it != itr; ++it) { it->v += x; } } void assign(ll l, ll r, ll x) { set<Node>::iterator itr = split(r + 1), itl = split(l); s.erase(itl, itr); s.insert(Node(l, r, x)); } struct Rank { ll num, cnt; bool operator<(const Rank &a) const { return num < a.num; } Rank(ll num, ll cnt) : num(num), cnt(cnt) {} }; ll rnk(ll l, ll r, ll x) { set<Node>::iterator itr = split(r + 1), itl = split(l); vector<Rank> v; for (set<Node>::iterator i = itl; i != itr; ++i) { v.push_back(Rank(i->v, i->r - i->l + 1)); } sort(v.begin(), v.end()); int i; for (i = 0; i < v.size(); ++i) { if (v[i].cnt < x) { x -= v[i].cnt; } else { break; } } return v[i].num; } ll ksm(ll x, ll y, ll p) { ll r = 1; ll base = x % p; while (y) { if (y & 1) { r = r * base % p; } base = base * base % p; y >>= 1; } return r; } ll calP(ll l, ll r, ll x, ll y) { set<Node>::iterator itr = split(r + 1), itl = split(l); ll ans = 0; for (set<Node>::iterator i = itl; i != itr; ++i) { ans = (ans + ksm(i->v, x, y) * (i->r - i->l + 1) % y) % y; } return ans; } ll rnd() { ll ret = seed; seed = (seed * 7 + 13) % MOD; return ret; } int main() { cin >> n >> m >> seed >> vmax; for (int i = 1; i <= n; ++i) { a[i] = (rnd() % vmax) + 1; s.insert(Node(i, i, a[i])); } for (int i = 1; i <= m; ++i) { ll op, l, r, x, y; op = (rnd() % 4) + 1; l = (rnd() % n) + 1; r = (rnd() % n) + 1; if (l > r) swap(l, r); if (op == 3) { x = (rnd() % (r - l + 1)) + 1; } else { x = (rnd() % vmax) + 1; } if (op == 4) { y = (rnd() % vmax) + 1; } if (op == 1) { add(l, r, x); } else if (op == 2) { assign(l, r, x); } else if (op == 3) { cout << rnk(l, r, x) << endl; } else { cout << calP(l, r, x, y) << endl; } } return 0; } ```