一些线段树 Trick

· · 算法·理论

收录了一些我觉得很好的,把某些操作用线段树维护的 Trick。

推荐所有人去写 QOJ964。

P11340 [COI 2019] TENIS

应该来说算是性质题。

我们发现,对于一段后缀,若其在三个数组中包含的元素完全相等,那么其中元素一定不符合条件。

原因很简单,它能到达的点仅局限于这一段后缀而无法走到前面的点。

以此类推,符合条件的一段点一定为第一个符合三个数组中所包含元素完全相等的前缀。

考虑维护这个东西。

第一想法是维护 nxt_i,然后找到第一个点使得 \max{nxt_i} = i,但是这个东西很没有可二分性,并不好做。

观察发现,这个东西类似线段并:可以把一个点覆盖的 [i, nxt_i] 中的点继续加入到集合中,直到无法加入。

那么就可以转化为区间加,然后线段树上二分第一个没有被覆盖的线段(即 0)就可以了。

Qoj #964 Excluded Min

题意转化为:对于每个询问求解满足 \sum_{i = l}^{r} [a_i \le \operatorname{mex}] \ge \operatorname{mex} 的最大的 \operatorname{mex}

Trick:

对于不好维护答案但容易检验的查询:

对于后者,需要保证枚举答案增减对当前考虑的所有询问的变化量,而为了方便修改,需要保证当前考虑的询问连续分布。

而此题保证这个条件的关键在于,不能连续分布的询问,可以暂时删去不考虑,因为可以证明这样的询问区间的答案一定 \le 包含其的询问区间的答案。

所以可以只考虑仅相交的所有区间,此时对某一个点进行修改,产生影响的询问在其按左端点排序后一定是一段连续的区间,可以用线段树维护了。

然后就是对于已经确定答案的区间要删去并加入其他可行的区间。

1:|____________|
  2:|______________|
     3:|____________________|

若删去第二个,那么取出待处理集合中满足 l_2 < l < l_3, r_1 < r < r_2 的区间中左端点最靠前的点。重复这样的操作直到无法加入新的点。

每个区间只会加入一次,删除一次,所以复杂度是对的。

可以考虑对于每一个询问维护 b_i = \sum_{l_i}^{r_i} [a_i \le \operatorname{mex}]

具体到实现上,为了方便更新连续询问的答案,将当前考虑的询问区间按左端点存在线段树 sgt1 中,然后不存在询问的位置设置为极小值,可以用区间加维护。以及需要实现整个树扫描后剔除已经满足 b_i \ge \operatorname{mex} 的询问,然后将待处理询问添加到线段树上。

维护待处理区间,可以将区间以右端点为下标存在线段树 sgt2 中,线段树上每个叶节点维护一个栈(栈内元素按照左端点升序),删除元素后能快速加入新的区间,每次对 (r_1, r_2) 查询其中左端点最靠前的点,然后删掉后加入到 sgt1 中。特别的如果新加入的区间也满足 b_i \ge \operatorname{mex} 可以直接确定答案,不用加入到 sgt1 中。

这道题维护不包含区间集合的写法非常有意思,神题。

