题解 P6563 【[SBCOI2020] 一直在你身旁】

· · 题解

可能是既视感太强,我做这道时总是被二分查找的思想干扰X

解析

首先要想到一个暴力的区间 dp:

f(l, r) 表示确定目标在 [l, r] 范围内,最终确定目标最坏情况下所需的消耗。于是就有转移方程:f(l, r)=\min\limits_{l\leq k\leq r}\{\max(f(l, k), f(k+1, r))+a[k]\}(可能有越界X,越界状态直接无视就行了)

接下来我们要设法将这个方程优化成 O(n^2)

 

首先可以从定义得知 f 是满足区间包含单调性的。即有:

\forall l\leq l'\leq r, l\leq r'\leq r, l'\leq r'$,有 $f(l, r)\geq f(l', r')

所以对于一对 l, r,一定有一个 “临界点” m_{l, r}。在这临界点之前,\max 均取 f(k+1, r);在这临界点之后,\max 均取 f(l, k)

并且若固定 r,随着 l 的减小,m_{l, r} 单调不增;若固定 l,随着 r 的增大,m_{l, r} 单调不减

于是我们就可以先枚举一个端点,再枚举一个端点,并在第二层枚举内设一个指针保存临界点位置。可以知道该指针的移动是单调的,因此这部分总的复杂度还是两层枚举,O(n^2)

 

现在我们考虑确定了 l, r 及临界点 m_{l, r} 后,如何快速计算出 dp 的值

假设我们先枚举 l,再枚举 r,显然可以在每次的第二层循环用单调队列维护转移方程中的 \min\limits_{m_{l, r} < k <r}\{f(l, k)+a[k]\},但我们很难维护 \min\limits_{l\leq k\leq m_{l, r}}\{f(k+1, r)+a[k]\},因为第二层循环枚举的就是 r

注意到题目中的 a_1\leq a_2\leq\cdots\leq a_n,再结合区间包含单调性,可以发现 f(l, k)+a[k] 是随 k 单调不减的。因此只需先枚举 r,再枚举 l,同样用单调队列维护 \min\limits_{l\leq k\leq m_{l, r}}\{f(k+1, r)+a[k]\},至于 \min\limits_{m_{l, r} < k < r}\{f(l, k)+a[k]\} 直接取 k=m_{l, r}+1 即可

(单调队列部分具体可见代码,相关变量的定义都是一样的)

若不满足 a_1\leq a_2\leq\cdots\leq a_n

其实即使题目不保证 a_1\leq a_2\leq\cdots\leq a_n,也仍然是可做的

我们先倒序地枚举 l 再(正序)枚举 r,并同样设临界点 m_{l, r},和刚才一样地在每次的第二层循环用单调队列维护转移方程中的 \min\limits_{m_{l, r} < k < r}\{f(l, k)+a[k]\}

我们再对每个端点维护一个维护最小值的单调队列,设当前刚刚枚举到 l, r,对于端点 r 的单调队列中已统计(队列里不一定要有这些元素)的元素就为 \{f(k+1, r))+a[k]:l\leq k < r\}

若当前枚举到的 l, r 的临界点是 m_{l, r},就用 m_{l, r} 维护单调队列队头,将(在原地址)下标超过临界点的元素全部弹出。此时 r 的单调队列的队头就是 \min\limits_{l\leq k\leq m_{l, r}}\{f(k+1, r)+a[k]\} 的值

