P8233题解
思路
线段树是个好东西,在序列问题上有奇效。但是线段树要求维护的东西是可并的。即已知区间
容易(很难)发现这题的答案是可并的。定义
其中
后两个结论是显然的,第一个结论也不难理解:
这样就可以维护了。看一看数据范围:
知道了这些代码就好写了。
代码
#include <bits/stdc++.h>
#define ll long long
#define l(x) (x << 1)
#define r(x) (x << 1 | 1)
#define pr (pair3)
using namespace std;
const int maxn = 1e6 + 5, mod = 1e9 + 7;
struct node{
int l, r, tag;
ll sum, lmax, rmax, len;
}t[maxn << 5];
struct pair3{
ll lmax, rmax, sum, len;
};
int opt[maxn];
ll L[maxn], R[maxn], a1[maxn << 1], inv2;
void up(int x){
t[x].lmax = t[l(x)].lmax;
if(t[l(x)].lmax == t[l(x)].len) t[x].lmax += t[r(x)].lmax;
t[x].rmax = t[r(x)].rmax;
if(t[r(x)].rmax == t[r(x)].len) t[x].rmax += t[l(x)].rmax;
t[x].sum = (t[l(x)].sum + t[r(x)].sum + (t[l(x)].rmax % mod) * (t[r(x)].lmax % mod)) % mod;
}
void down(int x){
if(t[x].tag && l(x)){
t[l(x)].tag = 1;
t[l(x)].lmax = t[l(x)].rmax = t[l(x)].len;
t[l(x)].sum = (((((t[l(x)].len % mod) * ((t[l(x)].len - 1) % mod) % mod) % mod) * inv2) % mod + t[l(x)].len % mod) % mod;
t[r(x)].tag = 1;
t[r(x)].lmax = t[r(x)].rmax = t[r(x)].len;
t[r(x)].sum = (((((t[r(x)].len % mod) * ((t[r(x)].len - 1) % mod) % mod) % mod) * inv2) % mod + t[r(x)].len % mod) % mod;
}
}
void build(int x, int l, int r){
t[x].l = l, t[x].r = r;
if(l == r){
if(l & 1){
t[x].len = 1;
}else{
t[x].len = a1[l / 2 + 1] - a1[l / 2] - 1;
}
return;
}
int mid = l + r >> 1;
build(l(x), l, mid);
build(r(x), mid + 1, r);
t[x].len = t[l(x)].len + t[r(x)].len;
}
void update(int x, int l, int r){
if(l <= t[x].l && t[x].r <= r){
t[x].tag = 1;
t[x].lmax = t[x].rmax = t[x].len;
t[x].sum = (((((t[x].len % mod) * ((t[x].len - 1) % mod) % mod) % mod) * inv2) % mod + t[x].len % mod) % mod;
return;
}
down(x);
int mid = t[x].l + t[x].r >> 1;
if(l <= mid) update(l(x), l, r);
if(r >= mid + 1) update(r(x), l, r);
up(x);
}
pair3 query(int x, int l, int r){
if(l <= t[x].l && t[x].r <= r){
return pr{t[x].lmax, t[x].rmax, t[x].sum, t[x].len};
}
down(x);
int mid = t[x].l + t[x].r >> 1;
if(l <= mid && r <= mid){
return query(l(x), l, r);
}else if(l > mid && r > mid){
return query(r(x), l, r);
}else if(l <= mid && r > mid){
pair3 lc = query(l(x), l, r), rc = query(r(x), l, r);
ll lmax = lc.lmax;
if(lc.lmax == lc.len) lmax += rc.lmax;
ll rmax = rc.rmax;
if(rc.rmax == rc.len) rmax += lc.rmax;
ll sum = (lc.sum + rc.sum + (lc.rmax % mod) * (rc.lmax % mod) % mod) % mod;
return pr{lmax, rmax, sum, lc.len + rc.len};
}else{
return pr{0, 0, 0, 0};
}
}
int main(){
int n;
scanf("%d", &n);
inv2 = mod - mod / 2;
for(int i = 1; i <= n; i ++){
scanf("%d %lld %lld", &opt[i], &L[i], &R[i]);
a1[i * 2 - 1] = L[i], a1[i * 2] = R[i];
}
sort(a1 + 1, a1 + 2 * n + 1);
int m = unique(a1 + 1, a1 + 2 * n + 1) - a1 - 1;
build(1, 1, 2 * m - 1);
for(int i = 1; i <= n; i ++){
L[i] = lower_bound(a1 + 1, a1 + m + 1, L[i]) - a1;
R[i] = lower_bound(a1 + 1, a1 + m + 1, R[i]) - a1;
if(opt[i] == 1) update(1, L[i] * 2 - 1, R[i] * 2 - 1);
else printf("%lld\n", query(1, L[i] * 2 - 1, R[i] * 2 - 1).sum);
}
return 0;
}
补充
记得多多取模。