#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 10;
typedef long long ll;
typedef pair<int, int> pi;
#define endl '\n'
int n, qcnt;
int a[N], ans[N];
vector<int> pos[N];
struct qry {
    int l, r, id;
    bool operator < (const qry y) const {
        if (l == y.l) return r < y.r;
        return l < y.l;
    }
} q[N], y[N];
// 维护:<= x 的 a[i]
namespace bit {
    int c[N];
    void add (int x, int y) {
        for (; x <= n; x += x & -x) c[x] += y;
    }
    int gsum (int x) {
        int res = 0;
        for (; x; x -= x & -x) res += c[x];
        return res;
    }
    int qry (int l, int r) {
        if (l > n) return -1e9;
        return gsum(r) - gsum(l - 1);
    }
};
int mex;
set<qry> s, s2; // 维护当前考虑的集合,方便求前后继
// 维护对应位置
namespace sgt2 {
    vector<qry> st[N];
    qry mi[N << 2];
    qry mrg (qry x, qry y) {
        return x.l < y.l ? x : y;
    }
    void build (int p, int l, int r) {
        if (l == r) {
            reverse(st[l].begin(), st[l].end());
            if (st[l].empty()) mi[p] = qry{n + 1, 0};
            else mi[p] = st[l].back();
        } else {
            int mid = l + r >> 1;
            build (p << 1, l, mid), build (p << 1 | 1, mid + 1, r);
            mi[p] = mrg(mi[p << 1], mi[p << 1 | 1]);
        }
    }
    void del (int p, int l, int r, int x) {
        if (l == r) {
            st[l].pop_back();
            if (st[l].empty()) mi[p] = qry{n + 1, 0};
            else mi[p] = st[l].back();
            return;
        }
        int mid = l + r >> 1;
        if (mid >= x) del (p << 1, l, mid, x);
        else del (p << 1 | 1, mid + 1, r, x);
        mi[p] = mrg(mi[p << 1], mi[p << 1 | 1]);
    }
    qry find (int p, int l, int r, int ql, int qr) {
        if (ql <= l && r <= qr) return mi[p];
        if (l > qr || r < ql) return {n + 1, 0};
        return mrg(find(p << 1, l, l + r >> 1, ql, qr), find(p << 1 | 1, (l + r >> 1) + 1, r, ql, qr));
    }
};
// 维护对应位置询问的 b[i] = qry(l[i], r[i]) - mex
namespace sgt {
    qry a[N];
    int ma[N << 2], tag[N << 2];
    int mar[N << 2];
    void psd (int p) {
        ma[p << 1] += tag[p], ma[p << 1 | 1] += tag[p];
        tag[p << 1] += tag[p], tag[p << 1 | 1] += tag[p];
        tag[p] = 0;
    }
    void build (int p, int l, int r) {
        ma[p] = -1e9;
        if (l < r) {
            int mid = l + r >> 1;
            build (p << 1, l, mid), build (p << 1 | 1, mid + 1, r);
        }
    }
    void upd (int p, int l, int r, int x, qry q) {
        if (l == r) {
            a[l] = q;
            ma[p] = bit::qry(q.l, q.r);
            mar[p] = q.r;
            return;
        }
        psd(p);
        int mid = l + r >> 1;
        if (mid >= x) upd (p << 1, l, mid, x, q);
        else upd (p << 1 | 1, mid + 1, r, x, q);
        ma[p] = max(ma[p << 1], ma[p << 1 | 1]);
        mar[p] = max(mar[p << 1], mar[p << 1 | 1]);
    }
    vector<qry> deled;
    void del (int p, int l, int r, int x) {
        if (l == r) {
            if (ma[p] >= x) {
                // 加新的点进来
                ans[a[l].id] = x;
                auto pos = s.find(a[l]);
                s2.erase({a[l].r, a[l].l, a[l].id});
                upd (1, 1, n, l, {n + 1, 0});
                qry lst = *prev(pos);
                qry nxt = *next(pos);
                s.erase(pos);
                while (1) {
                    qry tmp = sgt2::find(1, 1, n, lst.r + 1, nxt.r - 1);
                    if (lst.l < tmp.l && nxt.l > tmp.l) {
                        sgt2::del(1, 1, n, tmp.r);
                        int res = bit::qry(tmp.l, tmp.r);
                        if (res >= x) ans[tmp.id] = x;
                        else {
                            upd (1, 1, n, tmp.l, tmp);
                            s.insert(tmp);
                            s2.insert({tmp.r, tmp.l, tmp.id});
                            lst = tmp;
                        }
                    } else break;
                }
            }
            return;
        }
        psd(p);
        int mid = l + r >> 1;
        if (ma[p << 1] >= x) del (p << 1, l, mid, x);
        if (ma[p << 1 | 1] >= x) del (p << 1 | 1, mid + 1, r, x);
    }
    void add (int p, int l, int r, int ql, int qr, int x) {
        if (ql <= l && r <= qr) ma[p] += x, tag[p] += x;
        else {
            psd(p);
            int mid = l + r >> 1;
            if (mid >= ql) add (p << 1, l, mid, ql, qr, x);
            if (mid < qr) add (p << 1 | 1, mid + 1, r, ql, qr, x);
            ma[p] = max(ma[p << 1], ma[p << 1 | 1]);
        }
    }
    int findfst (int p, int l, int r, int x) {
        if (l == r) return l;
        int mid = l + r >> 1;
        if (mar[p << 1] > x) return findfst (p << 1, l, mid, x);
        return findfst (p << 1 | 1, mid + 1, r, x);
    }
    void delnum (int pos) {
        qry r = *prev(s.upper_bound({pos, n + 1, 0})); // 最后一个 
        qry l = *s2.lower_bound({pos, 0, 0});
        swap(l.l, l.r);
        if (l.l != 0 && r.l != n + 1) add (1, 1, n, l.l, r.l, -1);
    }
};

