浅谈矩阵乘法在线段树标记下传的运用

· · 算法·理论

前言

对于复杂的区间操作,我们自然会想到使用线段树。但是,这也意味着会出现复杂的标记下传与维护,许多初学者也因此开始打退堂鼓。本文将介绍另一种用矩阵乘法来解决复杂标记下传的问题。

前置知识:简单线段树、线性代数基础(矩阵乘法)。

一、多标记下传

(一)P3373 线段树 2

由于有区间乘法和区间加法两种操作,所以用普通的标记则需要分别维护乘法标记和加法标记,并且需要进行复杂的分类讨论。而使用矩阵乘法就不一样了,只需要一个标记,自然也没有了分类讨论。

我们让线段树的每个节点存储一个向量 \begin{bmatrix}x\\len\end{bmatrix} ,其中 x 表示当前的区间和,len 表示区间长度。现对其进行加 k ,也就是使这个向量变为 \begin{bmatrix}x+len\times k\\len\end{bmatrix}len 不能变)。不难得到,这就是对原向量\begin{bmatrix}1&k\\0&1\end{bmatrix} ,即 \begin{bmatrix}1&k\\0&1\end{bmatrix}\begin{bmatrix}x\\len\end{bmatrix}=\begin{bmatrix}x+len\times k\\len\end{bmatrix} 。同理可得,对区间乘 k 就是对原向量\begin{bmatrix}k&0\\0&1\end{bmatrix}

:::info[注意]{open} 若想用右乘,只需将列向量变为行向量,再将原来用于左乘的矩阵进行转置后拿去右乘,但本文中的矩阵乘法均使用左乘。 :::

此时,我们就将区间加与区间乘两个完全不同的操作,统一为了矩阵乘法。于是我们就只用维护一个矩阵乘法的标记。因为矩阵乘法满足结合律,所以可以放心打标记。

对于 pushup ,只需要对两个子节点的向量进行矩阵加法后放到父节点上就行了,这一点很容易理解。

:::success[参考代码]

#include<bits/stdc++.h>
using namespace std;
const int N=1e5+5,mod=571373;
struct Matrix{
    int n,m,a[3][3];
    void clear(){for(int i=1;i<=n;i++)for(int j=1;j<=m;j++)a[i][j]=0;}
    void reset(){clear();for(int i=1;i<=min(n,m);i++)a[i][i]=1;}
    void init(int _n,int _m,int op){n=_n,m=_m;if(op==0)clear();else reset();}
    Matrix friend operator+(const Matrix&A,const Matrix&B){
        if(A.n!=B.n||A.m!=B.m)cout<<"Error:add",exit(0);
        Matrix C;C.n=A.n,C.m=A.m;
        for(int i=1;i<=A.n;i++)for(int j=1;j<=A.m;j++)
            C.a[i][j]=(A.a[i][j]+B.a[i][j])%mod;
        return C;
    }
    Matrix friend operator*(const Matrix&A,const Matrix&B){
        if(A.m!=B.n)cout<<"Error:mul",exit(0);
        Matrix C;C.init(A.n,B.m,0);
        for(int i=1;i<=A.n;i++)for(int j=1;j<=B.m;j++)
            for(int k=1;k<=A.m;k++)
                C.a[i][j]=(1ll*A.a[i][k]*B.a[k][j]%mod+C.a[i][j])%mod;
        return C;
    }
};
struct segment{
    #define mid (l+r>>1)
    int n;Matrix tr[(N<<2)+1],tag[(N<<2)+1];
    void push(int u,Matrix V){tr[u]=V*tr[u],tag[u]=V*tag[u];}
    void pushdown(int u){push(u<<1,tag[u]),push(u<<1|1,tag[u]),tag[u].init(2,2,1);}
    void pushup(int u){tr[u]=tr[u<<1]+tr[u<<1|1];}
    void update(int u,int l,int r,int L,int R,int v,int op){
        if(L<=l&&r<=R){
            Matrix T;T.init(2,2,1);
            if(op==1)T.a[1][1]=v;
            else T.a[1][2]=v;
            return push(u,T);
        }
        pushdown(u);
        if(L<=mid)update(u<<1,l,mid,L,R,v,op);
        if(R>mid)update(u<<1|1,mid+1,r,L,R,v,op);
        pushup(u);
    }void update(int L,int R,int v,int op){
        update(1,1,n,L,R,v,op);}
    Matrix query(int u,int l,int r,int L,int R){
        if(L<=l&&r<=R)return tr[u];
        pushdown(u);Matrix ans;ans.init(2,1,0);
        if(L<=mid)ans=ans+query(u<<1,l,mid,L,R);
        if(R>mid)ans=ans+query(u<<1|1,mid+1,r,L,R);
        return ans;
    }int query(int L,int R){
        return query(1,1,n,L,R).a[1][1];}
    void build(int u,int l,int r,int*a){
        tr[u].init(2,1,0),tag[u].init(2,2,1);
        if(l==r)return tr[u].a[1][1]=a[l],tr[u].a[2][1]=1,void();
        build(u<<1,l,mid,a),build(u<<1|1,mid+1,r,a),pushup(u);
    }void build(int n_,int*a){n=n_,build(1,1,n,a);}
    #undef mid
}seg;
int n,q,m,a[N];
int main(){
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    cin>>n>>q>>m;
    for(int i=1;i<=n;i++)cin>>a[i];
    seg.build(n,a);
    int op,l,r,v;
    for(;q--;){
        cin>>op>>l>>r;
        if(op<3)cin>>v,seg.update(l,r,v,op);
        else cout<<seg.query(l,r)<<"\n";
    }
    return 0;
}

