题解:P7294 [USACO21JAN] Minimum Cost Paths P

· · 题解

很显然的整体 dp 的形式。

我们设 dp_{i,j} 表示考虑了前 i 列,走到第 j 行的代价。

dp_{i,j}=\min(dp_{i-1,j}+j^2,dp_{i,j-1}+c_i)

因为贡献函数 j^2 是下凸函数,可以猜测,固定 idp_{i,j} 是一个下凸函数。也就是 dp_{i,j} 的差分数组单调不减。

每次从 i-1\to i,对于 dp_{i} 先是直接继承 dp_{i-1},然后对于每个位置打上一个 j^2 的标记。

现在考虑列内的转移。上面说了 dp_{i,j} 的差分数组单调递增,而行内观察形式其实就是差分数组对于 c_i 进行一个 chkmin。由于差分数组单调,于是我们直接二分找到变化点进行修改即可。

由于第二维是 O(n) 的,我们不能直接维护。考虑使用单调栈维护转折点。转折点之间都是一条一次函数。

直接的思路是维护一个差分数组,这样子非常方便更新。但是由于不是问 (1,1)\to (n,m) 的代价,而是中间有 q 个询问,我们不可能每次询问都累加一次差分数组吧。

于是考虑维护 f_i(x)=(i-j)\times x^2+c_j\times x-b 的形式。这个 (i-j)\times x^2 就是上面说的一个 +x^2 的 tag,其中 j 表示上次该点被 chkmin c_j 更新的时间点 j。由于上次是被 c_j 更新,所以就是形式就是 val_p+c_j\times (x-p),其中 p 表示前面一个断点,拆开就是 c_j\times x+b。综上所述,我们只需要在单调栈中对于每个断点维护一个三元组 (x,j,b) 即可。

每次修改由于我们不是直接维护的差分数组,所以可以直接用 f_i(x)-f_i(x-1) 得到差分值。对于后缀推平,我们可以直接从后缀开始不断 check 并且 pop,由于每列只会进队列 O(1) 个的点,所有点也只会被删一次,所以复杂度均摊是正确的。pop 到满足题意之后要二分寻找一下,然后插入新的端点。

每次询问就是离线处理,二分找到询问点在栈中哪两个端点之中,然后求值。

时间复杂度 O((m+q)\log n)

#include<bits/stdc++.h>
#define pb emplace_back
#define fi first
#define se second
#define mp make_pair
using namespace std;
typedef long long ll;
const int maxn=2e5+10;
void cmax(int &x,int y){ x=x>y?x:y; }
void cmin(int &x,int y){ x=x<y?x:y; }
ll V(int x){ return 1ll*x*x; } int c[maxn];
struct node{
    int x,tim; ll deta;
    bool operator < (const node &rhs) const { return mp(x,tim)<mp(rhs.x,rhs.tim);}
    ll get(int x,int t){ return V(x)*(t-tim)+1ll*x*c[tim]+deta; }
}st[maxn]; int top=0;
int n,m,Q; vector<pair<int,int> > q[maxn]; ll ans[maxn];
int find(int x){  return upper_bound(st+1,st+1+top,(node){x+1,0,0})-st-1; }
int main(){
    ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
    cin>>n>>m;
    for(int i=1;i<=m;i++) cin>>c[i];
    cin>>Q;
    for(int i=1;i<=Q;i++){
        int x,y; cin>>x>>y;
        if(n==1) cout<<y<<"\n";
        q[y].pb(x,i);
    }
    if(n==1) return 0;
    for(int i=1;i<=m;i++){
        if(i==1) st[++top]=(node){1,1,-c[1]};
        else{
            while(top){
                int x=st[top].x;
                ll v1=st[top].get(x,i);
                ll v2=st[top].get(x+1,i);
                if(v1+c[i]<v2){ top--; continue; }
                if(st[top].get(n-1,i)+c[i]>=st[top].get(n,i)) break;
                int l=x+1,r=n-1;
                while(l<r){
                    int mid=(l+r)>>1;
                    if(st[top].get(mid,i)+c[i]<st[top].get(mid+1,i)) r=mid;
                    else l=mid+1;
                }
                v1=st[top].get(l,i);
                st[++top]=(node){l,i,v1-1ll*l*c[i]};
                break;
            }
            if(!top) st[++top]=(node){1,i,i-1-c[i]};
        }
        for(auto z:q[i]) ans[z.se]=st[find(z.fi)].get(z.fi,i);
    }
    for(int i=1;i<=Q;i++) cout<<ans[i]<<"\n";
    return 0;
}