signed main() {
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    cin >> n >> qcnt;
    for (int i = 1; i <= n; i ++) cin >> a[i], pos[a[i]].push_back(i);
    for (int i = 1; i <= qcnt; i ++) {
        cin >> q[i].l >> q[i].r;
        q[i].id = i;
    }
    s.insert({0, 0});
    s.insert({n + 1, n + 1});
    s2.insert({0, 0});
    s2.insert({n + 1, n + 1});
    sort (q + 1, q + 1 + qcnt, [](qry x, qry y) {
        if (x.l == y.l) return x.r > y.r;
        return x.l < y.l;
    });
    qry lst = {0, 0};
    sgt::build(1, 1, n);
    for (int i = 1; i <= n; i ++) bit::add(i, 1);
    for (int i = 1; i <= qcnt; i ++) {
        if (q[i].l > lst.l && q[i].r > lst.r) {
            sgt::upd(1, 1, n, q[i].l, q[i]);
            s.insert(q[i]);
            s2.insert({q[i].r, q[i].l, q[i].id});
            lst = q[i];
        } else {
            sgt2::st[q[i].r].push_back(q[i]);
        }
    }
    sgt2::build(1, 1, n);
    for (mex = 5e5; mex >= 0; mex --) {
        for (int v : pos[mex])
            sgt::delnum(v), bit::add(v, -1);
        sgt::del(1, 1, n, mex);
    }
    for (int i = 1; i <= qcnt; i ++) cout << ans[i] << endl;
    return 0;
}

[20231130] 取名字

无名想给自己的模拟赛取个好听的名字,所以她开始用硬币决定名字。

无名有 n 枚硬币,第 i 枚硬币正面权值是 A_i,背面权值是 B_i,一开始都是正面朝上,并按标号排成一排。

无名做了 m 次操作,第 i 次她会将第 l_i 枚到第 r_i 枚硬币中朝上权值不超过 c_i 的硬币翻面。无名想知道最后硬币的朝上的权值和,这就是一个好听的模拟赛名字。

首先直接做限制太多并不太好维护,而且实际翻转的硬币位置不连续。尝试去观察这个修改的性质。

发现当 c_i \in [\min(A_i, B_i), \max(A_i, B_i)) 时,一定会将硬币变到 \max(A_i, B_i) 的状态。

不妨去找到最后一个满足这个条件的修改,然后对于硬币 i 之后的状态就可以通过统计 c_i \ge \max(A_i, B_i) 的修改的数量来判断。

转化后发现,有两重限制 l < i < r, c_i \in [\min(A_i, B_i), \max(A_i, B_i))

区间包含限制可以通过扫描线去除,而后面一个可以通过树套树维护,外层线段树以操作时间为下标,内层线段树以值域为下标,每次找最后一个修改相当于在外层线段树上二分。

