从二维偏序到三维偏序之cdq浅谈

· · 算法·理论

11.18 update 对文章漏洞频出而让管理审核多次深感抱歉,并将优化 DP 讲更详细了。

有时你会看到这种题。
怎么说,能想到思路吗?
其实这道题就在求一个东西:

二维偏序

啥!你不知道二维偏序。

二维偏序就是

任意两个点 (a_1,a_2)(b_1,b_2) ,存在这样的关系:

a_1\le b_1$ 且 $a_2\le b_2

我们可以称为 (a_1,a_2)\le(b_1,b_2)

这与题目不是挺符合的吗?
当我们把一个数看成 (i,a_i) 的一个点对,get到了吧 。
这时问题就成了询问比 (i,k) 小的点对的个数。

做法:先将询问区间按左右端点存入,然后再依次将 a_i 存在树状数组中,并在此过程中处理左右端点在此位置的区间。

我们不妨假设我们面前有一个数轴,那么,此时我们的操作可以看做这样:

众所周知,当我们把树状数组视为维护一个数轴时,我们可以利用树状数组求出小于 x 的数的个数。

那我们操作为 (l,r,k) 时,我们要先将 l 之前的数都存进去,查询一次得到在 [1,l-1] 中小于等于 k 的数的个数。
然后再将 r 之前的数存进去,再查询一次,得到 [1,r] 中小于等于 k 的数的个数。
显然,我们现在做个差分,就可以得到 [l,r] 中小于等于 k 的数的个数。

当然,有些时候数据会很大,这时候离散化一下就行了。

多说无益,上代码。

#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N=2e6+4,mod=1e9+7;
int n,m,a[N],ans[N];
struct BIT{//简单的树状数组模板,应该都会吧
    int c[N];
    inline int lowbit(int x){return x&(-x);}
    void ad(int x,int k){
        while(x<=N)c[x]+=k,x+=lowbit(x);
    }
    int sum(int x){
        int res=0;
        while(x)res+=c[x],x-=lowbit(x);
        return res;
    }
}tr;
struct node{
    int x,d,id;
};
vector<node>v[N];
int main(){
    ios::sync_with_stdio(0);cin.tie(0);
    cin>>n>>m;
    for(int i=1;i<=n;i++)cin>>a[i];
    for(int i=1;i<=m;i++){
        int l,r,k;
        cin>>l>>r>>k;
        v[l-1].push_back({k,-1,i});
        v[r].push_back({k,1,i});
    //与差分思想差不多,把区间分为两次查询
    //巧妙运用 -1 来差分
    }
    for(int i=1;i<=n;i++){
        tr.ad(a[i],1);//存入数组
        for(int j=0;j<v[i].size();j++){
            ans[v[i][j].id]+=v[i][j].d*tr.sum(v[i][j].x);
      //计算区间查询的结果
        }
    }
    for(int i=1;i<=m;i++)cout<<ans[i]<<"\n";
    return 0;
}

最终,我们可以发现每个区间只被访问 2 次,推算时间复杂度为 O(\max(n,m)\times \log(\max\{a_i\})),可以同过此题。

完结撒花

好了,想必大家都会二维偏序了吧~

那么我们再来看看这道题。

可以发现这是二维偏序的升级版,三维偏序

来,让我们一起想想怎么做。

首先,可以想到,我们可以通过按 a 排序来消掉一维。

那么我们肯定想要再消一维,这样我们就可以用树状数组了。

可是,如果再按照 b 排序,会把 a 给打乱了,我们就前功尽弃了,怎么办呢?
有没有在排序时可以不打乱 a 的排序方法呢?

还没想到吗? 很显然,的确有这么一种排序,它就是归并排序。

我们想,在归排时,我们是把数组分成两部分,因为这些数据已经被我们按 a 排序过了,所以,前面那一部分的 a 是绝对不会比后面的大的,此时我们再按 b 排,我们就可以得到一个完美的数组,此时我们再用两个指针去分别查询两部分,保持第二个指针及其后面的数的 b 绝对不小于第一个指针及其前面的数的 b,再用树状数组处理 c

这不就万事大吉了吗?

上代码。

