整体DP—从勤拿少取到一气呵成

· · 算法·理论

可能更好的阅读体验

0. 前言

这可能是全网比较全面的整体 DP 思想叙述?话说回来到底什么是整体 DP 啊嘞 (。ŏ_ŏ)?

你在做 DP 题的时候。如果熟练的话,常常会写出正确的 DP 方程,却因为状态太多或过多的重复转移而超时。而整体 DP 的思想就是一种化繁为简的思路,我们把原本需要逐个处理的状态或子问题,放进一个整体里批量处理。本蒟蒻遇到的情况大致可以分为以下两类:

上述的所有过程,都体现整体 DP 的思想中的核心——整体:用一次整体操作代替多次操作,避免重复计算

想这些概括的词很累的呜呜呜

1. 转移整体化

1.1 概述

转移整体化是什么呢?你看上面不说人话的大范围概括,其实就是通过类似数据结构维护序列的方式将 DP 状态中的一维压入数据结构,并通过批量操作(整体修改、整体查询)优化。其中最常见的就是线段树合并优化转移。

同时不难发现,我们由于是要求对整体进行转移,我们需要的是这个转移具有一定的固定性,即有大量相同或者说相似转移的 dp,这样才可以优化。若某一维(通常是后一维)的转移具有较强的共性时,可以考虑利用整体 dp 优化。

一般的,使用整体 dp 的题目有以下几个步骤:

  1. 写出朴素 DP;
  2. 发现朴素 DP 转移具有大量重复同样的操作,或将朴素 DP 通过前缀和等方式将转移具有一定的固定性。
  3. 出 dp 状态中转移具有共性的一维,使用数据结构维护这一维。具体地,随着其它维的变化,在数据结构上执行各种修改操作,动态维护此时此刻,当压进数据结构一维的下标为某个值时的 dp 值。

接下来我们将会以例题来详细解释这种整体转移的优化。

1.2 例题

线段树维护整体位移—P9400

显然的 DP,设 f(i,j) 表示考虑到前 i 个数,最后的 j 个大于 b 的方案数,有转移:

\begin{aligned} f(i,j)&=f(i-1,j-1)\cdot v_{1} & j>0 \\ f(i,0) & = \sum\limits_{j=0}^{a_{i}-1} f(i-1,j)\cdot v_{2} \end{aligned}

其中 v_{1}=\max(0,r_{i}-\max\{l_{i}-1,b\})v_{2}=(r_{i}-l_{i}+1)-\max(0,r_{i}-\max\{l_{i}-1,b\})。时间复杂度 O(n^2),考虑优化,注意到第二个转移可以通过前缀和 O(1) 求出,那么不难发现所有转移都由上一层转移过来,并且所有转移 O(1) 进行。考虑整体 DP,对于第一个方程可以看作整体向后移动一次然后整体乘上 v_{1},对于第二个转移就是更新以前的值乘上权值单点更新。

故我们需要维护一个支持单点插、单点删、区间位移、区间乘、区间求和的数据结构,用 FHQ-Treap 简单可以做到,但是有更简单的做法,由于序列长度始终为 a,且位移最多 n 次,因此我们用线段树维护一个长度为 n+a 的序列,类似于滑动窗口一样的维护当前的有效区间,每次操作先算出新的 f(i,0),然后移动指针整体向左移动,将左侧新增的位置设置为 f(i,0),具体实现见代码,时间复杂度 O(n\log n)

#include<bits/stdc++.h>
#define int long long
using namespace std;
constexpr int MN=5e5+15,MOD=998244353;
int n,a,b,L[MN],R[MN];

struct Segment{
    #define ls p<<1
    #define rs p<<1|1
    struct Node{
        int l,r,val,tag;
    }t[MN<<2];

    void pushup(int p){
        t[p].val=(t[ls].val+t[rs].val)%MOD;
    }

    void domul(int p,int k){
        t[p].val=t[p].val*k%MOD;
        t[p].tag=t[p].tag*k%MOD;
    }

    void pushdown(int p){
        if(t[p].tag!=1){
            domul(ls,t[p].tag);
            domul(rs,t[p].tag);
            t[p].tag=1;
        }
    }

    void build(int p,int l,int r){
        t[p].l=l;
        t[p].r=r;
        t[p].tag=1;
        if(l==r) return;
        int mid=(l+r)>>1;
        build(ls,l,mid);
        build(rs,mid+1,r);
        pushup(p);
    }

    void modify(int p,int fl,int fr,int k){
        if(t[p].l>=fl&&t[p].r<=fr){
            domul(p,k);
            return;
        }
        pushdown(p);
        int mid=(t[p].l+t[p].r)>>1;
        if(mid>=fl) modify(ls,fl,fr,k);
        if(mid<fr) modify(rs,fl,fr,k);
        pushup(p);
    }

    void change(int p,int pos,int k){
        if(t[p].l==t[p].r){
            t[p].val=k;
            return;
        }
        pushdown(p);
        int mid=(t[p].l+t[p].r)>>1;
        if(mid>=pos) change(ls,pos,k);
        else change(rs,pos,k);
        pushup(p);
    }

    int query(int p,int fl,int fr){
        if(t[p].l>=fl&&t[p].r<=fr){
            return t[p].val;
        }
        pushdown(p);
        int mid=(t[p].l+t[p].r)>>1,ret=0;
        if(mid>=fl) (ret+=query(ls,fl,fr))%=MOD;
        if(mid<fr) (ret+=query(rs,fl,fr))%=MOD;
        return ret;
    }
}sg;

