题解 P4243 【[JSOI2009]等差数列】

· · 题解

线段树

区间加等差数列显然用线段树

但是为了方便查询我们采用差分的方法

令v[i] = v[i+1] - v[i]

对于修改操作则变为了两个单点加,一个区间加。

而对于查询操作我们考虑维护六个信息 :

  1. 当前区间的左端点的数值 -> l
  2. 当前区间的右端点的数值 -> r
  3. 当前区间如果左右端点都不选有多少个等差数列 -> s[0]
  4. 当前区间如果只选左端点有多少个等差数列 -> s[1]
  5. 当前区间如果只选右端点有多少个等差数列 -> s[2]
  6. 当前区间如果左右端点都选有多少个等差数列 -> s[3]

那么答案就是查询区间的s[3]的值

如何合并呢?

信息1和2 直接继承左右区间即可

信息3,4,5,6是相似的 以下以s[0]为例

s[0] = min(lc->s[2]+rc->s[1]-(lc->r==rc->l), lc->s[0]+rc->s[1], lc->s[2]+rc->s[0]);

共有三种情况对于区间合并时合并处的端点选择,取个Min即可另外三个信息同样

查询和修改时存在一些细节需要注意具体见代码

本人的代码

#include <bits/stdc++.h>
#define R register
#define eps 1e-12
#define INF (1<<30)
#define LINF (1ll<<60)
#define LL long long
#define MM(x, y) memset(x, y, sizeof x)
#define Fo(i, x, y) for(R int i=x; i<=y; ++i)
#define Ro(i, x, y) for(R int i=x; i>=y; --i)
using namespace std;
template<typename T> inline T Max(R T x, R T y) {return x > y ? x : y;}
template<typename T> inline T Min(R T x, R T y) {return x < y ? x : y;}
template<typename T> inline void in(R T &x)
{
    static int ch; static bool flag;
    for(flag=false, ch=getchar(); ch<'0'||ch>'9'; ch=getchar()) flag |= ch=='-';
    for(x=0; ch>='0'&&ch<='9'; ch=getchar()) x = (x<<1) + (x<<3) + ch - '0';
    x = flag ? -x : x;
}
/*********************************Samle****************************************/

inline void minn(R int &x, R int y) {if(y < x) x = y;}
struct data{
    int s[4], l, r;
    data operator + (const data &y) const
    {
        data c; c.l = l; c.r = y.r;
        c.s[0] = s[2] + y.s[1] - (r == y.l);
        minn(c.s[0], s[0]+y.s[1]); minn(c.s[0], s[2]+y.s[0]);
        c.s[1] = s[3] + y.s[1] - (r == y.l);
        minn(c.s[1], s[1]+y.s[1]); minn(c.s[1], s[3]+y.s[0]);
        c.s[2] = s[2] + y.s[3] - (r == y.l);
        minn(c.s[2], s[2]+y.s[2]); minn(c.s[2], s[0]+y.s[3]);
        c.s[3] = s[3] + y.s[3] - (r == y.l);
        minn(c.s[3], s[3]+y.s[2]); minn(c.s[3], s[1]+y.s[3]);
        return c;
    }
};
#define lc (o << 1)
#define rc (lc | 1)
#define mid ((l + r) >> 1)
struct node{
    int val; data x;
}tr[500005]; int n, q, A[100005];
inline void pushdown(R int o)
{
    if(tr[o].val)
    {
        tr[lc].val += tr[o].val; tr[rc].val += tr[o].val;
        tr[lc].x.l += tr[o].val; tr[rc].x.l += tr[o].val;
        tr[lc].x.r += tr[o].val; tr[rc].x.r += tr[o].val;
        tr[o].val = 0;
    }
}
inline void build(R int o, R int l, R int r)
{
    if(l == r)
    {
        tr[o].x.s[0] = 0; tr[o].val = 0; tr[o].x.l = tr[o].x.r = A[l];
        tr[o].x.s[1] = tr[o].x.s[2] = tr[o].x.s[3] = 1;
        return;
    }
    build(lc, l, mid); build(rc, mid+1, r);
    tr[o].val = 0;
    tr[o].x = tr[lc].x + tr[rc].x;
}
inline void change(R int o, R int l, R int r, R int x, R int y, R int k)
{
    if(x <= l && y >= r)
    {
        tr[o].val += k;
        tr[o].x.l += k;
        tr[o].x.r += k;
        return;
    }
    pushdown(o);
    if(x <= mid) change(lc, l, mid, x, y, k);
    if(y > mid) change(rc, mid+1, r, x, y, k);
    tr[o].x = tr[lc].x + tr[rc].x;
}
inline data query(R int o, R int l, R int r, R int x, R int y)
{
    if(x <= l && y >= r) return tr[o].x;
    pushdown(o);
    if(y <= mid) return query(lc, l, mid, x, y);
    else if(x > mid) return query(rc, mid+1, r, x, y);
    else return query(lc, l, mid, x, mid) + query(rc, mid+1, r, mid+1, y);
}

int main()
{
    in(n);Fo(i, 1, n) in(A[i]); Fo(i, 1, n-1) A[i] = A[i+1] - A[i];
    build(1, 1, n-1);
    in(q);
    Fo(i, 1, q)
    {
        R char opt; R int s, t, a, b;
        scanf(" %c%d%d", &opt, &s, &t);
        if(opt == 'A')
        {
            in(a); in(b);
            if(s != 1) change(1, 1, n-1, s-1, s-1, a); // 因为1位置没有与前一个的差
            if(t != n) change(1, 1, n-1, t, t, -(a+b*(t-s))); // 不存在n+1与n的差
            if(s != t)change(1, 1, n-1, s, t-1, b);
        }
        else
        {
            if(s == t) // 只有一个数 ans一定是1
            {
                puts("1");continue;
            }
            data re = query(1, 1, n-1, s, t-1);
            printf("%d\n", re.s[3]);
        }
    }
    return 0;
}