#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int mod=1e9+7,N=2e5+4;
struct BIT{
    int n;
    vector<ll>c;
    BIT(int _n):n(_n),c(_n+2){}
    inline int lowbit(int x){return x&(-x);}
    void add(int x,ll k){
        while(x<=n)c[x]+=k,x+=lowbit(x);
    }
    ll sum(int x){
        ll res=0;
        while(x)res+=c[x],x-=lowbit(x);
        return res;
    }
};
BIT B(N);
struct node{int a,b,c,ans,cnt;}x[N],y[N];
int n,k,f[N],len;
bool cmp1(node x,node y){
    return x.a==y.a?(x.b==y.b?x.c<y.c:x.b<y.b):x.a<y.a;
//方便后面离散化
}
bool cmp2(node x,node y){
    return x.b<y.b;
}
void cdq(int l,int r){
    if(l==r)return;
    int mid=(l+r)/2;
    cdq(l,mid),cdq(mid+1,r);//先在小的范围内处理
    sort(x+l,x+mid+1,cmp2);//直接排序,少打几行代码
    sort(x+mid+1,x+r+1,cmp2);
    int i=l,j=mid+1;
    for(;j<=r;j++){
        while(x[i].b<=x[j].b&&i<=mid)B.add(x[i].c,x[i].cnt),i++;
    //将前部分的b小于当前位置的b的存入树状数组
        x[j].ans+=B.sum(x[j].c);
    }
    for(int t=l;t<i;t++)B.add(x[t].c,-x[t].cnt);//记得清空树状数组
}
int main(){
    ios::sync_with_stdio(0);cin.tie(0);
    cin>>n>>k;
    for(int i=1;i<=n;i++)cin>>y[i].a>>y[i].b>>y[i].c;
    sort(y+1,y+1+n,cmp1);//处理a
    for(int i=1,w=0;i<=n;i++){//数据比较大,需要离散化
        w++;
        if(y[i].a!=y[i+1].a||y[i].b!=y[i+1].b||y[i].c!=y[i+1].c)x[++len]=y[i],x[len].cnt=w,w=0;//沿用上一题思路处理区间
    }
    cdq(1,len);
    for(int i=1;i<=len;i++)f[x[i].cnt+x[i].ans-1]+=x[i].cnt;//与它相同的也算
    for(int i=0;i<n;i++)cout<<f[i]<<"\n";
    return 0;
}

我们惊奇地发现,这就是 cdq 分治。

总而言之,cdq 分治就是一种神奇的思想。

好的,如果你到这里都听懂了,恭喜你又学会了一个新思想。

奖励你们一个简单的题目

咋一看,是不是二维偏序?
诶不对,它还有操作时间的限制。
欸,真头大...

不过, 其实我们可以将操作时间也看成一维,这样不就是三维偏序了吗!
改改就行了。

#include<bits/stdc++.h>
using namespace std;
#define ll long long
using VI=vector<int>;
using PI=pair<int,int>;
const int maxn=2e6+5;
ll w,len,ans[maxn];
struct BIT{
    ll c[maxn];
    inline int lowbit(int x){return x&(-x);}
    void add(int x,ll k){
        for(;x<=maxn-5;x+=lowbit(x))c[x]+=k;
    }
    ll ask(int x){
        ll res=0;
        for(;x;x-=lowbit(x))res+=c[x];
        return res;
    }
    void clear(int x){
        if(x<=0)return;
        for(;x<=maxn-5;x+=lowbit(x))c[x]=0;
    }
}tree;
struct node{int x,y,tot,o;ll cnt;}a[maxn],t[maxn];
void cdq(int l,int r){
    if(l>=r)return;
    int mid=(l+r)>>1;
    cdq(l,mid);cdq(mid+1,r);
    int i=l,j=mid+1,le=l;
    while(i<=mid&&j<=r){
        if(a[i].x<=a[j].x){
            if(a[i].o==1)tree.add(a[i].y,a[i].cnt);
            t[le]=a[i];le++;i++;
        }
        else{
            if(a[j].o==2)ans[a[j].cnt]+=tree.ask(a[j].y);
            else if(a[j].o==3)ans[a[j].cnt]-=tree.ask(a[j].y);
            t[le]=a[j];j++;le++;
        }
    }
    while(i<=mid){
        if(a[i].o==1)tree.add(a[i].y,a[i].cnt);
        t[le]=a[i];le++;i++;
    }
    while(j<=r){
        if(a[j].o==2)ans[a[j].cnt]+=tree.ask(a[j].y);
        else if(a[j].o==3)ans[a[j].cnt]-=tree.ask(a[j].y);
        t[le]=a[j];le++;j++;
    }
    for(int k=l;k<=r;k++)a[k]=t[k];
    for(int k=l;k<=r;k++)tree.clear(a[k].y);
}
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);
    int o,sum=0;cin>>o>>w;
    while(1){
        cin>>o;if(o==3)break;
        ll x,y,xx,yy;
        if(o==1){
            cin>>x>>y>>xx;
            a[++len].o=1;a[len].tot=len;
            a[len].x=x;a[len].y=y;
            a[len].cnt=xx;
        }
        else{//查询操作拆分成四个前缀和 
            cin>>x>>y>>xx>>yy;sum++;
            a[++len].o=2;a[len].x=xx;a[len].y=yy;a[len].tot=len;a[len].cnt=sum;
            a[++len].o=3;a[len].x=x-1;a[len].y=yy;a[len].tot=len;a[len].cnt=sum;
            a[++len].o=3;a[len].x=xx;a[len].y=y-1;a[len].tot=len;a[len].cnt=sum;
            a[++len].o=2;a[len].x=x-1;a[len].y=y-1;a[len].tot=len;a[len].cnt=sum;
        }
    }
    cdq(1,len);
    for(int i=1;i<=sum;i++)cout<<ans[i]<<"\n";
    return 0;
}