最后,再将 f(l, r)+a[l-1]r 的单调队列的队尾插入即可(注意此时要求 l\geq 1

注意到这里的 m_{l, r}l 的减小单调不增,因此我们对单调队列队头的维护总是合法的

我们对每个 r 维护的单调队列的复杂度消耗都是 O(n),总的复杂度也就是 O(n^2)

(也可以先正序地枚举 r 再(倒序)枚举 l,做法也是差不多的)

但这种方法在过程中需要维护 O(n) 个单调队列,可能空间有点紧。我还没试过,不知道会不会爆qaq。试过了,的确会炸,而且大数组寻址、维护多一倍的单调队列带来的常数也不小(

CODE

满足 a_1\leq a_2\leq\cdots\leq a_n 的做法的代码

#include <cstdio>
#include <iostream>
#define ll long long
using std::pair;
using std::min;
typedef pair<ll, int> pad;

const int MAXN =7110;

/*------------------------------Monotone queue------------------------------*/

pad q[MAXN];
int head, tail;

inline ll front(){
    if(head < tail)
        return q[head].first;
    else
        return 0x3f3f3f3f3f3f3f3f;
}

void modifyfront(int nowtime){
    while(head < tail && q[head].second >= nowtime)
        ++head;
}

void push(pad newelement){
    while(head < tail && q[tail-1].first >= newelement.first)
        --tail;
    q[tail++] =newelement;
}

/*------------------------------Main------------------------------*/

inline int read(){
    int x =0; char c =getchar();
    while(c < '0' || c > '9') c =getchar();
    while(c >= '0' && c <= '9') x = (x<<3) + (x<<1) + (48^c), c =getchar();
    return x;
}

ll dp[MAXN][MAXN];
int a[MAXN];

int main(){
    for(int t =0, T =read(); t < T; ++t){
        int n =read();
        for(int i =0; i < n; ++i)
            a[i] =read();
        for(int r =1; r < n; ++r){
            head =tail =0;
            /*mid: 转移方程分界点*/
            for(int l =r-1, mid =r-1; l >= 0; --l){
                while(mid >= l && dp[l][mid] >= dp[mid+1][r])
                    --mid;
                modifyfront(mid+1);
                dp[l][r] =min(dp[l][mid+1]+a[mid+1], front());
                if(l > 0)/*避免最后越界*/
                    push(pad(dp[l][r]+a[l-1], l-1));
            }
        }
        printf("%lld\n", dp[0][n-1]);
    }
}

 

不满足 a_1\leq a_2\leq\cdots\leq a_n 的做法的代码

/*不要求 ai 单调不减的版本*/
/*注意会炸空间*/
#include <cstdio>
#include <iostream>
#define ll long long
using std::pair;
using std::min;
typedef pair<ll, int> pad;

const int MAXN =7101;

/*------------------------------Monotone queue------------------------------*/

pad q1[MAXN], q2[MAXN][MAXN];
int head1, tail1, head2[MAXN], tail2[MAXN];

inline ll front(const int &head, const int &tail, const pad *q){
    if(head < tail)
        return q[head].first;
    else
        return 0x3f3f3f3f3f3f3f3f;
}

void modifyfront1(const int &nowtime){
    while(head1 < tail1 && q1[head1].second <= nowtime)
        ++head1;
}

void modifyfront2(const int &nowtime, int &head, const int &tail, const pad *q){
    while(head < tail && q[head].second >= nowtime)
        ++head;
}

void push(const pad &newelement, const int &head, int &tail, pad *q){
    while(head < tail && q[tail-1].first >= newelement.first)
        --tail;
    q[tail++] =newelement;
}

/*------------------------------Main------------------------------*/

inline int read(){
    int x =0; char c =getchar();
    while(c < '0' || c > '9') c =getchar();
    while(c >= '0' && c <= '9') x = (x<<3) + (x<<1) + (48^c), c =getchar();
    return x;
}

ll dp[MAXN][MAXN];
int a[MAXN];

int main(){
    for(int t =0, T =read(); t < T; ++t){
        int n =read();
        for(int i =0; i < n; ++i)
            a[i] =read(), head2[i] =tail2[i] =0;
        for(int l =n-2; l >= 0; --l){
            head1 =tail1 =0;
            /*由于实现上跳过了 l == r 的部分,所以这里要对单调队列初始化*/
            /*初始化: [l+1, l+1] ( 这其实是在替左端点为 l+1 ( 上一次的 l ) 时的 l == r 的时刻初始化 )*/
            push(pad(0+a[l], l), head2[l+1], tail2[l+1], q2[l+1]);
            /*初始化: [l, l]*/
            push(pad(0+a[l], l), head1, tail1, q1);
            /*mid: 转移方程分界点*/
            /*这里的 mid 和题解略有不同,这里 [mid, r-1] 均取 dp[l][k],[l, mid-1] 均取 dp[k+1][r]*/
            for(int r =l+1, mid =l; r < n; ++r){
                while(mid < r && dp[l][mid] <= dp[mid+1][r])
                    ++mid;
                modifyfront1(mid-1);
                modifyfront2(mid, head2[r], tail2[r], q2[r]);
                dp[l][r] =min(front(head1, tail1, q1), front(head2[r], tail2[r], q2[r]));
                push(pad(dp[l][r]+a[r], r), head1, tail1, q1);
                if(l > 0)/*避免最后越界*/
                    push(pad(dp[l][r]+a[l-1], l-1), head2[r], tail2[r], q2[r]);
            }
        }
        printf("%lld\n", dp[0][n-1]);
    }
}