P11771 题解

· · 题解

分五种情况:

$a_i\le a_k\le a_j$:答案为 $\min\{B,C\}(a_j-a_k)$。 $a_j\le a_k\le a_i$:答案为 $\min\{A,C\}(a_i-a_k)$。 $a_k\le a_i\le a_j$:答案为 $\min\{A+B,C\}(a_i-a_k)+\min\{B,C\}(a_j-a_i)$。 $a_k\le a_j\le a_i$:答案为 $\min\{A+B,C\}(a_j-a_k)+\min\{A,C\}(a_i-a_j)$。 令 $(A,B,C)\leftarrow (\min\{A,C\},\min\{B,C\},\min\{A+B,C\})$ 即可去掉式子里的根号。 对于第 2,3 种情况,可以在 $i$ 处记录答案,此时需要在权值线段树上维护 $(j,k)$ 的个数和答案中只和 $j,k$ 有关的部分。 对于第 4,5 种情况,可以在 $k$ 处记录答案,需要维护 $(i,j)$ 的个数和 $\max\{0,a_i-a_j\}\cdot A+\max\{0,a_j-a_i\}\cdot B+\min\{a_i,a_j\}\cdot C$,可以用树状数组记录 $\ge x$ 的 $a_i$ 的个数和它们的和。 时间复杂度 $O(n\log n)$,需要卡常。$\color{white}{下面给的这份代码常数不太能过。}
#include<bits/stdc++.h>
using namespace std;
#define uint unsigned int
#define lson (u<<1)
#define rson (u<<1|1)
const int N=500007;
int n,m;
uint ans,s,A,B,C,K,x[N],y[N],cnt[N<<2],len[N<<2],val[N<<2],tag[N<<2],sum[N<<2],extag[N<<2];
void addex(int u,uint v){val[u]+=v*cnt[u];extag[u]+=v;}
void add(int u,uint v){val[u]+=K*len[u]*v;tag[u]+=v;sum[u]+=cnt[u]*v;}
void pushdown(int u){
    if (tag[u]){
        add(lson,tag[u]);add(rson,tag[u]);tag[u]=0;
    }
    if (extag[u]){
        addex(lson,extag[u]);addex(rson,extag[u]);extag[u]=0;
    }
}
void add(int u,int l,int r,int L,int R,uint v){
    if (L<=l&&r<=R){addex(u,v);add(u,1);return;}
    int mid=l+r>>1;pushdown(u);
    if (L<=mid) add(lson,l,mid,L,R,v);
    if (R>mid) add(rson,mid+1,r,L,R,v);
    val[u]=val[lson]+val[rson];sum[u]=sum[lson]+sum[rson];
}
void modify(int u,int l,int r,int x,uint v,uint c){
    ++cnt[u];len[u]+=y[x];sum[u]+=c;val[u]+=v;
    if (l==r) return;
    int mid=l+r>>1;pushdown(u);
    if (x<=mid) modify(lson,l,mid,x,v,c);
    else modify(rson,mid+1,r,x,v,c);
}
uint getcnt(int u,int l,int r,int L,int R){
    if (L<=l&&r<=R) return cnt[u];
    int mid=l+r>>1;pushdown(u);
    if (R<=mid) return getcnt(lson,l,mid,L,R);
    if (L>mid) return getcnt(rson,mid+1,r,L,R);
    return getcnt(lson,l,mid,L,R)+getcnt(rson,mid+1,r,L,R);
}
uint getsum(int u,int l,int r,int L,int R){
    if (L<=l&&r<=R) return sum[u];
    int mid=l+r>>1;pushdown(u);
    if (R<=mid) return getsum(lson,l,mid,L,R);
    if (L>mid) return getsum(rson,mid+1,r,L,R);
    return getsum(lson,l,mid,L,R)+getsum(rson,mid+1,r,L,R);
}
uint getval(int u,int l,int r,int L,int R){
    if (L<=l&&r<=R) return val[u];
    int mid=l+r>>1;pushdown(u);
    if (R<=mid) return getval(lson,l,mid,L,R);
    if (L>mid) return getval(rson,mid+1,r,L,R);
    return getval(lson,l,mid,L,R)+getval(rson,mid+1,r,L,R);
}
void output(int u,int l,int r){
    cout<<l<<' '<<r<<' '<<val[u]<<' '<<cnt[u]<<' '<<sum[u]<<' '<<len[u]<<endl;
    if (l==r) return;pushdown(u);
    int mid=l+r>>1;
    output(lson,l,mid);output(rson,mid+1,r);
}
struct BIT{
    uint cnt[N];
    void clear(){memset(cnt,0,sizeof(cnt));}
    void add(int x,uint v){
        while(x<=m){cnt[x]+=v;x+=x&-x;}
    }
    uint query(int x){
        uint sum=0;
        while(x){sum+=cnt[x];x&=x-1;}
        return sum;
    }
}X,Y;
int main(){
    ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
    cin>>n>>A>>B>>C;C=min(C,A+B);A=min(A,C);B=min(B,C);
    for (int i=1;i<=n;++i){cin>>x[i];y[i]=x[i];}
    sort(y+1,y+1+n);m=unique(y+1,y+1+n)-y-1;K=C-B;
    for (int i=1;i<=n;++i){
        x[i]=lower_bound(y+1,y+1+m,x[i])-y;
//      if (x[i]<m) cout<<getval(1,1,m,x[i]+1,m)<<' '<<getsum(1,1,m,x[i]+1,m)<<endl;
        if (x[i]<m) ans+=getval(1,1,m,x[i]+1,m)-getsum(1,1,m,x[i]+1,m)*y[x[i]]*C;
        add(1,1,m,1,x[i],B*y[x[i]]);
        uint c=i-1-X.query(x[i]),v=s-Y.query(x[i]);
//      cout<<i<<' '<<c<<' '<<v<<endl;
        modify(1,1,m,x[i],(v-c*y[x[i]])*A+c*y[x[i]]*C,c);
        s+=y[x[i]];
        X.add(x[i],1);Y.add(x[i],y[x[i]]);
//      output(1,1,m);
//      cout<<ans<<endl;
    }
    X.clear();Y.clear();
    memset(val,0,sizeof(val));memset(tag,0,sizeof(tag));memset(sum,0,sizeof(sum));memset(len,0,sizeof(len));memset(cnt,0,sizeof(cnt));memset(extag,0,sizeof(extag));
    K=-B;
    for (int i=n;i;--i){
        ans+=getval(1,1,m,x[i],m);
        add(1,1,m,1,x[i],B*y[x[i]]);
        modify(1,1,m,x[i],0,0);
    }
    X.clear();Y.clear();
    memset(val,0,sizeof(val));memset(tag,0,sizeof(tag));memset(sum,0,sizeof(sum));memset(len,0,sizeof(len));memset(cnt,0,sizeof(cnt));memset(extag,0,sizeof(extag));
    K=-A;
    for (int i=n;i;--i){
        ans+=A*y[x[i]]*getsum(1,1,m,1,x[i])+getval(1,1,m,1,x[i]);
        add(1,1,m,x[i],m,0);
        modify(1,1,m,x[i],0,0);
    }
    cout<<ans<<endl;
    return 0;
}