P8233题解

· · 题解

思路

线段树是个好东西,在序列问题上有奇效。但是线段树要求维护的东西是可并的。即已知区间 [l,k] 和区间 [k,r] 的答案后,可以求出区间 [l,r] 的答案。

容易(很难)发现这题的答案是可并的。定义 sum,lmax,rmax,len 分别表示某个区间的答案、从左端点开始的最长黑色连通块、从右端点开始的最长黑色连通块、区间长度,则不难发现:

sum_x=sum_l+sum_r+rmax_l\times lmax_r lmax_x=\begin{cases} lmax_l & lmax_l<len_l \\ lmax_l+lmax_r & lmax_l=len_l\end{cases} rmax_x=\begin{cases} rmax_r & rmax_r<len_r \\ rmax_r+rmax_l & rmax_r=len_r\end{cases}

其中 x 代表当前区间, l,r 分别代表其左子区间和右子区间。

后两个结论是显然的,第一个结论也不难理解:x 的答案必定包含 l 的答案和 r 的答案,且多出了跨越 l,r 两个区间的答案。

这样就可以维护了。看一看数据范围:1\leq l\leq r\leq 10^{18},显然是不太行的,要离散化。离散化后应该将原序列拆分为若干区间来维护,这些区间包括了 n 此操作所涉及的点及相邻两点之间的区间。注意一定要将点和区间分开维护,因为可能给出的操作是有可能只包含一个单点的。

知道了这些代码就好写了。

代码

#include <bits/stdc++.h>
#define ll long long
#define l(x) (x << 1)
#define r(x) (x << 1 | 1)
#define pr (pair3)

using namespace std;

const int maxn = 1e6 + 5, mod = 1e9 + 7;

struct node{
    int l, r, tag;
    ll sum, lmax, rmax, len;
}t[maxn << 5];
struct pair3{
    ll lmax, rmax, sum, len;
};
int opt[maxn];
ll L[maxn], R[maxn], a1[maxn << 1], inv2;

void up(int x){
    t[x].lmax = t[l(x)].lmax;
    if(t[l(x)].lmax == t[l(x)].len) t[x].lmax += t[r(x)].lmax;
    t[x].rmax = t[r(x)].rmax;
    if(t[r(x)].rmax == t[r(x)].len) t[x].rmax += t[l(x)].rmax;
    t[x].sum = (t[l(x)].sum + t[r(x)].sum + (t[l(x)].rmax % mod) * (t[r(x)].lmax % mod)) % mod;
}
void down(int x){
    if(t[x].tag && l(x)){
        t[l(x)].tag = 1;
        t[l(x)].lmax = t[l(x)].rmax = t[l(x)].len;
        t[l(x)].sum = (((((t[l(x)].len % mod) * ((t[l(x)].len - 1) % mod) % mod) % mod) * inv2) % mod + t[l(x)].len % mod) % mod;
        t[r(x)].tag = 1;
        t[r(x)].lmax = t[r(x)].rmax = t[r(x)].len;
        t[r(x)].sum = (((((t[r(x)].len % mod) * ((t[r(x)].len - 1) % mod) % mod) % mod) * inv2) % mod + t[r(x)].len % mod) % mod;
    }
}
void build(int x, int l, int r){
    t[x].l = l, t[x].r = r;
    if(l == r){
        if(l & 1){
            t[x].len = 1;
        }else{
            t[x].len = a1[l / 2 + 1] - a1[l / 2] - 1;
        }
        return;
    }
    int mid = l + r >> 1;
    build(l(x), l, mid);
    build(r(x), mid + 1, r);
    t[x].len = t[l(x)].len + t[r(x)].len;
}
void update(int x, int l, int r){
    if(l <= t[x].l && t[x].r <= r){
        t[x].tag = 1;
        t[x].lmax = t[x].rmax = t[x].len;
        t[x].sum = (((((t[x].len % mod) * ((t[x].len - 1) % mod) % mod) % mod) * inv2) % mod + t[x].len % mod) % mod;
        return;
    }
    down(x);
    int mid = t[x].l + t[x].r >> 1;
    if(l <= mid) update(l(x), l, r);
    if(r >= mid + 1) update(r(x), l, r);
    up(x);
}
pair3 query(int x, int l, int r){
    if(l <= t[x].l && t[x].r <= r){
        return pr{t[x].lmax, t[x].rmax, t[x].sum, t[x].len};
    }
    down(x);
    int mid = t[x].l + t[x].r >> 1;
    if(l <= mid && r <= mid){
        return query(l(x), l, r);
    }else if(l > mid && r > mid){
        return query(r(x), l, r);
    }else if(l <= mid && r > mid){
        pair3 lc = query(l(x), l, r), rc = query(r(x), l, r);
        ll lmax = lc.lmax;
        if(lc.lmax == lc.len) lmax += rc.lmax;
        ll rmax = rc.rmax;
        if(rc.rmax == rc.len) rmax += lc.rmax;
        ll sum = (lc.sum + rc.sum + (lc.rmax % mod) * (rc.lmax % mod) % mod) % mod;
        return pr{lmax, rmax, sum, lc.len + rc.len};
    }else{
        return pr{0, 0, 0, 0};
    }
}

int main(){
    int n;
    scanf("%d", &n);
    inv2 = mod - mod / 2;
    for(int i = 1; i <= n; i ++){
        scanf("%d %lld %lld", &opt[i], &L[i], &R[i]);
        a1[i * 2 - 1] = L[i], a1[i * 2] = R[i];
    }
    sort(a1 + 1, a1 + 2 * n + 1);
    int m = unique(a1 + 1, a1 + 2 * n + 1) - a1 - 1;
    build(1, 1, 2 * m - 1);
    for(int i = 1; i <= n; i ++){
        L[i] = lower_bound(a1 + 1, a1 + m + 1, L[i]) - a1;
        R[i] = lower_bound(a1 + 1, a1 + m + 1, R[i]) - a1;
        if(opt[i] == 1) update(1, L[i] * 2 - 1, R[i] * 2 - 1);
        else printf("%lld\n", query(1, L[i] * 2 - 1, R[i] * 2 - 1).sum);
    }
    return 0;
}

补充

记得多多取模。