题解 P6477 【[NOI Online #2 提高组]子序列问题(民间数据)】

· · 题解

update 5.9 ,更新一下题解,请管理大大友善通过

--------以下为正文--------

我觉得大佬们的题解都比较复杂

因为线段树/树状数组里存的都是很复杂的东西

我就比较直接,线段树里存的就是 f(l,r)

因为真的不需要那些很复杂的拆分求单点贡献

1. O(n) 倒序枚举 l (顺序枚举 r 也行,看个人习惯)

2. 维护一个 last[i] 表示i这个颜色上一次出现的位置。

若这个颜色前面没有出现,那么我让 last[i] 默认为 n+1

3. 线段树第r个位置存的是以当前枚举的 l 终点为 rf(l,r)

我们要求的是\sum_{l=1}^n\sum_{r=l}^n f(l,r)^2

这里我们先求\sum_{l=1}^n\sum_{r=l}^n f(l,r)

那么很显然对于每一个不同的 lquery(1,1,n)

线段树优化的就是后面那个 \sum_{r=l}^n的过程

枚举到一个l时,则线段树里存的是\sum f(l,r)

最终要求答案的是 \sum f(l,r)^2

这个其实就在下一步就能解决

4.每当要枚举下一个 l 之前,就把当前的 l 的贡献加上

ans[l]\sum_{r=l}^nf(l,r)^2

怎么样从\sum_{r=l}^nf(l,r)ans[l]呢?

来一波学过初中数学就能看懂的东西

\because(x+1)^2=x^2+2x+1

\therefore(x+1)^2-x^2=2x+1

更新后-更新前=2x+1

那么当线段树里的x要更新+1

实际上要求的答案增加了 2\sum{x+1}

也就是增加了2\sum{x}+2len (len 是区间长度), 在代码中体现为

*$ans[l]=ans[l+1]+2query(1,1,last[a[l]]-1)+last[a[l]]-l$**

其实query(1,1,last[a[l]-1]) 就是 \sum{x}

last[a[l]]-l$ 就是 $2len

最后 update(1,l,last[a[l]]-1)

因为a[l]这个颜色对答案的影响最多只能到last[a[l]]-1

后面的区间已经有了这个颜色,就不会重复增加

5.每个 l 的答案加起来就是总答案

ans[l]=\sum_{r=l}^n f[l][r] (线段树O(logn))

ANS=\sum_{l=1}^n ans[l] (枚举l O(n))

说到这,我们发现

其实没有必要开数组ans,只要一边求一边统计就好啦

我的代码里op就是那个ans[l]

复杂度O(nlogn)

最后注意一下要开long long

运算要取膜

还有细节问题看代码

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<cmath>
using namespace std;
#define ll long long
#define int long long
const int N=1e6+5,mod=1e9+7;
inline int read(){
    int x=0; char ch=getchar();
    while(ch<'0'||ch>'9')ch=getchar();
    while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48); ch=getchar();}
    return x;
}
int n,a[N],b[N];
#define ls p<<1
#define rs p<<1|1
#define mid ((l(p)+r(p))>>1)
struct Seg{
    int l,r,sum,add;
    #define l(x) tree[x].l
    #define r(x) tree[x].r
    #define sum(x) tree[x].sum
    #define add(x) tree[x].add
    #define len(x) (r(x)-l(x)+1)
}tree[N<<2];
void build(int p,int l,int r){
    l(p)=l,r(p)=r;
    if(l==r)return;
    build(ls,l,mid);
    build(rs,mid+1,r);
}
inline void pushdown(int p){
    sum(ls)+=add(p)*len(ls);
    sum(rs)+=add(p)*len(rs);
    add(ls)+=add(p);
    add(rs)+=add(p);
    add(p)=0;
}
void update(int p,int l,int r,int d){
    if(l<=l(p)&&r(p)<=r){
        sum(p)+=len(p)*d;
        add(p)+=d;
        return;
    }
    if(add(p))pushdown(p);
    if(l<=mid)update(ls,l,r,d);
    if(r>mid)update(rs,l,r,d);
    sum(p)=sum(ls)+sum(rs);
}
int query(int p,int l,int r){
    if(l<=l(p)&&r(p)<=r)return sum(p);
    if(add(p))pushdown(p);
    int ans=0;
    if(l<=mid)ans+=query(ls,l,r);
    if(r>mid)ans+=query(rs,l,r);
    return ans; 
}
int last[N];
signed main(){

    //freopen("sequence.in","r",stdin);
    //freopen("sequence.out","w",stdout);
    n=read();
    for(int i=1;i<=n;i++)a[i]=b[i]=read();
    sort(b+1,b+1+n);
    int len=unique(b+1,b+1+n)-b-1;
    for(int i=1;i<=n;i++)a[i]=lower_bound(b+1,b+1+len,a[i])-b,last[i]=n+1;
    build(1,1,n);
    ll op=0,ans=0;
    for(int l=n;l>=1;l--){
        op=(op+2*query(1,1,last[a[l]]-1)+(last[a[l]]-l))%mod;
        update(1,l,last[a[l]]-1,1);
        last[a[l]]=l;
        ans=(ans+op)%mod;
    }
    cout<<ans<<endl;
}