然后就做完了。

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10;
const int M = 1.9e7 + 10;
typedef long long ll;
typedef pair<int, int> pi;
namespace Fread
{...}
using namespace Fread;
namespace Fwrite
{...}
using namespace Fwrite;
#define getchar Fread::getchar
#define putchar Fwrite::putchar
namespace Fastio
{...}
using namespace Fastio;
#define cin Fastio::cin
#define cout Fastio::cout
#define endl Fastio::endl
int n;
int num = 1e9;
int a[N], b[N];
struct node {
    int p, v, w, t;
}opt[N << 1];
int idx;
int rt[N << 2], ls[M], rs[M], sum[M];
void upd (int &p, int l, int r, int x, int v) {
    if (!p) p = ++idx;
    sum[p] += v;
    if (l == r) return;
    int mid = l + r >> 1;
    if (mid >= x) upd (ls[p], l, mid, x, v);
    else upd (rs[p], mid + 1, r, x, v);
}
int qry (int p, int l, int r, int ql, int qr) {
    if (!p || l > qr || r < ql || ql > qr) return 0;
    if (ql <= l && r <= qr) return sum[p];
    return qry (ls[p], l, l + r >> 1, ql, qr) + qry(rs[p], (l + r >> 1) + 1, r, ql, qr);
}
void updata (int p, int l, int r, int x, int v, int op) {
    upd (rt[p], 1, num, v, op);
    if (l == r) return;
    int mid = l + r >> 1;
    if (mid >= x) updata (p << 1, l, mid, x, v, op);
    else updata (p << 1 | 1, mid + 1, r, x, v, op);
}
int find (int p, int l, int r, int ql, int qr) {
    if (l == r) {
        return 0;
    }
    int mid = l + r >> 1;
    if (qry(rt[p << 1 | 1], 1, num, ql, qr - 1)) return find(p << 1 | 1, mid + 1, r, ql, qr);
    return find(p << 1, l, mid, ql, qr) + qry(rt[p << 1 | 1], 1, num, qr, num);
}
bool cmp (node x, node y) {
    return x.p < y.p;
}
int fst[N];

int main() {
    ios::sync_with_stdio(0);
    cin >> n;
    vector<int> lsh;
    for (int i = 1; i <= n; i ++) cin >> a[i], lsh.push_back(a[i]);
    for (int i = 1; i <= n; i ++) cin >> b[i], lsh.push_back(b[i]);
    int m;
    cin >> m;
    for (int i = 1; i <= m; i ++) {
        int l, r, v;
        cin >> l >> r >> v;
        lsh.push_back(v);
        opt[2 * i - 1] = node{l, v, 1, i};
        opt[2 * i] = node{r + 1, v, -1, i};
    }
    sort(lsh.begin(), lsh.end());
    lsh.erase(unique(lsh.begin(), lsh.end()), lsh.end());
    num = lsh.size();
    sort (opt + 1, opt + 1 + 2 * m, cmp);
    auto _lsh = [lsh](int& x) ->void{
        x = lower_bound(lsh.begin(), lsh.end(), x) - lsh.begin() + 1;
    };
    for (int i = 1; i <= 2 * m; i ++) _lsh(opt[i].v);
    int j = 1;
    ll ans = 0;
    for (int i = 1; i <= n; i ++) {
        while (j <= 2 * m && opt[j].p <= i) updata(1, 1, m, opt[j].t, opt[j].v, opt[j].w), j ++;
        int A = a[i], B = b[i];
        _lsh(A), _lsh(B);
        if (qry(rt[1], 1, num, min(A, B), max(A, B) - 1) == 0) {
            int res = qry(rt[1], 1, num, max(A, B), num);
            ans += ((res & 1) ? b[i] : a[i]);
        } else {
            int res = find(1, 1, m, min(A, B), max(A, B));
            if (a[i] < b[i]) ans += ((res & 1) ? a[i] : b[i]);
            else ans += ((res & 1) ? b[i] : a[i]);
        }
    }
    cout << ans;
    return 0;
}