题解 P1637 【三元上升子序列】
windows250 · · 题解
看到都是树状数组的解法,这里发篇线段树的题解,希望对像我一样初学线段树的同学有帮助.
很容易想到 三元上升子序列的个数=每个数前比它小的数的个数*每个数后比它大的数的个数 之和
我们只需要维护 每个数前比它小的数的个数 和 每个数后比它大的数的个数 并记录下来就可以了,
那么怎样用线段树去维护呢?以一组数据 An={ 1,4,5,3 } 为例:
-
首先我们开个桶sum[n]记录每个数的出现次数,那么一开始这个桶就是 0 0 0 0 0
-
插入A1=1,此时的桶为 1 0 0 0 0 ,查询A1前有没有比它小的数: 1 0 0 0 0,好吧一个都没有,记smaller[1]=0;
-
插入A2=4,此时的桶为 1 0 0 1 0 ,查询A2前有没有比它小的数:(1 0 0)1 0,发现此时1比4小,smaller[2]=1;
-
插入A3=5,此时的桶为 1 0 0 1 1 ,查询A3前有没有比它小的数:(1 0 0 1)1,发现此时1,4比5小,smaller[3]=2;
-
插入A4=3,此时的桶为 1 0 1 1 1 ,查询A4前有没有比它小的数:(1 0)1 1 1,发现此时1比3小,smaller[4]=1;
按上面的操作步骤,即每次更新,将sum[A[n]]+1,然后sum[1]~sum[A[n]-1]的和即是small[n]的值,这也是用线段树求逆序对的方法
那么用同样的方法求出所有数后比它大的数,得到bigger[1~4],最后ans=ans+smaller[i]*bigger[i],1<=i<=4
当然如果数据的间隔很大,比如 99 1 9999 999999 99999999999 我们当然没必要开那么大的桶,记住相对大小关系即可:2 1 3 4 5
这种操作就是离散化啦。
详细见代码:
#include<bits/stdc++.h>
#define MAXN 30001
#define ls root<<1,l,mid
#define rs root<<1|1,mid+1,r
#define in(x) x=read()
using namespace std;
inline int read()
{
int X=0,w=1;char ch=getchar();
while(ch>'9' || ch<'0'){if(ch=='-') w=-1;ch=getchar();}
while(ch<='9' && ch>='0') X=(X<<3)+(X<<1)+ch-'0',ch=getchar();
return X*w;
}
struct data
{
int num,pos;
}st[MAXN];
int n,cnt;
int sum[MAXN<<2],num[MAXN],bigger[MAXN],smaller[MAXN];
long long ans;
inline void push_up(int root)
{
sum[root]=sum[root<<1]+sum[root<<1|1];
}
inline void update(int root,int l,int r,int pos)
{
if(l==r && l==pos){sum[root]++;return;}
int mid=(l+r)>>1;
if(mid>=pos) update(ls,pos);
if(mid<pos) update(rs,pos);
push_up(root);
}
inline int query(int root,int l,int r,int L,int R)
{
if(L<=l && r<=R) return sum[root];
int mid=(l+r)>>1,total=0;
if(mid>=L) total+=query(ls,L,R);
if(mid<R) total+=query(rs,L,R);
return total;
}
inline bool mcomp(const data &a,const data &b)
{
return a.num<b.num;
}
int main()
{
in(n);
for(int i=1;i<=n;i++) in(st[i].num),st[i].pos=i;
sort(st+1,st+n+1,mcomp);
for(int i=1;i<=n;i++)
{
if(st[i].num>st[i-1].num) cnt++;
num[st[i].pos]=cnt;
}
//离散化,注意两个数相同的情况
for(int i=1;i<=n;i++)
{
if(num[i]>1) smaller[i]=query(1,1,n,1,num[i]-1);
update(1,1,n,num[i]);
}
//查询一个数前比它小的数的个数
memset(sum,0,sizeof(sum));
//注意清空sum数组
for(int i=n;i>=1;i--)
{
if(num[i]<n) bigger[i]=query(1,1,n,num[i]+1,n);
update(1,1,n,num[i]);
}
//查询一个数后比它大的数的个数
for(int i=1;i<=n;i++) ans+=(bigger[i]*smaller[i]);
//乘法原理 & 加法原理
printf("%lld",ans);
return 0;
}
//By windows250