题解 P6477 【[NOI Online #2 提高组]子序列问题(民间数据)】
update 5.9 ,更新一下题解,请管理大大友善通过
--------以下为正文--------
我觉得大佬们的题解都比较复杂
因为线段树/树状数组里存的都是很复杂的东西
我就比较直接,线段树里存的就是 f(l,r)
因为真的不需要那些很复杂的拆分求单点贡献
1. O(n) 倒序枚举 l (顺序枚举 r 也行,看个人习惯)
2. 维护一个 last[i] 表示i这个颜色上一次出现的位置。
若这个颜色前面没有出现,那么我让
last[i] 默认为n+1
3. 线段树第r 个位置存的是以当前枚举的 l 终点为 r 的f(l,r) 。
我们要求的是
\sum_{l=1}^n\sum_{r=l}^n f(l,r)^2 这里我们先求
\sum_{l=1}^n\sum_{r=l}^n f(l,r) 那么很显然对于每一个不同的
l 求query(1,1,n) 线段树优化的就是后面那个
\sum_{r=l}^n 的过程枚举到一个
l 时,则线段树里存的是\sum f(l,r) 最终要求答案的是
\sum f(l,r)^2 这个其实就在下一步就能解决
4.每当要枚举下一个 l 之前,就把当前的 l 的贡献加上
令
ans[l] 是\sum_{r=l}^nf(l,r)^2 怎么样从
\sum_{r=l}^nf(l,r) 求ans[l] 呢?来一波学过初中数学就能看懂的东西
\because(x+1)^2=x^2+2x+1
\therefore(x+1)^2-x^2=2x+1 更新后-更新前=
2x+1 那么当线段树里的
x 要更新+1 时实际上要求的答案增加了
2\sum{x+1} 也就是增加了
2\sum{x}+2len (len 是区间长度), 在代码中体现为*$ans[l]=ans[l+1]+2query(1,1,last[a[l]]-1)+last[a[l]]-l$**
其实
query(1,1,last[a[l]-1]) 就是\sum{x} last[a[l]]-l$ 就是 $2len 最后
update(1,l,last[a[l]]-1) 因为
a[l] 这个颜色对答案的影响最多只能到last[a[l]]-1 后面的区间已经有了这个颜色,就不会重复增加
5.每个 l 的答案加起来就是总答案
ans[l]=\sum_{r=l}^n f[l][r] (线段树O(logn) )
ANS=\sum_{l=1}^n ans[l] (枚举l O(n) )
说到这,我们发现
其实没有必要开数组ans,只要一边求一边统计就好啦
我的代码里
复杂度
最后注意一下要开long long
运算要取膜
还有细节问题看代码
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<cmath>
using namespace std;
#define ll long long
#define int long long
const int N=1e6+5,mod=1e9+7;
inline int read(){
int x=0; char ch=getchar();
while(ch<'0'||ch>'9')ch=getchar();
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48); ch=getchar();}
return x;
}
int n,a[N],b[N];
#define ls p<<1
#define rs p<<1|1
#define mid ((l(p)+r(p))>>1)
struct Seg{
int l,r,sum,add;
#define l(x) tree[x].l
#define r(x) tree[x].r
#define sum(x) tree[x].sum
#define add(x) tree[x].add
#define len(x) (r(x)-l(x)+1)
}tree[N<<2];
void build(int p,int l,int r){
l(p)=l,r(p)=r;
if(l==r)return;
build(ls,l,mid);
build(rs,mid+1,r);
}
inline void pushdown(int p){
sum(ls)+=add(p)*len(ls);
sum(rs)+=add(p)*len(rs);
add(ls)+=add(p);
add(rs)+=add(p);
add(p)=0;
}
void update(int p,int l,int r,int d){
if(l<=l(p)&&r(p)<=r){
sum(p)+=len(p)*d;
add(p)+=d;
return;
}
if(add(p))pushdown(p);
if(l<=mid)update(ls,l,r,d);
if(r>mid)update(rs,l,r,d);
sum(p)=sum(ls)+sum(rs);
}
int query(int p,int l,int r){
if(l<=l(p)&&r(p)<=r)return sum(p);
if(add(p))pushdown(p);
int ans=0;
if(l<=mid)ans+=query(ls,l,r);
if(r>mid)ans+=query(rs,l,r);
return ans;
}
int last[N];
signed main(){
//freopen("sequence.in","r",stdin);
//freopen("sequence.out","w",stdout);
n=read();
for(int i=1;i<=n;i++)a[i]=b[i]=read();
sort(b+1,b+1+n);
int len=unique(b+1,b+1+n)-b-1;
for(int i=1;i<=n;i++)a[i]=lower_bound(b+1,b+1+len,a[i])-b,last[i]=n+1;
build(1,1,n);
ll op=0,ans=0;
for(int l=n;l>=1;l--){
op=(op+2*query(1,1,last[a[l]]-1)+(last[a[l]]-l))%mod;
update(1,l,last[a[l]]-1,1);
last[a[l]]=l;
ans=(ans+op)%mod;
}
cout<<ans<<endl;
}