P9403 [POI 2020/2021 R3] Les Bitérables

· · 题解

题意简明,不再阐述。

首先可以对当前两行(假设为第 ii+1 行)的情况分类。

此时可以分为三种情况。

一种是从 0 处调 x 件物品(0\leq x\leq s_{i+1}),这 x 件物品显然对应第 i+1p_1,p_2,...,p_x,然后第 i 行的后 s_i+x-s_{i+1} 个物品去到 d

另一种是从 d 处调 x 件物品(0\leq x \leq s_{i+1}),这 x 件物品显然对应第 i+1p_{s_{i+1}-x+1},p_{s_{i+1}-x+2},...,p_{s_{i+1}},然后第 i 行的前 s_i+x-s_{i+1} 个物品去到 0。

还有一种是从 0 处调 x 件物品,从 d 处调 y 件物品(0\leq x,0\leq y,x+y+s_i=s_{i+1}),分别对应 i+1p_1,p_2,...,p_xp_{s_{i+1}-y+1},p_{s_{i+1}-y+2},p_{s_{i+1}},第 i 行的物品则对应第 i+1 行的 p_{x+1},p_{x+2},...,p_{s_{i+1}-y}

发现最小代价无论在哪种情况中取到其代价随 x 变化所形成的图像都是单谷的(第三种情况看作 y=s_{i+1}-s_{i}-x),于是可以三分,对于每一种情况都三分 x 找到最小代价最后三种情况取最小即可。

此时也是分三种情况,前两种情况与 s_i\leq s_{i+1} 时的前两种情况相同,第三种情况变为调前 x 个元素到 0,调后 y 个元素到 d0<x,0<y,s_i-x-y=s_{i+1})。

最小代价无论在哪种情况中取到其代价随 x 变化所形成的图像也是单谷的,同上三分求最小代价即可。

时间复杂度 O(n\log_3n)

#include<bits/stdc++.h>
using namespace std;
#define mp make_pair
#define int long long
#define db double
#define endl '\n'
#define lowbit(x) x&-x
#define intz(x,a) memset(x,a,sizeof(x))
const int N=5e5+5; 
int s[N];vector<int>p[N];
signed main(){int n,d;cin>>n>>d;
    for(int i=1;i<=n;i++){cin>>s[i];p[i].resize(s[i]+5);
        for(int j=1;j<=s[i];j++)cin>>p[i][j];
    }
    for(int i=1;i<n;i++){
        if(s[i]<=s[i+1]){int l=0,r=s[i+1],ans=(1ll<<63)-1;
            while(l<=r){int mid0=l+(r-l)/3,mid1=r-(r-l)/3,sum0=0,sum1=0;
                for(int j=1;j<=s[i+1];j++)
                    if(j<=mid0)sum0+=p[i+1][j];
                    else if(j<=mid0+s[i])sum0+=abs(p[i+1][j]-p[i][j-mid0]);
                    else sum0+=d-p[i+1][j];
                for(int j=s[i+1]-mid0+1;j<=s[i];j++)sum0+=min(p[i][j],d-p[i][j]);
                for(int j=1;j<=s[i+1];j++)
                    if(j<=mid1)sum1+=p[i+1][j];
                    else if(j<=mid1+s[i])sum1+=abs(p[i+1][j]-p[i][j-mid1]);
                    else sum1+=d-p[i+1][j];
                for(int j=s[i+1]-mid1+1;j<=s[i];j++)sum1+=min(p[i][j],d-p[i][j]);
                if(sum0>=sum1)l=mid0+1,ans=min(ans,sum1);else r=mid1-1,ans=min(ans,sum0);
            }l=0,r=s[i+1];
            while(l<=r){int mid0=l+(r-l)/3,mid1=r-(r-l)/3,sum0=0,sum1=0;
                for(int j=s[i+1];j;j--)
                    if(j>=s[i+1]-mid0+1)sum0+=d-p[i+1][j];
                    else if(j>=s[i+1]-mid0-s[i]+1)sum0+=abs(p[i+1][j]-p[i][j-s[i+1]+mid0+s[i]]);
                    else sum0+=p[i+1][j];
                for(int j=1;j<=-s[i+1]+mid0+s[i];j++)sum0+=min(p[i][j],d-p[i][j]);
                for(int j=s[i+1];j;j--)
                    if(j>=s[i+1]-mid1+1)sum1+=d-p[i+1][j];
                    else if(j>=s[i+1]-mid1-s[i]+1)sum1+=abs(p[i+1][j]-p[i][j-s[i+1]+mid1+s[i]]);
                    else sum1+=p[i+1][j];
                for(int j=1;j<=-s[i+1]+mid1+s[i];j++)sum1+=min(p[i][j],d-p[i][j]); 
                if(sum0>=sum1)l=mid0+1,ans=min(ans,sum1);else r=mid1-1,ans=min(ans,sum0);
            }cout<<ans<<endl;
        }else{int l=0,r=s[i],ans=(1ll<<63)-1;
            while(l<=r){int mid0=l+(r-l)/3,mid1=r-(r-l)/3,sum0=0,sum1=0;
                for(int j=1;j<=s[i];j++)
                    if(j<=mid0)sum0+=p[i][j];
                    else if(j<=mid0+s[i+1])sum0+=abs(p[i][j]-p[i+1][j-mid0]);
                    else sum0+=d-p[i][j];
                for(int j=s[i]-mid0+1;j<=s[i+1];j++)sum0+=min(p[i+1][j],d-p[i+1][j]);
                for(int j=1;j<=s[i];j++)
                    if(j<=mid1)sum1+=p[i][j];
                    else if(j<=mid1+s[i+1])sum1+=abs(p[i][j]-p[i+1][j-mid1]);
                    else sum1+=d-p[i][j];
                for(int j=s[i]-mid1+1;j<=s[i+1];j++)sum1+=min(p[i+1][j],d-p[i+1][j]);
                if(sum0>=sum1)l=mid0+1,ans=min(ans,sum1);else r=mid1-1,ans=min(ans,sum0);
            }l=0,r=s[i];
            while(l<=r){int mid0=l+(r-l)/3,mid1=r-(r-l)/3,sum0=0,sum1=0;
                for(int j=s[i];j;j--)
                    if(j>=s[i]-mid0+1)sum0+=d-p[i][j];
                    else if(j>=s[i]-mid0-s[i+1]+1)sum0+=abs(p[i][j]-p[i+1][j-s[i]+mid0+s[i+1]]);
                    else sum0+=d-p[i][j];
                for(int j=1;j<=-s[i]+mid0+s[i+1];j++)sum0+=min(p[i+1][j],d-p[i+1][j]);
                for(int j=s[i];j;j--)
                    if(j>=s[i]-mid1+1)sum1+=d-p[i][j];
                    else if(j>=s[i]-mid1-s[i+1]+1)sum1+=abs(p[i][j]-p[i+1][j-s[i]+mid1+s[i+1]]);
                    else sum1+=d-p[i][j];
                for(int j=1;j<=-s[i]+mid1+s[i+1];j++)sum1+=min(p[i+1][j],d-p[i+1][j]);
                if(sum0>=sum1)l=mid0+1,ans=min(ans,sum1);else r=mid1-1,ans=min(ans,sum0);
            }
            cout<<ans<<endl;
        }
    }
    return 0;
}