P7823题解

· · 题解

P7823题解

这道题还是一个比较好的线段树优化题,思维过程还是很流畅的(只是有点难调)。

n^{\smash{3}} 做法

f_{i,j,t} 为前 t 个时间点,a_1k_ia_2k_j 的最小代价,暴力转移即可。

n^{\smash{2}} 做法

发现当前一种做法在时间点为t时,a_1,a_2中必有一个取到k_t,且两者之间不区分。故而可以压缩掉一维。

f_{i,j} 为时间点 i ,其中一个 ak_i ,另一个 ak_j 的最小代价,转移方程:

f[i][j]=f[i-1][j]+ \lvert k[i]-k[i-1]\rvert \text{ if }(0\leq j\leq i-2)

需要注意的是,j=i-1 的情况下,f_{i,j} 可能由前面每一个 f_{i-1,j'} 得到(等价于位于 i-1a 不动,位于 j' 的动 ),此时转移方程变为:

f[i][i-1]=\ \min \lbrace f[i-1][j']+ \lvert k[i]-k[j']\rvert \rbrace \text{ } (0 \leq j'\leq i-2)

代码:

void bf(){
    memset(f,0x3f,sizeof(f));
    ans = 0x3fffffffffffffff;
    f[0][0]=0;
    for(int i = 1; i <= n; i++){
        for(int j = 0; j < i; j++){
            f[i][j]=min(f[i][j],f[i-1][j]+abs(a[i]-a[i-1]));
            f[i][i-1]=min(f[i][i-1],f[i-1][j]+abs(a[i]-a[j])); 
        }
    }
    for(int i = 0; i <= n; i++)ans = min(ans,f[n][i]);
    cout << ans;
}

n\log n 做法

注意到转移式一为区间加法,自然想到可以用线段树维护这一段的数据。每次操作到一个数是将值加上即可。

而对于第二个式子,因为在已经做完前 i 个数之后,只存在 i 个状态,对于第 i+1 个数,我们只需要求出这个式子的值,将其作为一个新的点 i+1 插入到线段树中就可以了。

然而,只是单纯的区间最小值用普通线段树可以维护,但式子带可变的绝对值难以用普通的线段树维护。我们需要找到一种方法以去除绝对值。

因此,自然想到可以将每一个 k_i 排序后建树维护。令 rk_i 为其排序后从小到大的排名,则对于比其排名小的部分,绝对值变成了k_i-k_j,对于排名比其大的位置,绝对值变成了 k_j-k_i 。而对于一个点来说,k_i,-k_i 不会改变。对于每个点我们可以将这一部分的权值先计算出来。具体的,我们维护三个权值:val_i ,lval_i=val_i-k_i,rval_i=val_i+k_i 分别代表其初始权值,排名小的时候的权值,排名大的时候的权值。这样操作到第i个数时,根据绝对值计算其排名左边区间的最小值\ \min\lbrace val_j+k_i-k_j=lval_j+k_i\rbrace ,右边区间的最小值\ \min\lbrace val_j+k_j-k_i=rval_j+k_i\rbrace ,两者再取最小值就可以得到这个点的权值了。最后将这个点插入就行了。

实现方面,线段树只需要实现区间加法(这部分可以直接将权值加到根节点的标记上),单点修改,查询区间 lval,rval 的最值就行了。这部分都是\log n 的。

代码见下。

#include<bits/stdc++.h>
using namespace std;
int n;
long long f[1005][1005],a[1000005],rnk[1000005],ans;
struct mai{
    long long val;
    int id;
}b[1000005];
bool cmp(mai a, mai b){
    return a.val == b.val ? a.id < b.id : a.val < b.val;
}
long long inf = 0x33ffffffffffffff;
void bf(){
    memset(f,0x3f,sizeof(f));
    ans = 0x3fffffffffffffff;
    f[0][0]=0;
    for(int i = 1; i <= n; i++){
        for(int j = 0; j < i; j++){
            f[i][j]=f[i-1][j]+abs(a[i]-a[i-1]);
            f[i][i-1]=min(f[i][i-1],f[i-1][j]+abs(a[i]-a[j])); 
        }
    }
    for(int i = 0; i <= n; i++)ans = min(ans,f[n][i]);
    cout << ans;
}

int ls(int x){
    return (x << 1);
}

int rs(int x){
    return ((x << 1)|1);
}

struct seg{
    long long lval,rval,lazy,ans;
}t[5000005];
int tot;

void push_up(int x){
    t[x].lval = min(t[ls(x)].lval,t[rs(x)].lval);
    t[x].rval = min(t[ls(x)].rval,t[rs(x)].rval);
    t[x].ans = min(t[ls(x)].ans,t[rs(x)].ans);
}

void push_down(int x){
    if(t[x].lazy){
        t[ls(x)].ans+=t[x].lazy;
        t[rs(x)].ans+=t[x].lazy;
        t[ls(x)].lazy+=t[x].lazy;
        t[ls(x)].lval+=t[x].lazy;
        t[ls(x)].rval+=t[x].lazy;
        t[rs(x)].lazy+=t[x].lazy;
        t[rs(x)].lval+=t[x].lazy;
        t[rs(x)].rval+=t[x].lazy;
        t[x].lazy=0;
    }
}

void build(int l, int r, int x){
    if(l==r){
        t[x].lval = t[x].rval = t[x].ans = inf;
        return;
    }
    int mid = (l+r)>>1;
    build(l,mid,ls(x));
    build(mid+1,r,rs(x));
    push_up(x);
}

void insert(int x, int pos, int id, long long val, int l, int r){
    if(l == r){
        t[x].lazy = 0;
        t[x].lval = val-a[id];
        t[x].rval = val+a[id];
        t[x].ans = val;
        return;
    }
    push_down(x);
    int mid = (l+r)>>1;
    if(pos <= mid) insert(ls(x),pos,id,val,l,mid);
    else insert(rs(x),pos,id,val,mid+1,r);
    push_up(x);
}

long long askl(int x, int ql, int qr, int l, int r){
    if(qr < ql)return inf;
    if(ql <= l && r <= qr){
        return t[x].lval;
    }
    int mid = (l+r)>>1;
    long long ret = inf;
    push_down(x);
    if(ql <= mid)ret = min(ret,askl(ls(x),ql,qr,l,mid));
    if(qr > mid )ret = min(ret,askl(rs(x),ql,qr,mid+1,r));
    push_up(x);
    return ret;
}

long long askr(int x, int ql, int qr, int l, int r){
    if(qr < ql)return inf;
    if(ql <= l && r <= qr){
        return t[x].rval;
    }
    int mid = (l+r)>>1;
    long long ret = inf;
    push_down(x);
    if(ql <= mid)ret = min(ret,askr(ls(x),ql,qr,l,mid));
    if(qr > mid )ret = min(ret,askr(rs(x),ql,qr,mid+1,r));
    push_up(x);
    return ret;
}

void sol(){
    for(int i = 1; i <= n; i++) b[i].val=a[i],b[i].id = i;
    sort(b+1,b+1+n,cmp);
    for(int i = 1; i <= n; i++)rnk[b[i].id]=i;
    build(0,n,1);
    insert(1,0,0,0,0,n);
    for(int i = 1; i <= n; i++){
        int rk = rnk[i];
        long long val = min(askl(1,0,rk-1,0,n)+a[i],askr(1,rk+1,n,0,n)-a[i]);
        t[1].lazy += abs(a[i]-a[i-1]);
        t[1].ans += abs(a[i]-a[i-1]);
        t[1].lval += abs(a[i]-a[i-1]);
        t[1].rval += abs(a[i]-a[i-1]);
        insert(1,rnk[i-1],i-1,val,0,n);

    }
    cout << t[1].ans;
}

int main(){
    freopen("P7823.in","r",stdin);
    freopen("P7823.out","w",stdout);
    scanf("%d",&n);
    for(int i = 1; i <= n; i++)scanf("%lld",&a[i]);
    sol();
    return 0;
}