题解:P4751 【模板】"动态DP"&动态树分治(加强版)

· · 题解

最近开始学习 动态 DP,如果有不对的地方请在讨论区指出,我当感激不尽。(如果已经会动态 DP 请去下面直接从 全局平衡二叉树维护动态 DP 开始看起)。

进入正题

本题需要一个科技 动态 DP,如果还不会的话,这里大概讲一下。

$f_{u,0} = \sum \max(f_{v,0},f_{v,1}) f_{u,1} = \sum f_{v,0} + a_u

然后我们考虑修改怎么做,你会发现其实他只会修改一条从修改点到根节点的 f 值,但是如果暴力修改的话肯定会超时,时间复杂度 O(nm)

然后,我们引入 动态 DP

动态 DP 问题是猫锟在 WC2018 讲的黑科技,一般用来
解决树上的带有点权(边权)修改操作的 DP 问题。--- oiwiki

具体的,我们考虑将原树 重链剖分,计算出不含有 重儿子f 值为 g

状态转移方程如下,

f_{x,0}=g_{x,0}+\max(f_{son_x,1},f_{son_x,0}) f_{x,1}=g_{x,1}+f_{son_x,0}

然后把这个方程用矩阵写出来,就是:

\begin{bmatrix}f_{x,0}&f_{x,1} \end{bmatrix}= \begin{bmatrix}f_{son_x,0}&f_{son_x,1} \end{bmatrix} \begin{bmatrix} g_{x,0} & g_{x,1} \\ g_{x,0} & -\infty \\ \end{bmatrix}

额,怎么感觉这个矩阵乘法不太一样呢?哦,这是 广义矩阵乘法

其实就是在原先的矩阵乘法上把计算符号组合从 (+,\times) 改为了 (\max,+)

代码如下:

mat operator * (const mat &A)const{
        mat res=mat();
        for(int i=0;i<2;i++){
            for(int j=0;j<2;j++){
                for(int k=0;k<2;k++){
                    res.a[i][j]=max(res.a[i][j],a[i][k]+A.a[k][j]);
                }
            }
        }
        return res;
    }

题外话:并不是所有的运算都能 广义矩阵乘法 运算,具体见 此。

然后你就能理解上面的矩阵乘法了,实际上就是转移式子。

这样的话,我们就可以在每一个点上面放一个含 g 值的转移矩阵,特别的,我们需要特判叶子节点,在叶子上放一个初始矩阵(但是本题并不需要,因为你可以在叶子上 gf 数组是一样的)。

修改的时候我们,相当于只需要修改每一个链头的父亲的 g 数组即可(因为只有那些点的 g 值受到了影响)。因为重链只有 O(\log n),所以一次修改最多只会修改 \log n 个点的矩阵。

然后如果想求一个点的 f 值,只需要求当前点到重链深度最高的点的矩阵乘积,取出对应位置的值即可。对于矩阵乘积,我们可以用 树剖+线段树 维护。

但是我们发现我们的初始矩阵在叶子上,所以相乘的时候顺序应该是从下到上的,但是当前的我们刚才推的矩阵顺序是初始矩阵在左,转移矩阵在右的,而矩阵乘法不具有交换律,正常线段树都是从左往右的算的,而 dfs 序会导致是上面的数算的时候在左边,会出错,解决办法有两个:

1.最常见的是改变矩阵的位置,变成,

\begin{bmatrix} f_{x,0} \\f_{x,1} \end{bmatrix}= \begin{bmatrix} g_{x,0} & g_{x,0} \\ g_{x,1} & -\infty \\ \end{bmatrix} \begin{bmatrix} f_{son_x,0} \\f_{son_x,1} \end{bmatrix}

这样从左往右乘就是对的了,我用的就是这种办法。