:::

(二)P1253 扶苏的问题

对于这道题,我们让线段树的每个节点存储一个向量 \begin{bmatrix}x\\1\end{bmatrix} ,其中 x 表示当前的区间最大值。根据上道题的思路,我们可以将区间覆盖操作变为左乘 \begin{bmatrix}0&k\\0&1\end{bmatrix} ,区间加法变为左乘 \begin{bmatrix}1&k\\0&1\end{bmatrix}(想想区间加为什么和上一题一样)。pushup 则取左右两个儿子中向量 x 值更大的。所以只需要将上一题的代码改一下就是这题的代码了。

:::success[参考代码]

//因为常数问题,该代码仅能得到90tps
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N=1e6+5;
struct Matrix{
    int n,m;ll a[3][3];
    void clear(){for(int i=1;i<=n;i++)for(int j=1;j<=m;j++)a[i][j]=0;}
    void reset(){clear();for(int i=1;i<=min(n,m);i++)a[i][i]=1;}
    void init(int _n,int _m,int op){n=_n,m=_m;if(op==0)clear();else reset();}
    Matrix friend operator*(const Matrix&A,const Matrix&B){
        if(A.m!=B.n)cout<<"Error:mul",exit(0);
        Matrix C;C.init(A.n,B.m,0);
        for(int i=1;i<=A.n;i++)for(int j=1;j<=B.m;j++)
            for(int k=1;k<=A.m;k++)
                C.a[i][j]+=A.a[i][k]*B.a[k][j];
        return C;
    }
};
struct segment{
    #define mid (l+r>>1)
    int n;Matrix tr[(N<<2)+1],tag[(N<<2)+1];
    void push(int u,Matrix V){tr[u]=V*tr[u],tag[u]=V*tag[u];}
    void pushdown(int u){push(u<<1,tag[u]),push(u<<1|1,tag[u]),tag[u].init(2,2,1);}
    void pushup(int u){tr[u].a[1][1]=max(tr[u<<1].a[1][1],tr[u<<1|1].a[1][1]);}
    void update(int u,int l,int r,int L,int R,int v,int op){
        if(L<=l&&r<=R){
            Matrix T;T.init(2,2,1);
            T.a[1][1]=op-1,T.a[1][2]=v;
            return push(u,T);
        }
        pushdown(u);
        if(L<=mid)update(u<<1,l,mid,L,R,v,op);
        if(R>mid)update(u<<1|1,mid+1,r,L,R,v,op);
        pushup(u);
    }void update(int L,int R,int v,int op){update(1,1,n,L,R,v,op);}
    ll query(int u,int l,int r,int L,int R){
        if(L<=l&&r<=R)return tr[u].a[1][1];
        pushdown(u);ll ans=-1e18;
        if(L<=mid)ans=max(ans,query(u<<1,l,mid,L,R));
        if(R>mid)ans=max(ans,query(u<<1|1,mid+1,r,L,R));
        return ans;
    }ll query(int L,int R){return query(1,1,n,L,R);}
    void build(int u,int l,int r,int*a){
        tr[u].init(2,1,0),tag[u].init(2,2,1),tr[u].a[2][1]=1;
        if(l==r)return tr[u].a[1][1]=a[l],tr[u].a[2][1]=1,void();
        build(u<<1,l,mid,a),build(u<<1|1,mid+1,r,a),pushup(u);
    }void build(int n_,int*a){n=n_,build(1,1,n,a);}
    #undef mid
}seg;
int n,q,a[N];
int main(){
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    cin>>n>>q;
    for(int i=1;i<=n;i++)cin>>a[i];
    seg.build(n,a);
    int op,l,r,v;
    for(;q--;){
        cin>>op>>l>>r;
        if(op<3)cin>>v,seg.update(l,r,v,op);
        else cout<<seg.query(l,r)<<"\n";
    }
    return 0;
}

