题解:P10513 括号

· · 题解

P10513 括号

我们考虑使用线段树实现,对于每个区间,可以维护区间内的最大子段和。

为了进行合并我们还需维护剩余的左/右括号个数,在合并时统计跨块的括号对。

具体来讲,左区间剩余的左括号可以与右区间剩余的右括号结合,数量为两者的 \min 值。

Node merge(Node x, Node y) {
    int minn = min(x.lv, y.rv);
    return{x.v + y.v + minn,x.lv + y.lv - minn,x.rv + y.rv - minn};
}

其中 v 是当前区间的答案,lvrv 分别是剩余的左/右括号个数。

对于区间取反操作,似乎不好直接维护,取反会导致我们维护好的括号全部失配,难以计算。

这里其实有个小 trick,我们可以维护两棵线段树,分别维护正常字符串和取反后的字符串,我们在更新时直接交换两颗树的节点即可。

可以发现,这两颗树只有 vlvrv 是不一样的,所以我们可以把这三个变量封成结构体,将它们缩至一棵树,以降低编程复杂度。

时间复杂度 O(n\log n)

代码实现

#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 5;
struct Node {
    int v, lv, rv; // v: 匹配对数,lv: 左括号剩余,rv: 右括号剩余
};
Node a[N << 2], b[N << 2]; // a: 正常匹配,b: 反转后匹配
bool lazy[N << 2]; // 懒标记,表示是否翻转
char s[N]; // 输入字符串
// 合并两个节点信息
Node merge(Node x, Node y) {
    int minn = min(x.lv, y.rv);
    return{x.v + y.v + minn,x.lv + y.lv - minn,x.rv + y.rv - minn};
}
// 构建线段树
void build(int l, int r, int rt) {
    if (l == r) {
        if (s[l] == '(') a[rt] = {0, 1, 0},b[rt] = {0, 0, 1};
        else if (s[l] == ')') a[rt] = {0, 0, 1},b[rt] = {0, 1, 0};
        else a[rt] = b[rt] = {0, 0, 0};
    }
    else{
        int mid = (l + r) >> 1;
        build(l, mid, rt << 1);
        build(mid + 1, r, rt << 1 | 1);
        a[rt] = merge(a[rt << 1], a[rt << 1 | 1]);
        b[rt] = merge(b[rt << 1], b[rt << 1 | 1]);    
    }
}
// 下推懒标记
void pushdown(int rt) {
    if (lazy[rt]) {
        lazy[rt << 1] ^= 1;
        lazy[rt << 1 | 1] ^= 1;
        swap(a[rt << 1], b[rt << 1]);
        swap(a[rt << 1 | 1], b[rt << 1 | 1]);
        lazy[rt] = 0;
    }
}
// 区间翻转操作
void update(int L, int R, int l, int r, int rt) {
    if (R < l || r < L) return;
    if (L <= l && r <= R) {
        lazy[rt] ^= 1;
        swap(a[rt], b[rt]);
        return;
    }
    pushdown(rt);
    int mid = (l + r) >> 1;
    update(L, R, l, mid, rt << 1);
    update(L, R, mid + 1, r, rt << 1 | 1);
    a[rt] = merge(a[rt << 1], a[rt << 1 | 1]);
    b[rt] = merge(b[rt << 1], b[rt << 1 | 1]);
}
// 查询区间匹配对数
Node query(int L, int R, int l, int r, int rt) {
    if (R < l || r < L) return {0, 0, 0};
    if (L <= l && r <= R) return a[rt];
    pushdown(rt);
    int mid = (l + r) >> 1;
    return merge(query(L, R, l, mid, rt << 1), query(L, R, mid + 1, r, rt << 1 | 1));
}
int n,q;
int main() {
    cin >> n >> s + 1 >> q;
    build(1, n, 1);
    while (q--) {
        int op, x, y;
        cin >> op >> x >> y;
        if (op == 1) update(x, y, 1, n, 1);
        else cout << query(x, y, 1, n, 1).v << "\n";
    }
    return 0;
}