signed main(){
    cin>>n>>a>>b;
    for(int i=1;i<=n;i++){
        cin>>L[i]>>R[i];
    }
    sg.build(1,1,n+a);
    int ql=n+1,qr=n+a;
    sg.change(1,n+1,1);
    for(int i=1;i<=n;i++){
        int sum=sg.query(1,ql,qr);
        ql--,qr--;
        int val=max(0ll,R[i]-max(b,L[i]-1));
        if(ql+1<=qr){
            sg.modify(1,ql+1,qr,val);
        }
        val=max(0ll,min(b,R[i])-L[i]+1)*sum%MOD;
        sg.change(1,ql,val);
    }
    cout<<sg.query(1,ql,qr)<<'\n';
    return 0;
}

线段树维护复杂转移—P8476

显然有一个 O(nV^2) 的 DP 就是设 f(i,j) 表示前 i 个数最后一个 b_{i-1} 的值为 j 的最小答案。显然有转移:

f(i,j)=\min\limits_{k=j}^n f(i-1,k)+w(i,j)

其中 w(i,j) 表示将 a_{i} 改为 j 的代价。

注意到 V=10^9 很难泵,不过我们可以通过离散化 a 将时间复杂度做到 O(n^3)。考虑优化,注意到这个 \min 操作是一个后缀 \min,而且转移都是逐层转移的。考虑整体 DP,可以用线段树简单维护这个转移因为没有位移操作,但注意到 w 是一个分段函数,我们考虑 a_{i}=j 的位置作为分界点,对于这个位置之前的所有下标操作代价都是 C,区间价即可。

而对于后面的所有下标 x,每个位置需要加上 b_{x}-a_{i},首先 -a_{i} 的部分可以区间加。现在问题在于如何快速处理 b_{x} 的操作,注意到后缀 \min 会导致 DP 值单调递增,所以修改后的最小值一定取在区间的最左端点处。可以直接打 Tag 即可。

现在问题还有一个后缀 \min,不难发现,修改后 j<a_{i} 和 j>a_{i} 的位置分别单调递增,于是直接找到 j\ge a_{i} 部分的最小值,并在 j<a_{i} 二分找到大于右侧最小值的部分,区间赋值抹平即可。时间复杂度 O(n\log n)

#include<bits/stdc++.h>
#define int long long
using namespace std;
constexpr int MN=2e5+15;
int n,tot,C,a[MN],b[MN];

struct Segment{
    #define ls p<<1
    #define rs p<<1|1
    struct Node{
        int l,r,mn,mx,cov,add1,add2;
    }t[MN<<2];

    void pushup(int p){
        t[p].mn=min(t[ls].mn,t[rs].mn);
        t[p].mx=max(t[ls].mx,t[rs].mx);
    }

    void docov(int p,int k){
        t[p].cov=k;
        t[p].mn=t[p].mx=k;
        t[p].add1=t[p].add2=0;
    }

    void doadd1(int p,int k){
        t[p].add1+=k;
        t[p].mn+=1ll*k*b[t[p].l];
        t[p].mx+=1ll*k*b[t[p].r];
    }

    void doadd2(int p,int k){
        t[p].add2+=k;
        t[p].mn+=k;
        t[p].mx+=k;
    }

    void pushdown(int p){
        if(t[p].l==t[p].r) return;
        if(~t[p].cov){
            docov(ls,t[p].cov);
            docov(rs,t[p].cov);
            t[p].cov=-1;
        }
        if(t[p].add1){
            doadd1(ls,t[p].add1);
            doadd1(rs,t[p].add1);
            t[p].add1=0;
        }
        if(t[p].add2){
            doadd2(ls,t[p].add2);
            doadd2(rs,t[p].add2);
            t[p].add2=0;
        }
    }

    void build(int p,int l,int r){
        t[p].l=l;
        t[p].r=r;
        t[p].cov=-1;
        if(l==r) return;
        int mid=(l+r)>>1;
        build(ls,l,mid);
        build(rs,mid+1,r);
        pushup(p);
    }

    void cover(int p,int fl,int fr,int k){
        if(t[p].l>=fl&&t[p].r<=fr){
            docov(p,k);
            return;
        }
        pushdown(p);
        int mid=(t[p].l+t[p].r)>>1;
        if(mid>=fl) cover(ls,fl,fr,k); 
        if(mid<fr) cover(rs,fl,fr,k);
        pushup(p);
    }

    void add1(int p,int fl,int fr,int k){
        if(t[p].l>=fl&&t[p].r<=fr){
            doadd1(p,k);
            return;
        }
        pushdown(p);
        int mid=(t[p].l+t[p].r)>>1;
        if(mid>=fl) add1(ls,fl,fr,k);
        if(mid<fr) add1(rs,fl,fr,k);
        pushup(p);
    }

    void add2(int p,int fl,int fr,int k){
        if(t[p].l>=fl&&t[p].r<=fr){
            doadd2(p,k);
            return;
        }
        pushdown(p);
        int mid=(t[p].l+t[p].r)>>1;
        if(mid>=fl) add2(ls,fl,fr,k);
        if(mid<fr) add2(rs,fl,fr,k);
        pushup(p);
    }

    int querymn(int p,int fl,int fr){
        if(t[p].l>=fl&&t[p].r<=fr){
            return t[p].mn;
        }
        pushdown(p);
        int mid=(t[p].l+t[p].r)>>1;
        int ret=1e18;
        if(mid>=fl) ret=min(ret,querymn(ls,fl,fr));
        if(mid<fr) ret=min(ret,querymn(rs,fl,fr));
        return ret;
    }

