题解 P3380 【【模板】二逼平衡树(树套树)】

2017-12-29 22:19:59


显然这道题是没有人写pbds的,于是我们用线段树套pbds的红黑树过一发。

然而贼慢,3000多ms,不过不用写平衡树还是比较高兴的。

于是我们这里用到了一个结构体来储存重复元素,并且采用了奇技淫巧来找到元素。

于是下面会给出代码细节部分。

白送一个红黑树是不是有些兴奋呢?

#include <cstdio>
#include <cctype>
#include <climits>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
#define MAXN 50005
#define MAXM 876543
#define INF INT_MAX
#define Finline __inline__ __attribute__((always_inline))//强行inline
inline char get_char(){
    static char buf[5000001], *p1 = buf, *p2 = buf + fread(buf, 1, 5000000, stdin);
    return p1 == p2 ? EOF : *p1 ++;
}
inline int read(){
    int num = 0;
    char c;
    while (isspace(c = get_char()));
    while (num = num * 10 + c - 48, isdigit(c = get_char()));
    return num;
}
inline void upmin(int &a, const int &b){
    if (a > b) a = b;
}
inline void upmax(int &a, const int &b){
    if (a < b) a = b;
}
struct Node{
    int v, id;
    Node(int a, int b){
        v = a, id = b;
    }
    bool operator < (Node tar) const {
        return v == tar.v ? id < tar.id : v < tar.v;
    }
};//类似于hash去重......
typedef tree<Node, null_type, less<Node>, rb_tree_tag, tree_order_statistics_node_update> Tree;
int n, m, lc[MAXM], rc[MAXM], cnt, times, num[MAXN], max_num;
int tmp;//给指针和调试输出用
Tree stree[MAXM];
Tree::iterator it;//迭代器
inline void Build(int &node, int l, int r){
    node = ++cnt;
    if(l == r) return ;
    int mid = (l + r) >> 1;
    Build(lc[node], l, mid), Build(rc[node], mid + 1, r);//节约空间的线段树
}
inline void Insert(int node, int l, int r, int k, int val){
    stree[node].insert(Node(val, ++times));
    if(l == r){
        upmax(max_num, val);
        return ;
    }
    int mid = (l + r) >> 1;
    if(k <= mid) Insert(lc[node], l, mid, k, val);
    else Insert(rc[node], mid + 1, r, k, val);
}
inline int Get_Rank(int node, int l, int r, int ls, int rs, int val){
    if(l == ls && r == rs) return tmp = stree[node].order_of_key(Node(val, times + 1));//注意这里允许重复
    int mid = (l + r) >> 1;
    if(rs <= mid) return Get_Rank(lc[node], l, mid, ls, rs, val);
    if(ls > mid) return Get_Rank(rc[node], mid + 1, r, ls, rs, val);
    return Get_Rank(lc[node], l, mid, ls, mid, val) + Get_Rank(rc[node], mid + 1, r, mid + 1, rs, val);
}
inline int Get_Rank_Lager(int node, int l, int r, int ls, int rs, int val){
    if(l == ls && r == rs) return tmp = stree[node].order_of_key(Node(val, 0));//注意这里不允许重复
    int mid = (l + r) >> 1;
    if(rs <= mid) return Get_Rank_Lager(lc[node], l, mid, ls, rs, val);
    if(ls > mid) return Get_Rank_Lager(rc[node], mid + 1, r, ls, rs, val);
    return Get_Rank_Lager(lc[node], l, mid, ls, mid, val) + Get_Rank_Lager(rc[node], mid + 1, r, mid + 1, rs, val);
}
inline int Get_Kth(int ls, int rs, int k){
    int l = 0, r = max_num + 1, mid, cnt;
    while(l < r){
        mid = (l + r) >> 1;
        cnt = Get_Rank(1, 1, n, ls, rs, mid);
        if(cnt < k) l = mid + 1;
        else r = mid;
    }
    return l;//二分找kth
}
inline void Change(int node, int l, int r, int k, int val){
    it = stree[node].lower_bound(Node(num[k], 1));
    stree[node].erase(it), stree[node].insert(Node(val, ++times));//注意这里谨慎直接erase值,否则会删除重复元素(在这份代码里其实不需要)
    if(l == r){
        upmax(max_num, val);
        num[k] = val;
        return ;
    }
    int mid = (l + r) >> 1;
    if(k <= mid) Change(lc[node], l, mid, k, val);
    else Change(rc[node], mid + 1, r, k, val);
}
inline int Get_Pre(int node, int l, int r, int ls, int rs, int val){
    if(l == ls && r == rs){
        it = stree[node].lower_bound(Node(val, 0));
        it--;//注意--,lower_bound允许等于
        return it == stree[node].end() ? -INF : it -> v;
    }
    int mid = (l + r) >> 1;
    if(rs <= mid) return Get_Pre(lc[node], l, mid, ls, rs, val);
    if(ls > mid) return Get_Pre(rc[node], mid + 1, r, ls, rs, val);
    return max(Get_Pre(lc[node], l, mid, ls, mid, val), Get_Pre(rc[node], mid + 1, r, mid + 1, rs, val));
}
inline int Get_Nxt(int node, int l, int r, int ls, int rs, int val){
    if(l == ls && r == rs){
        it = stree[node].upper_bound(Node(val, times + 1));//upper_bound不允许
        return it == stree[node].end() ? INF : it -> v;//注意判断it,一定需要
    }
    int mid = (l + r) >> 1;
    if(rs <= mid) return Get_Nxt(lc[node], l, mid, ls, rs, val);
    if(ls > mid) return Get_Nxt(rc[node], mid + 1, r, ls, rs, val);
    return min(Get_Nxt(lc[node], l, mid, ls, mid, val), Get_Nxt(rc[node], mid + 1, r, mid + 1, rs, val));
}
int main(){
    n = read(), m = read();
    Build(tmp, 1, n);
    for(int i = 1; i <= n; i++) Insert(1, 1, n, i, num[i] = read());
    for(int i = 1; i <= m; i++){
        int cons = read(), l, r, val;
        switch(cons){
            case 1:
                l = read(), r = read(), val = read();
                printf("%d\n", Get_Rank_Lager(1, 1, n, l, r, val) + 1);
                break;
            case 2:
                l = read(), r = read(), val = read();
                printf("%d\n", Get_Kth(l, r, val));
                break;
            case 3:
                l = read(), val = read();
                Change(1, 1, n, l, val);
                break;
            case 4:
                l = read(), r = read(), val = read();
                printf("%d\n", Get_Pre(1, 1, n, l, r, val));
                break;
            case 5:
                l = read(), r = read(), val = read();
                printf("%d\n", Get_Nxt(1, 1, n, l, r, val));
                break;
        }
    }
    return 0;
}