2.直接在线段树中改变乘法的顺序,从右往左乘。(不过这个方法好像没什么人提到过,不确定行不行,所以 狗头保命

到这里其实就可以 AC P4719 了。

代码和其他人差不多,在这里不过多赘述。(不过如果是第一次接触这个内容的最好先做完上面的那道题再来看下面的东西。

时间复杂度 O(2^3 n \log n^2)

但是这个时间复杂度过不了本题,所以我们考虑优化。但是矩阵这个 trick 好像扔不掉,那我们考虑有没有什么更快的方法处理矩阵的乘积。

正题:全局平衡二叉树维护动态 DP

首先我们要知道全局平衡二叉树(下面简称平衡树)是什么。

全局平衡二叉树实际上是一颗二叉树森林,
其中的每颗二叉树维护一条重链。
但是这个森林里的二叉树又互有联系,
其中每个二叉树的根连向这个重链链头的父亲,
就像 LCT 中一样。但全局平衡二叉树是
静态树,区别于 LCT,建成后树的形态不变。--oiwiki

平衡树将每一条重链转化成二叉树,由多条轻边相连,组成一个树高为 O(\log n) 的树。这样就不用二叉树维护,直接求子树内乘积即可。

首先考虑如何建树,首先从一个点到根的轻链条数一定是 \log n 级别的,所以我们只需要考虑如何构造重链的形态即可。

所以我们的构造方法就是把树链构造成一棵二叉树,具体操作过程就是把每一条重链剖出来,然后求加权中点(下面简称中点),以中点为根,继续处理左右两边,具体看代码:

int cbuild(int ql,int qr){
    int l=ql,r=qr;
    while(l+1!=r){
        int mid=(l+r)>>1;
        if(((pre[mid]-pre[ql])<<1)<=pre[qr]-pre[ql]) l=mid;
        else r=mid;
    }//求中点
    int rt=b[l];tree[rt]=val[rt];
    if(l>ql) ls[rt]=cbuild(ql,l),fa[ls[rt]]=rt;
    if(l+1<qr) rs[rt]=cbuild(l+1,qr),fa[rs[rt]]=rt;//递归处理左右子树
    return rt;
}
int build(int x){
    int y=x;
    do{
        val[y].a[0][0]=val[y].a[0][1]=g[y][0];
        val[y].a[1][0]=g[y][1];val[y].a[1][1]=-INF;
        for(auto v:ed[y]){
            if(v==son[y]||v==fat[y]) continue;
            fa[build(v)]=y;//这里注意一下,不要写成 fa[v]=y;
        }
    }while(y=son[y]);//先处理重链旁的链
    do{
        b[y++]=x;pre[y]=pre[y-1]+siz[x]-siz[son[x]];
    }while(x=son[x]);把重链打出来
    return cbuild(0,y);
}

这里小证明一下,因为是求了中点,所以每往上面跳一次都会子树大小都会加倍,所以最多跳 \log n 次就会子树大小都会变成 n所以总树高就是 O(\log n)

然后剩下的就很简单了,每一个点维护子树矩阵乘积即可。

完整代码:(注释在下方)。

小拓展:平衡树的中序遍历就是原树链的序列。所以其实也可以求具体某一个点的值,只需要从当前点出发一直走到所在二叉树的根节点,乘上原树链在当前点的下面的点的矩阵(这个画个图就好理解了),可以尝试做一下 洪水。

#include <bits/stdc++.h>
using namespace std;
const int INF=1e9,N=1e6+10;
int n,m,a[N],fat[N],son[N],siz[N],fa[N],b[N],pre[N],ls[N],rs[N],f[N][2],g[N][2],lans;
vector<int> ed[N];
struct mat{
    int a[2][2];
    mat(){for(int i=0;i<2;i++) for(int j=0;j<2;j++) a[i][j]=-INF;}
    mat operator * (const mat &A)const{
        mat res=mat();
        for(int i=0;i<2;i++){
            for(int j=0;j<2;j++){
                for(int k=0;k<2;k++){
                    res.a[i][j]=max(res.a[i][j],a[i][k]+A.a[k][j]);
                }
            }
        }
        return res;
    }
}val[N],tree[N];//tree[x] 表示子树的矩阵乘积
void dfs(int x,int fa){
    f[x][1]=g[x][1]=a[x];
    fat[x]=fa;siz[x]=1;
    for(auto v:ed[x]){
        if(v==fa) continue;
        dfs(v,x);
        siz[x]+=siz[v];
        if(siz[v]>siz[son[x]]){
            g[x][1]+=f[son[x]][0];
            g[x][0]+=max(f[son[x]][1],f[son[x]][0]);
            son[x]=v;
        }else{
            g[x][1]+=f[v][0];
            g[x][0]+=max(f[v][1],f[v][0]);
        }
        f[x][1]+=f[v][0];
        f[x][0]+=max(f[v][1],f[v][0]);
    }
}
int cbuild(int ql,int qr){
    int l=ql,r=qr;
    while(l+1!=r){
        int mid=(l+r)>>1;
        if(((pre[mid]-pre[ql])<<1)<=pre[qr]-pre[ql]) l=mid;
        else r=mid;
    }
    int rt=b[l];tree[rt]=val[rt];
    if(l>ql) ls[rt]=cbuild(ql,l),fa[ls[rt]]=rt,tree[rt]=tree[ls[rt]]*tree[rt];//注意顺序,应该是先左子树到根最后右子树
    if(l+1<qr) rs[rt]=cbuild(l+1,qr),fa[rs[rt]]=rt,tree[rt]=tree[rt]*tree[rs[rt]];//因为在链上是从深度低的点往深度高的点
    return rt;
}
int build(int x){
    int y=x;
    do{
        val[y].a[0][0]=val[y].a[0][1]=g[y][0];
        val[y].a[1][0]=g[y][1];val[y].a[1][1]=-INF;
        for(auto v:ed[y]){
            if(v==son[y]||v==fat[y]) continue;
            fa[build(v)]=y;
        }
    }while(y=son[y]);
    do{
        b[y++]=x;pre[y]=pre[y-1]+siz[x]-siz[son[x]];
    }while(x=son[x]);
    return cbuild(0,y);
}
void change(int x,int y){
    g[x][1]+=y-a[x];
    a[x]=y;
    val[x].a[1][0]=g[x][1];
    while(x){
        mat pre=tree[x];
        tree[x]=val[x];
        if(ls[x]) tree[x]=tree[ls[x]]*tree[x];
        if(rs[x]) tree[x]=tree[x]*tree[rs[x]];
        if(ls[fa[x]]!=x&&rs[fa[x]]!=x){
            g[fa[x]][0]+=max(tree[x].a[0][0],tree[x].a[1][0])-max(pre.a[0][0],pre.a[1][0]);
            g[fa[x]][1]+=tree[x].a[0][0]-pre.a[0][0];
            val[fa[x]].a[0][0]=val[fa[x]].a[0][1]=g[fa[x]][0];
            val[fa[x]].a[1][0]=g[fa[x]][1];
        }
        x=fa[x];
    }
}
int main(){
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++) scanf("%d",&a[i]);
    for(int i=1;i<n;i++){
        int x,y;
        scanf("%d%d",&x,&y);
        ed[x].push_back(y);
        ed[y].push_back(x);
    }
    dfs(1,0);
    int Root=build(1);
    while(m--){
        int x,y;
        scanf("%d%d",&x,&y);
        x^=lans;
        change(x,y);
        lans=max(tree[Root].a[0][0],tree[Root].a[1][0]);
        printf("%d\n",lans);
    }
    return 0;
}