:::

二、区间历史最值

:::info[前置知识1:区间历史最值的概念]

历史最大值

简单地说,一个位置的历史最大值就是当前位置下曾经出现过的数的最大值。形式化地定义,我们定义一个辅助数组 B ,一开始与 A 完全相同。在 A 的每次操作后,我们对整个数组取 \max

\forall i\in[1,n],\ B_i=\max(B_i,A_i)

这时,我们将 B_i 称作这个位置的历史最大值。

历史最小值

定义与历史最大值类似,在 A 的每次操作后,我们对整个数组取 \min 。这时,我们将 B_i 称作这个位置的历史最小值。

历史版本和

辅助数组 B 一开始全部是 0。在每一次操作后,我们把整个 A 数组累加到 B 数组上:

\forall i\in[1,n], \ B_i=B_i+A_i

我们称 B_ii 这个位置上的历史版本和。

※ 以上内容摘自 OI Wiki :::

:::info[前置知识2:广义矩阵乘]

一、定义

设两个 n\times n 的矩阵 A,B ,定义广义矩阵乘 C=A\odot B(符号 \odot 表示广义乘),则其元素 C_{i,j} 满足:C_{i,j}=\bigoplus_{k=1}^{n}(A_{i,k}\otimes B_{k,j})

在此基础上还需保证设计出来的广义矩阵乘满足结合律

二、判断结合律:4 条半环公理

广义矩阵乘 \odot 满足结合律的充要条件是:运算对 (\oplus,\otimes) 满足以下 4 条公理(构成“半环”结构):

  1. ::: 我们知道这一类问题的标准解法为吉司机线段树,但吉司机线段树也需要复杂的标记下传。而这里将介绍如何用矩阵乘法来解决,或者说用矩阵乘法来理解吉司机线段树。 ### (一)区间历史最大/最小

先来看区间历史最大。这次我们让线段树的每个节点存储的向量为 \begin{bmatrix}a\\b\end{bmatrix} ,其中 a 表示当前区间最大值,b 表示区间历史最大值。并定义广义矩阵乘中广义加法为取 \max ,广义乘法为普通加法。此时,对于区间加操作,有 \begin{bmatrix}k&-\infty\\k&0\end{bmatrix}\begin{bmatrix}a\\b\end{bmatrix}=\begin{bmatrix}a+k\\\max(b,a+k)\end{bmatrix}(使用广义矩阵乘,下同)。若还有区间覆盖操作,就需要将向量增加一维到 \begin{bmatrix}a\\b\\0\end{bmatrix} ,于是就有了 \begin{bmatrix}-\infty&-\infty&k\\-\infty&0&k\\-\infty&-\infty&0\end{bmatrix}\begin{bmatrix}a\\b\\0\end{bmatrix}=\begin{bmatrix}k\\\max(b,k)\\0\end{bmatrix}(此时的区间加请下来自己思考)。而 pushup ,也就变为了将左右儿子的向量用此时的广义加法(取 \max )加起来再赋给父节点。

对于区间历史最小,就是将广义加法定义为取 \min ,剩下的与区间历史最大一样。

(二)区间历史和

这里使用普通矩阵乘。让线段树的每个节点存储的向量为 \begin{bmatrix}a\\b\\len\end{bmatrix} ,其中 a 表示当前区间和,b 表示区间历史和,len 表示区间长度。对于区间加,有 \begin{bmatrix}1&0&k\\1&1&k\\0&0&1\end{bmatrix}\begin{bmatrix}a\\b\\len\end{bmatrix}=\begin{bmatrix}a+len\times k\\b+a+len\times k\\len\end{bmatrix} ;对于区间覆盖,有 \begin{bmatrix}0&0&k\\0&1&k\\0&0&1\end{bmatrix}\begin{bmatrix}a\\b\\len\end{bmatrix}=\begin{bmatrix}len\times k\\b+len\times k\\len\end{bmatrix}

总结

如你所见,矩阵乘法对线段树的标记下传的简化效果是很大的。但是,也会带来一些复杂度上的问题。假设矩阵大小为 w\times w ,就会对时间复杂度增加一个大小为 w^3 的复杂度,对空间复杂度增加一个大小为 w/w^2 的复杂度。所以,虽然可以对线段树所维护的向量上灵活地增加 0/1/len ,但最好不要让向量大小超过 4

最后,也希望这篇博客能为你带来帮助,感谢你的浏览!

参考资料

  1. OI-Wiki
  2. https://www.luogu.com.cn/article/ypgkm4vg
  3. https://www.luogu.com.cn/article/d04azg6j
  4. https://www.luogu.com.cn/article/tafs5gxk