    int binary(int p,int k){
        if(t[p].mx<k) return t[p].r+1;
        if(t[p].l==t[p].r) return t[p].l;
        pushdown(p);
        if(t[ls].mx>=k) return binary(ls,k);
        else return binary(rs,k);
    }
}sg;

signed main(){
    cin>>n>>C;
    for(int i=1;i<=n;i++){
        cin>>a[i];
        b[i]=a[i];
    }
    sort(b+1,b+1+n);
    tot=unique(b+1,b+1+n)-b-1;
    for(int i=1;i<=n;i++){
        a[i]=lower_bound(b+1,b+tot+1,a[i])-b;
    }
    sg.build(1,1,tot);
    for(int i=1;i<=n;i++){
        if(a[i]!=1){
            sg.add2(1,1,a[i]-1,C); 
        }
        sg.add2(1,a[i],tot,-b[a[i]]);
        sg.add1(1,a[i],tot,1); 
        int rm=sg.querymn(1,a[i],tot); 
        int pos=sg.binary(1,rm);
        if(pos<a[i]) sg.cover(1,pos,a[i]-1,rm);
    }
    cout<<sg.t[1].mn<<'\n';
    return 0;
}

线段树合并维护—P6733

首先考虑发掘性质,题目叽里咕噜看不懂,但是有几个性质挺好:

第二个性质是第一个性质的推论,其本质就点明了子问题的设计,即限制影响的设计。

两个性质启示我们 DP 状态的设计应当包含这个限制影响范围。设 f(u,i) 表示 u 子树内限制祖先深度最远到了深度为 j 的祖先,其他我们都保证合法的方案数。

转移考虑一个一个子树合并,利用性质 2 我们可以枚举边权设置为为 0/1 转移。

时间复杂度为 O(n^2),但是发现转移是一个区间形式的转移,考虑用线段树优化这一过程,第一个式子是全局乘,第二个式子是前后缀乘法,可以利用线段树合并动态维护值优化这一过程。时间复杂度 O(n\log n)

在一些和祖先带相关限制的树形 DP 中,我们会遇到一种现象就是子树内部能解决一部分限制,但有些限制不能在当前子树内解决,只能依赖于祖先去兜底。所以 DP 状态不能只描述当前子树内部已经解决的情况,还必须记录子树内尚未解决、但需要祖先去兜底的残余需求。在下一道例题中我们会再次叙述。

如果你还需要一些线段树合并维护转移的题目,我们还有更厉害的:P5298 [PKUWC2018] Minimax

#include<bits/stdc++.h>
#define int long long
using namespace std;
constexpr int MN=2e6+15,MOD=998244353;
int n,m,mx[MN],rt[MN];
vector<int> adj[MN];

struct Segment{
    #define ls t[p].lson
    #define rs t[p].rson
    struct Node{
        int lson,rson,val,tag=1;
    }t[MN<<3];
    int tot,tmp;

    void pushup(int p){
        t[p].val=(t[ls].val+t[rs].val)%MOD;
    }

    void domul(int p,int k){
        t[p].val=t[p].val*k%MOD;
        t[p].tag=t[p].tag*k%MOD;
    }

    void pushdown(int p){
        if(t[p].tag!=1){
            domul(ls,t[p].tag);
            domul(rs,t[p].tag);
            t[p].tag=1;
        }
    }

    void update(int &p,int l,int r,int pos,int k){
        if(!p) p=++tot;
        if(l==r){
            t[p].val=k;
            return;
        }
        pushdown(p);
        int mid=(l+r)>>1;
        if(mid>=pos) update(ls,l,mid,pos,k);
        else update(rs,mid+1,r,pos,k);
        pushup(p);
    }

    int query(int p,int l,int r,int fl,int fr){
        if(l>=fl&&r<=fr){
            return t[p].val;
        }
        pushdown(p);
        int mid=(l+r)>>1,ret=0;
        if(mid>=fl) (ret+=query(ls,l,mid,fl,fr))%=MOD;
        if(mid<fr) (ret+=query(rs,mid+1,r,fl,fr))%=MOD;
        return ret;
    }

    int merge(int x,int y,int l,int r,int s1,int s2){
        if(!x&&!y) return 0;
        if(!x){
            domul(y,s2);
            return y;
        }
        if(!y){
            domul(x,s1+tmp);
            return x;
        }
        if(l==r){
            t[x].val=(t[x].val*(s1+t[y].val+tmp)%MOD+t[y].val*s2)%MOD;
            return x;
        }
        int mid=(l+r)>>1;
        pushdown(x);
        pushdown(y);
        t[x].rson=merge(t[x].rson,t[y].rson,mid+1,r,(s1+t[t[y].lson].val)%MOD,(s2+t[t[x].lson].val)%MOD);
        t[x].lson=merge(t[x].lson,t[y].lson,l,mid,s1,s2);
        pushup(x);
        return x;
    }

}sg;

namespace Tree{
    int dep[MN];

    void dfs1(int u,int pre){
        dep[u]=dep[pre]+1;
        for(auto v:adj[u]){
            if(v==pre) continue;
            dfs1(v,u);
        }
    }

    void dfs2(int u,int pre){
        sg.update(rt[u],0,n,mx[u],1);
        for(auto v:adj[u]){
            if(v==pre) continue;
            dfs2(v,u);
            sg.tmp=sg.query(rt[v],0,n,0,dep[u]);
            rt[u]=sg.merge(rt[u],rt[v],0,n,0,0);
        }
    }

}using namespace Tree;