CDQ 优化 DP

CDQ 分治是一种用于解决高维偏序问题或优化动态规划的算法思想,特别擅长处理带时间轴的多维偏序问题。其核心是通过分治降维,将动态问题转化为静态问题处理。

当 DP 式满足以下条件时可以用 cdq 优化。

此有一图。

关键操作与三维偏序差不多,将第二维排序后,用树状数组计算第三维贡献,并合并结果。

此有一例题。

别看这是个紫,你要有信心 A 了他的好吧。

我们可以发现每一次操作一定是全部霍霍完才最优。

所以可以简单得到 DP 递推式

f_i=\max(f_{i-1},\max\{x_j\times a_i,y_j\times b_i\})

其中,x_i=\frac{f_i\times r_i}{(a_i\times r_i+b_i)}y_i=\frac {f_i}{(a_i\times r_i+b_i)}

转成斜率形式后发现并不单调。

\frac{y_j-y_k}{x_j-x_k}>-\frac{a_i}{b_i}

那怎么做?

首先,我们可以将每一天的信息存储下来,按 -\frac{a_i}{b_i} 排序,保证凸壳有序。

然后将区间分成两块,并按照天数分组,保证左边的天数比右边小,处理 dp 数组也保证先处理时间靠前的。

这时,我们就可以用 cdq 分治来处理,计算前半部分对后半部分的贡献。

代码

#include<bits/stdc++.h>
using namespace std;
using ll=long long;
const int N=1e5+4;
const double eqs=1e-9,inf=1e9;
struct node{
    int p;double x,y;
}q[N],tmp[N];
double f[N],a[N],b[N],R[N];//f为第i天的价值 
int s[N];
bool cmp(node x,node y){return (a[x.p]/b[x.p])<(a[y.p]/b[y.p]);} 
double sl(int x,int y){//斜率 
    if(q[x].x==q[y].x)return inf;
    return (q[y].y-q[x].y)/(q[y].x-q[x].x);
}
void cdq(int l,int r){
    if(l==r){
        f[l]=max(f[l],f[l-1]);
        q[l].x=f[l]/(a[l]*R[l]+b[l])*R[l];
        q[l].y=f[l]/(a[l]*R[l]+b[l]);
        return;
    }
    int mid=l+r>>1,lp=l,rp=mid+1,tp=1;
    for(int i=l;i<=r;i++)//按天数分组,保证用前去算后 
        if(q[i].p<=mid)tmp[lp++]=q[i];
        else tmp[rp++]=q[i];
    for(int i=l;i<=r;i++)q[i]=tmp[i];
    cdq(l,mid);//先处理左区间
    int t=0,p=1;
    for(int i=l;i<=mid;i++){
        while(t>1&&sl(s[t],i)>sl(s[t-1],s[t]))--t;
        s[++t]=i;
    }
    for(int i=mid+1;i<=r;i++){//计算左对右的影响 
        while(p<t&&sl(s[p],s[p+1])>-a[q[i].p]/b[q[i].p])++p;//找到第一个满足条件的点,这是最优转移点
        f[q[i].p]=max(f[q[i].p],q[s[p]].x*a[q[i].p]+q[s[p]].y*b[q[i].p]);
    }
    cdq(mid+1,r);
    lp=l;rp=mid+1;tp=l;
    while(lp<=mid&&rp<=r)//归并 
        if(q[lp].x<q[rp].x)tmp[tp++]=q[lp++];
        else tmp[tp++]=q[rp++];
    while(lp<=mid)tmp[tp++]=q[lp++];
    while(rp<=r)tmp[tp++]=q[rp++];
    for(int i=l;i<=r;i++)q[i]=tmp[i];
}
int main(){
    ios::sync_with_stdio(false),cin.tie(0);
    int n;double s;
    cin>>n>>s;
    for(int i=1;i<=n;i++){
        cin>>a[i]>>b[i]>>R[i];
        q[i].p=i;
        q[i].x=s/(a[q[i].p]*R[q[i].p]+b[q[i].p])*R[q[i].p];
        q[i].y=s/(a[q[i].p]*R[q[i].p]+b[q[i].p]);
        f[i]=s;//初始化 
    }
    sort(q+1,q+1+n,cmp);
    cdq(1,n);
    printf("%.3lf",f[n]);
    return 0;
}

如果你没听懂?
那实在不好意思,实力有限,建议 read 这个。