P11771 题解
分五种情况:
#include<bits/stdc++.h>
using namespace std;
#define uint unsigned int
#define lson (u<<1)
#define rson (u<<1|1)
const int N=500007;
int n,m;
uint ans,s,A,B,C,K,x[N],y[N],cnt[N<<2],len[N<<2],val[N<<2],tag[N<<2],sum[N<<2],extag[N<<2];
void addex(int u,uint v){val[u]+=v*cnt[u];extag[u]+=v;}
void add(int u,uint v){val[u]+=K*len[u]*v;tag[u]+=v;sum[u]+=cnt[u]*v;}
void pushdown(int u){
if (tag[u]){
add(lson,tag[u]);add(rson,tag[u]);tag[u]=0;
}
if (extag[u]){
addex(lson,extag[u]);addex(rson,extag[u]);extag[u]=0;
}
}
void add(int u,int l,int r,int L,int R,uint v){
if (L<=l&&r<=R){addex(u,v);add(u,1);return;}
int mid=l+r>>1;pushdown(u);
if (L<=mid) add(lson,l,mid,L,R,v);
if (R>mid) add(rson,mid+1,r,L,R,v);
val[u]=val[lson]+val[rson];sum[u]=sum[lson]+sum[rson];
}
void modify(int u,int l,int r,int x,uint v,uint c){
++cnt[u];len[u]+=y[x];sum[u]+=c;val[u]+=v;
if (l==r) return;
int mid=l+r>>1;pushdown(u);
if (x<=mid) modify(lson,l,mid,x,v,c);
else modify(rson,mid+1,r,x,v,c);
}
uint getcnt(int u,int l,int r,int L,int R){
if (L<=l&&r<=R) return cnt[u];
int mid=l+r>>1;pushdown(u);
if (R<=mid) return getcnt(lson,l,mid,L,R);
if (L>mid) return getcnt(rson,mid+1,r,L,R);
return getcnt(lson,l,mid,L,R)+getcnt(rson,mid+1,r,L,R);
}
uint getsum(int u,int l,int r,int L,int R){
if (L<=l&&r<=R) return sum[u];
int mid=l+r>>1;pushdown(u);
if (R<=mid) return getsum(lson,l,mid,L,R);
if (L>mid) return getsum(rson,mid+1,r,L,R);
return getsum(lson,l,mid,L,R)+getsum(rson,mid+1,r,L,R);
}
uint getval(int u,int l,int r,int L,int R){
if (L<=l&&r<=R) return val[u];
int mid=l+r>>1;pushdown(u);
if (R<=mid) return getval(lson,l,mid,L,R);
if (L>mid) return getval(rson,mid+1,r,L,R);
return getval(lson,l,mid,L,R)+getval(rson,mid+1,r,L,R);
}
void output(int u,int l,int r){
cout<<l<<' '<<r<<' '<<val[u]<<' '<<cnt[u]<<' '<<sum[u]<<' '<<len[u]<<endl;
if (l==r) return;pushdown(u);
int mid=l+r>>1;
output(lson,l,mid);output(rson,mid+1,r);
}
struct BIT{
uint cnt[N];
void clear(){memset(cnt,0,sizeof(cnt));}
void add(int x,uint v){
while(x<=m){cnt[x]+=v;x+=x&-x;}
}
uint query(int x){
uint sum=0;
while(x){sum+=cnt[x];x&=x-1;}
return sum;
}
}X,Y;
int main(){
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
cin>>n>>A>>B>>C;C=min(C,A+B);A=min(A,C);B=min(B,C);
for (int i=1;i<=n;++i){cin>>x[i];y[i]=x[i];}
sort(y+1,y+1+n);m=unique(y+1,y+1+n)-y-1;K=C-B;
for (int i=1;i<=n;++i){
x[i]=lower_bound(y+1,y+1+m,x[i])-y;
// if (x[i]<m) cout<<getval(1,1,m,x[i]+1,m)<<' '<<getsum(1,1,m,x[i]+1,m)<<endl;
if (x[i]<m) ans+=getval(1,1,m,x[i]+1,m)-getsum(1,1,m,x[i]+1,m)*y[x[i]]*C;
add(1,1,m,1,x[i],B*y[x[i]]);
uint c=i-1-X.query(x[i]),v=s-Y.query(x[i]);
// cout<<i<<' '<<c<<' '<<v<<endl;
modify(1,1,m,x[i],(v-c*y[x[i]])*A+c*y[x[i]]*C,c);
s+=y[x[i]];
X.add(x[i],1);Y.add(x[i],y[x[i]]);
// output(1,1,m);
// cout<<ans<<endl;
}
X.clear();Y.clear();
memset(val,0,sizeof(val));memset(tag,0,sizeof(tag));memset(sum,0,sizeof(sum));memset(len,0,sizeof(len));memset(cnt,0,sizeof(cnt));memset(extag,0,sizeof(extag));
K=-B;
for (int i=n;i;--i){
ans+=getval(1,1,m,x[i],m);
add(1,1,m,1,x[i],B*y[x[i]]);
modify(1,1,m,x[i],0,0);
}
X.clear();Y.clear();
memset(val,0,sizeof(val));memset(tag,0,sizeof(tag));memset(sum,0,sizeof(sum));memset(len,0,sizeof(len));memset(cnt,0,sizeof(cnt));memset(extag,0,sizeof(extag));
K=-A;
for (int i=n;i;--i){
ans+=A*y[x[i]]*getsum(1,1,m,1,x[i])+getval(1,1,m,1,x[i]);
add(1,1,m,x[i],m,0);
modify(1,1,m,x[i],0,0);
}
cout<<ans<<endl;
return 0;
}