signed main(){
    cin>>n;
    for(int i=1;i<n;i++){
        int u,v;
        cin>>u>>v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    dfs1(1,0);
    cin>>m;
    for(int i=1;i<=m;i++){
        int u,v;
        cin>>u>>v;
        mx[v]=max(mx[v],dep[u]);
    }
    dfs2(1,0);
    cout<<sg.query(rt[1],0,n,0,0);

    return 0;
}

用堆维护—CF671D

祖先限制,考虑 DP,设 f(u,i) 表示 u 子树内全部覆盖,其限制向上延伸到了深度为 j 的祖先的最小花费,其中 j 满足 j\le dep_{i}。那么有转移方程:

f(u,i)=\min\limits_{v\in son(u)} \{f(u,i)+\min\limits_{j=0}^{dep_{i}}f(v,j),f(v,i)+\min\limits_{j=0}^{dep_{i}}f(u,j)\}

即我们考虑合并子节点的答案,答案合并的时候可能从 f(u) 或者 f(v) 一个产生。我们不难发现有两个前缀 \min 的操作,并且操作为子树合并取 \min 操作,考虑线段树合并,对合并时可能产生的情况进行分类讨论即可,时间复杂度 O(n\log n),但是显然有点过于难了不是吗?

我们还有跟简单的方法,注意到我们操作每次取操作的时前缀 \min,而且在子树合并过程中 dep_{i} 单调不升,我们可以考虑用堆来维护这个操作,对于 j>dep_{i} 的操作我们考虑懒惰删除(我们用的时候在排除不合法状态)。因此我们在每个节点上维护一个堆,堆里装所有的第二维状态和值即可。使用左偏树可以做到 O(n\log n),使用堆或 set 加启发式合并可以做到 O(n\log^2 n)

#include<bits/stdc++.h>
#define int long long
#define pir pair<int,int>
using namespace std;
constexpr int MN=3e5+15;
struct Node{
    int j,cst;

    bool operator<(const Node &x)const{
        return cst<x.cst;
    }

};
int n,m,dep[MN],ans,tag[MN];
bool flag=1;
vector<int> adj[MN];
vector<pir> path[MN];
multiset<Node> st[MN];

void merge(int x,int y){
    if(st[x].size()<st[y].size()) swap(st[x],st[y]),swap(tag[x],tag[y]); 
    int mnx=(!st[x].size())?0:(*st[x].begin()).cst,mny=(!st[y].size())?0:(*st[y].begin()).cst;
    while(!st[y].empty()){
        auto tp=(*st[y].begin());
        st[y].erase(st[y].begin());
        tp.cst+=mnx-mny;
        st[x].insert(tp);
    }
    tag[x]+=mny+tag[y];
}

void dfs1(int u,int pre){
    dep[u]=dep[pre]+1;
    for(auto v:adj[u]){
        if(v==pre||!flag) continue;
        dfs1(v,u);
        merge(u,v);
    }
    if(!flag) return;
    int minn=0;
    if(!st[u].empty()) minn=(*st[u].begin()).cst;
    for(auto p:path[u]){
        st[u].insert({p.first,p.second+minn});
    }
    if(st[u].empty()){
        flag=0;
        return;
    }
    if(u!=1){
        while(!st[u].empty()&&dep[(*st[u].begin()).j]>=dep[u]) st[u].erase(st[u].begin());
        if(st[u].empty()){
            flag=0;
            return;
        }
    }
}

signed main(){
    cin>>n>>m;
    for(int i=1;i<n;i++){
        int u,v;
        cin>>u>>v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    for(int i=1;i<=m;i++){
        int u,v,w;
        cin>>u>>v>>w;
        path[u].push_back(pir(v,w));
    }
    dfs1(1,0);
    if(!flag){
        cout<<-1;
        return 0;
    }
    else if(m==1) cout<<0;
    else cout<<st[1].begin()->cst+tag[1];

    return 0;
}

差分化维护—P6847

今天没有什么太好的性质,考虑 DP,设 f(i,j) 表示 i 子树内断边在 \le j 的时间断开,转移:

更 nb 的,因为 dp 值单调所以可以考虑差分,那么第一种转移就能直接启发式合并,第二种转移是增加 w_{u} 的差分值,直接插入 set 中,这其实是取 max 操作,所以要从后面删除一些差分标记,时间复杂度 O(n\log^2 n)

#include<bits/stdc++.h>
#define int long long
using namespace std;
constexpr int MN=1e5+15;
int n,m,K,rt[MN],d[MN],w[MN];
vector<int> adj[MN];

struct Segment{
    #define ls t[p].lson
    #define rs t[p].rson
    struct Node{
        int lson=0,rson=0;
        int mx=0,mn=0,add=0;
    }t[MN*50];
    int tot=0;

    bool isleaf(int p){
        if(!p) return true;
        return !t[p].lson && !t[p].rson;
    }

    void doadd(int p,int k){
        if(!p) return;
        t[p].mn += k;
        t[p].mx += k;
        t[p].add += k;
    }

    void pushup(int p){
        int L=t[p].lson, R=t[p].rson;
        int mxL = L? t[L].mx : INT_MIN/2;
        int mxR = R? t[R].mx : INT_MIN/2;
        int mnL = L? t[L].mn : INT_MAX/2;
        int mnR = R? t[R].mn : INT_MAX/2;
        t[p].mx = max(mxL,mxR);
        t[p].mn = min(mnL,mnR);
        if(t[p].mx==t[p].mn){
            t[p].lson = t[p].rson = 0;
            t[p].add = 0;
        }
    }

    void pushdown(int p){
        if(!p) return;
        if(t[p].add){
            doadd(t[p].lson,t[p].add);
            doadd(t[p].rson,t[p].add);
            t[p].add=0;
        }
    }

    int merge(int x,int y){
        if(!x || !y) return x?x:y;
        if(isleaf(y)){
            doadd(x,t[y].mx);
            return x;
        }
        if(isleaf(x)){
            doadd(y,t[x].mx);
            return y;
        }
        pushdown(x);
        pushdown(y);
        t[x].lson = merge(t[x].lson, t[y].lson);
        t[x].rson = merge(t[x].rson, t[y].rson);
        pushup(x);
        return x;
    }

    int query(int p,int l,int r,int k){
        if(!p) return 0;
        if(isleaf(p)) return t[p].mx;
        pushdown(p);
        int mid=(l+r)>>1;
        if(k<=mid) return query(t[p].lson,l,mid,k);
        else return query(t[p].rson,mid+1,r,k);
    }

    void modify(int &p,int l,int r,int fl,int fr,int k){
        // non-overlap
        if(l>fr || r<fl || !p) return;
        // if current interval's min already >= k, nothing to do
        if(t[p].mn >= k) return;
        // fully covered and max <= k -> set to k and clear children
        if(l>=fl && r<=fr && t[p].mx <= k){
            t[p].mx = t[p].mn = k;
            t[p].lson = t[p].rson = t[p].add = 0;
            return;
        }
        pushdown(p);
        int mid=(l+r)>>1;
        if(isleaf(p)){
            t[p].lson = ++tot;
            t[p].rson = ++tot;
            t[t[p].lson].mx = t[t[p].lson].mn = t[t[p].rson].mx = t[t[p].rson].mn = t[p].mx;
            // add/add initialized to 0 by Node default
        }
        if(fl <= mid) modify(t[p].lson, l, mid, fl, fr, k);
        if(fr > mid) modify(t[p].rson, mid+1, r, fl, fr, k);
        pushup(p);
    }

}sg;

void dfs(int u,int pre){
    rt[u] = ++sg.tot;
    for(auto v:adj[u]){
        if(v==pre) continue;
        dfs(v,u);
        rt[u] = sg.merge(rt[u], rt[v]);
    }
    if(d[u]){
        int cur = sg.query(rt[u],1,K,d[u]);
        int val = w[u] + cur;
        sg.modify(rt[u],1,K,d[u],K,val);
    }
}

signed main(){
    cin>>n>>m>>K;
    for(int i=2;i<=n;i++){
        int fa;
        cin>>fa;
        adj[fa].push_back(i);
        adj[i].push_back(fa);
    }
    for(int i=1;i<=m;i++){
        int x;
        cin>>x>>d[x]>>w[x];
    }
    dfs(1,0);
    cout<<sg.t[rt[1]].mx<<"\n";
    return 0;
}

1.3 习题与反思

一开始所提到的转移整体化的核心,在于把原本需要在 DP 中逐个状态枚举的转移,用批量可维护的方式交给数据结构来完成,从而让复杂度从 O(n^2)O(nV) 这类指数或平方级,下降到 O(n\log n) 甚至更优。对于一类有大量相同或者说相似转移的 dp,把 dp 的一维换成数据结构,用数据结构批量处理相同的转移。

来点习题练练手!

2. 多次转移合并

2.1 概述

多次转移合并这个名字怎么这么奇怪呢?

回忆我们一开始所提到的:

多次转移合并:将多次重复的转移,转化为不再对每个问题单独求解,而是把所有子问题作为整体嵌入同一个 DP 过程。

换句话说,如果每个问题的转移规则是固定的,那么我们可以把每个问题对答案的贡献统一考虑,通过一次 DP 或一次数据结构操作,计算出全部结果,而不必重复遍历。

这种整体化思想有两种展开(也是本蒟蒻所能够遇见的):

接下来我们会通过例题详细展开:

2.2 例题

转移贡献可叠加性

DP 的转移是固定的、局部的贡献可线性累加,就可以用一次整体 DP 代替多次重复计算。这一类常见的就是子区间问题的统计。

CF1603C Extreme Extension

首先考虑序列给定的话如何计算,显然最后一个值不可能动不然根据调整法不难证明不优。

那么也就是说最后一个值肯定是最大值,然后我们从后往前考虑,不妨记当前最大值为 x 每次拆我们尽量往大的拆,那么最少需要拆 k=\lceil \dfrac{a_{i}}{x} \rceil 次,而拆值后新的开头 x'=\lfloor \dfrac{a_{i}}{k} \rfloor

显然对于每个 ikx' 的取值最多只有 \sqrt{n} 种,并且计算只和当前开头有关,故设 f(i,j) 表示第 i 个位置开头数字为 j 的方案数。直接做的话是 O(n\sqrt{n}) 的,但是问题在于子区间统计做的话是 O(n^2\sqrt{n}) 的,注意到我们每次转移都是一致的系数完全一致,考虑整体 DP,在所有右端点处初始化,在左端点处统计答案,由于这个贡献是子区间完全可以累加贡献,只需要做一次 dp 就可以解决问题。时间复杂度 O(n\sqrt{n})

#include<bits/stdc++.h>
#define int long long
using namespace std;
constexpr int MN=5e5+15,MOD=998244353;
int n,mx,a[MN],f[2][MN],ans;
vector<int> vct[2];

void initdp(){
    ans=0;
    vct[0].clear();
    vct[1].clear();
    mx=0;
    for(int i=1;i<=n;i++){
        mx=max(mx,a[i]);
    }
    for(int i=0;i<=mx;i++){
        f[0][i]=f[1][i]=0;
    }
}

void solve(){
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>a[i];
    }
    initdp();
    int now=0,bef=1;
    for(int i=n;i>=1;i--){
        now^=1,bef^=1;
        int lst=a[i];
        vct[now].push_back(a[i]);
        f[now][a[i]]=1;
        for(auto p:vct[bef]){
            int cnt=ceil(1.0*a[i]/p),val=a[i]/cnt;
            f[now][val]=(f[now][val]+f[bef][p])%MOD;
            ans=(ans+(cnt-1)*i%MOD*f[bef][p]%MOD)%MOD;           
            if(lst!=val){
                vct[now].push_back(val);
                lst=val;
            }
        }
        for(auto p:vct[bef]) f[bef][p]=0;
        vct[bef].clear();
    }
    cout<<ans<<'\n';
}

signed main(){
    int T;
    cin>>T;
    while(T--){
        solve();
    }
    return 0;
}

CF1142D Foreigner

首先这个玩意过于抽象。但是我们不难发现每一层判断决策只和当前排名个数和上一位是什么有关,而且这个数列是根据生成顺序递增的。考虑增量法,我们考虑计算往数列中第 i 个数后面接一位 c,得到的数字的排名是多少?我们可以考虑计算比它小的合法数量,这里给出式子:

9+\sum\limits_{j=1}^{i-1} k\bmod 11 +c+1

但是我们只需要判断一个数是否是合法数字即可,考虑把上面式子放在 \bmod 11 的意义下进行就可以得到,有:

10+\dfrac{i(i-1)}{2}+c \pmod{11}

不难发现这个值只和 \bmod 11 后的值有关,对于原来的计数问题,可以考虑设计一个自动机来转移,说人话,设 f(i,j) 表示前 i 位中有多少排名 \bmod 11j 的不充分数字。枚举起点然后做时 O(n^2) 的,但是注意到这是子区间,显然满足贡献可叠加,直接在左端点初始化,走到右端点时上拿答案即可,时间复杂度 O(n)

#include<bits/stdc++.h>
#define int long long
using namespace std;
constexpr int MN=1e6+15;
int n,f[MN][15],ans;
string st;

int nxt(int x,int c){
    return (x*(x-1)/2+c+10)%11;
}

signed main(){
    cin>>st;
    n=st.length();
    st=" "+st;
    for(int i=1;i<=n;i++){
        int ch=st[i]-'0';
        if(ch>0) f[i][ch]++;
        for(int j=ch+1;j<=10;j++){
            f[i][nxt(j,ch)]+=f[i-1][j];
        }
        for(int j=0;j<=10;j++) ans+=f[i][j];
    }
    cout<<ans;
    return 0;
}

P3352 线段树

需要把这些序列大小要整合到一起,不知道用什么处理?笛卡尔树?bro 有点难。咱们还是考虑 01 序列怎么做吧。

首先没有概率,就是纯纯的求方案数乘上权值。考虑值域为 \{0,1\} 的时候怎么做,不难发现对答案造成贡献必定是 0 段和 1 段合并,并且发现 0 段必定会被 1 段给包夹(边界位置设置为 1)。考虑到每次操作 1 段大小单调不降,0 段大小单调不升,我们考虑 DP 主体应该为 0 段,有状态 f(i,l,r) 表示 i 操作后 0 段缩小到 [l,r] 的方案数,有转移:

dp[i][l][r]\leftarrow\begin{cases}dp[i-1][l][r]\cdot\frac{l(l+1)+(n-r+1)(n-r+2)+(r-l)(r-l-1)}{2} & \text{QwQ}\\dp[i-1][l'][r]\cdot l' & l'<l \\ dp[i-1][l][r']\cdot (n-r'+1) & r'> r\end{cases}

不难发现可以用前缀和优化转移,时间复杂度 O(n^2q),拓展到一般序列上我们把 w 的贡献拆为 \max a_{i}-\max_{i=0}^{\max a_{i}} [w<i],即把所有 \ge i 的位置标为 1,把所有 <i 的位置标为 0。此时我们就可以算出每一个位置 <i 的方案数,时间复杂度 O(n^3q),数据随机可过。

显然太吃运气了,考虑优化!发现所有的 dp 值的转移柿子都完全一致而且系数完全固定,只有转移是不同的!所以我们可以把所有初始值放到同一个 dp 数组里面,然后通过合理初始化进行一次整体的 dp,就可以求出答案。

#include<bits/stdc++.h>
#define int long long
using namespace std;
constexpr int MN=520,MOD=1e9+7,INV2=500000004;
struct Node{
    int v,id;
}a[MN];
int n,q,ans[MN],f[MN][MN],s1[MN][MN],s2[MN][MN],v[MN];

int ksm(int a,int b){
    int ret=1;
    while(b){
        if(b&1) ret=ret*a%MOD;
        a=a*a%MOD;
        b>>=1;
    }
    return ret;
}

int w(int l,int r){
    return (l*(l+1)%MOD+(n-r+1)*(n-r+2)%MOD+(r-l-1)*(r-l)%MOD)*INV2%MOD;
}

bool cmp(Node x,Node y){
    return x.v<y.v;
}

signed main(){
    cin>>n>>q;
    for(int i=1;i<=n;i++){
        cin>>a[i].v;
        a[i].id=i;
    }
    sort(a+1,a+1+n,cmp);
    v[n+1]=1;
    for(int i=n;i>=1;i--){
        int lst=0;
        v[a[i].id]=1;
        for(int j=1;j<=n+1;j++){
            if(v[j]){
                f[lst][j]+=a[i].v-a[i-1].v;
                lst=j;
            }
        }
    }
    ans[1]=a[n].v*ksm(n*(n+1)%MOD*INV2%MOD,q);
    for(int i=2;i<=n;i++) ans[i]=ans[1];
    for(int i=1;i<=q;i++){
        for(int j=0;j<=n;j++){
            for(int k=n+1;k>j+1;k--){
                if(j){
                    s1[j][k]=(s1[j-1][k]+f[j][k]*j%MOD)%MOD;
                }else s1[j][k]=f[j][k]*j%MOD;
                if(k<=n){
                    s2[j][k]=(s2[j][k+1]+f[j][k]*(n-k+1)%MOD)%MOD;
                }else s2[j][k]=(f[j][k]*(n-k+1)%MOD)%MOD;
            }
        }
        for(int j=0;j<=n;j++){
            for(int k=j+2;k<=n+1;k++){
                f[j][k]=f[j][k]*w(j,k)%MOD;
                if(j){
                    f[j][k]=(f[j][k]+s1[j-1][k])%MOD;
                }
                if(k<=n) f[j][k]=(f[j][k]+s2[j][k+1])%MOD;
            }
        }
    }
    for(int i=0;i<=n;i++){
        for(int j=i+2;j<=n+1;j++){
            for(int k=i+1;k<j;k++){
                ans[k]=(ans[k]-f[i][j]+MOD)%MOD;
            }
        }
    }
    for(int i=1;i<=n;i++) cout<<ans[i]<<' ';
    return 0;
}

多询问整体处理

每次修改只影响少量节点,但转移规则固定。这就是核心,对于一些只有状态上的值不同而转移方程完全相同的 DP 可以考虑使用这个 trick。接下来我们会以保卫王国这道经典题来作为叙述。

P5024 保卫王国

显然可以写出朴素转移方程,设 f(i,0/1) 表示 i 子树中,钦定 i 不选或选的最小代价,有转移:

\begin{aligned} f(u,0) &=\sum\limits_{v\in son(u)} f(v,1) \\ f(u,1) &= \sum\limits_{v\in son(u)} \min\{f(v,0),f(v,1)\} \end{aligned}

显然转移都是 O(n) 的,如果对于每个询问直接强制钦定然后重新 DP 的话时间复杂度是 O(n^2) 的,无法通过。

我们发现尽管强制钦定之后转移仍是固定的,虽然在每次询问中,某些节点被强制钦定,但转移方程本身从来没有变过,变的只是某些节点的取值。例如对于每个节点 u,正常的 DP 状态就是 f(u,*),这是在没有任何强制约束时的最优值。当一条询问出现,比如强制 u 不选,这其实就是把 f(u,1) 赋为 \infty;同理强制 u 必须选,就是把 f(u,0) 赋为 \infty

如果我们对于每个询问都从根重新跑一次树形 DP,复杂度就是 O(nm)。但是发现每次询问都只是修改了少量个别节点的 DP 初值;而且合并答案的规则是固定的。

我们可以将 DP 方程改写为矩阵形式,这里矩阵运算为 (\min,+) 广义矩阵乘法,这里不再列举。此时,强制钦定相当于在某个叶子节点上乘一个特殊矩阵,例如强制不选:[f(0), f(1)] \times \begin{bmatrix}\infty & \infty \\ \infty & 0\end{bmatrix};这样我们就能通过矩阵的形式表达对于某个节点的 DP 修改,并且这种修改是可以自下而上合并的。

如果对每个询问单独做矩阵 DP,复杂度仍高。但是我们发现只有少量的值会被修改,而且每次询问吧修改至多是 2 次单点修改,既然修改只发生在极少数节点,我们能不能把询问当作一堆修改,整体放到同一套 DP 里?正难则反!我们不考虑对每个询问求 DP 的值,而是在 DP 过程中维护每个询问情况下所对应的值!我们可以批量地对所有需要进行这一转移的 DP 进行转移,从而加快速度。

那如果我们现在拿着所有询问的 dp 值组成的一个数组,那么发现如果一个儿子没有成为过特殊点,它的转移显然用固定的转移矩阵维护就可以了;如果它成为过特殊点,我们可以单独转移它成为特殊点的那几次,就像上面特殊的用矩阵乘上。

暴力转移是 O(n^2) 的,如何更快地维护呢,例如 O(\log n)?我们有很多次修改,每次询问对应修改几个节点的 DP 初值;还需要快速把这些修改合并到全局结果里,这不线段树合并吗!

具体的,我们对于树上每个节点开一个线段树,让线段树上第 i 个叶子表示第 i 次询问的 DP 值,即 [f(0), f(1)]。注意这里的线段树仅仅是一个分治结构,而不是什么维护区间结合律信息的数据结构,也就是没有 pushup 这种玩意。它存在的意义就是利用结构一致和深度为 O(\log n) 来支持快速的单点修改和合并答案两个操作。

合并的时候,因为我们只关心叶子的信息,也就是 DP 值,考虑如果两棵线段树上有一棵有某个叶子 u,而另一棵没有,那么这个 u 的值应该没有变化,所以合并的时候不用管它如果两棵上都有 u,那么我们应该进行一个转移,我们记录这个 DP 值是属于哪一种转移,然后由于合并算法确实会走到这里,我们走到这里再执行转移就好了。

这个就是多询问整体处理的一个思想,其本质就是把多次 DP 叠加,转化为一次带修改的数据结构维护问题。在固定转移下,把多组独立询问抽象为初值修改,然后用线段树把它们整体维护起来,实现一次 DP 覆盖所有询问。

代码如下:

#include<bits/stdc++.h>
#define int long long
#define pir pair<int,int>
using namespace std;
constexpr int MN=5e5+15;
const int INF=1e18;
int ans[MN],a[MN],n,m,rt[MN];
vector<int> adj[MN];
vector<pir> qry[MN];

struct Matrix{
    int mat[2][2];

    Matrix(int x=0){
        mat[0][0]=mat[0][1]=mat[1][0]=mat[1][1]=x;
    }

    Matrix(int x1,int y1,int x2,int y2){
        mat[0][0]=x1;mat[0][1]=y1;mat[1][0]=x2;mat[1][1]=y2;
    }

    Matrix(int x,int y){
        mat[0][0]=x;mat[1][1]=y;mat[0][1]=mat[1][0]=INF;
    }

    friend bool operator==(const Matrix &x,const Matrix &y){
        for(int i=0;i<2;i++) for(int j=0;j<2;j++) if(x.mat[i][j]!=y.mat[i][j]) return 0;
        return 1;
    }

    friend Matrix operator*(const Matrix &x,const Matrix &y){
        Matrix ret(INF);
        for(int i=0;i<2;i++) for(int j=0;j<2;j++) for(int k=0;k<2;k++)
            ret.mat[i][j]=min(ret.mat[i][j],x.mat[i][k]+y.mat[k][j]);
        return ret;
    }
};
const Matrix MINF=Matrix(0,INF,INF,0);

struct Segment{
    #define ls t[p].lson
    #define rs t[p].rson
    struct Node{
        int lson,rson;
        Matrix val;
    }t[MN*30];
    int tot;
    void init(int x){t[x].lson=t[x].rson=0;t[x].val=MINF;}

    void pushdown(int p){
        if(t[p].val==MINF) return;
        if(!ls) ls=++tot,init(ls);
        if(!rs) rs=++tot,init(rs);
        t[ls].val=t[ls].val*t[p].val;
        t[rs].val=t[rs].val*t[p].val;
        t[p].val=MINF;
    }

    void modify(int &p,int l,int r,int pos,const Matrix &k){
        if(!p) p=++tot,init(p);
        if(l==r){
            t[p].val=t[p].val*k;
            return;
        }
        pushdown(p);
        int mid=(l+r)>>1;
        if(pos<=mid){
            modify(ls,l,mid,pos,k);
        }else{
            modify(rs,mid+1,r,pos,k);
        }
    }
    int merge(int x,int y){
        if(!x||!y) return x|y;
        if(!t[x].lson&&!t[x].rson) swap(x,y);
        if(!t[y].lson&&!t[y].rson){
            t[x].val=t[x].val*Matrix(t[y].val.mat[0][0],t[y].val.mat[0][1]);
            return x;
        }
        pushdown(x);
        pushdown(y);
        t[x].lson=merge(t[x].lson,t[y].lson);
        t[x].rson=merge(t[x].rson,t[y].rson);
        return x;
    }
    void solve(int p,int l,int r){
        if(l==r){
            ans[l]=t[p].val.mat[0][1];
            return;
        }
        pushdown(p);
        int mid=(l+r)>>1;
        if(ls) solve(ls,l,mid);
        if(rs) solve(rs,mid+1,r);
    }
}sg;

void dfs(int u,int pre){
    rt[u]=++sg.tot;
    sg.init(rt[u]);
    sg.t[rt[u]].val=Matrix(0,a[u],0,0);
    for(auto v:adj[u]) if(v!=pre){
        dfs(v,u);
        rt[u]=sg.merge(rt[u],rt[v]);
    }
    for(auto pr:qry[u]){
        if(pr.first==0) sg.modify(rt[u],1,m,pr.second,Matrix(0,INF));
        else sg.modify(rt[u],1,m,pr.second,Matrix(INF,0));
    }
    sg.t[rt[u]].val=sg.t[rt[u]].val*Matrix(INF,0,0,0);
}

signed main(){
    string tmp;
    cin>>n>>m>>tmp;
    for(int i=1;i<=n;i++) cin>>a[i];
    for(int i=1;i<n;i++){
        int u,v;
        cin>>u>>v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    for(int i=1;i<=m;i++){
        int x1,v1,x2,v2;
        cin>>x1>>v1>>x2>>v2;
        qry[x1].push_back({v1,i});
        qry[x2].push_back({v2,i});
    }
    sg.init(0);
    dfs(1,0);
    if(rt[1]) sg.solve(rt[1],1,m);
    for(int i=1;i<=m;i++) cout<<(ans[i]>=INF?-1:ans[i])<<'\n';
    return 0;
}

P2495 消耗战

虚树?不不不这是整体 DP!

先考虑一次询问怎么做,设 f(u) 表示子树 u 内所有关键点到 u 的路径切断的最小代价,转移方程显然:

可以发现,这个 DP 的过程非常简洁,由几个简单的操作组成:求和,取 min。不难可以写成 O(1) 转移式子或者矩阵形式,然后用上面的技巧维护,取 min 就是全局取 min,求和就是线段树合并的时候对应位置相加。时间复杂度 O(m\log m)

代码咕咕咕,如果想要可以参考 Fuyuki 的文章。

3. 参考与后言

上面说这么多,归根结底,这种方法的精髓在于将重复工作转化为一次统一处理。整体思想,就是在于整体处理,整体转移。种性质允许我们将原本需要多次重复计算的子问题,整合到一次整体 DP 中处理,从而大幅提升效率,避免重复操作。

这个博客提到的整体 DP 感觉可以单开一章,不过有点看不懂了,大家可以作为饭后作业 www。

求赞 QwQ。