浅谈矩阵乘法在线段树标记下传的运用
Little_Zyl · · 算法·理论
前言
对于复杂的区间操作,我们自然会想到使用线段树。但是,这也意味着会出现复杂的标记下传与维护,许多初学者也因此开始打退堂鼓。本文将介绍另一种用矩阵乘法来解决复杂标记下传的问题。
前置知识:简单线段树、线性代数基础(矩阵乘法)。
一、多标记下传
(一)P3373 线段树 2
由于有区间乘法和区间加法两种操作,所以用普通的标记则需要分别维护乘法标记和加法标记,并且需要进行复杂的分类讨论。而使用矩阵乘法就不一样了,只需要一个标记,自然也没有了分类讨论。
我们让线段树的每个节点存储一个向量
:::info[注意]{open} 若想用右乘,只需将列向量变为行向量,再将原来用于左乘的矩阵进行转置后拿去右乘,但本文中的矩阵乘法均使用左乘。 :::
此时,我们就将区间加与区间乘两个完全不同的操作,统一为了矩阵乘法。于是我们就只用维护一个矩阵乘法的标记。因为矩阵乘法满足结合律,所以可以放心打标记。
对于 pushup ,只需要对两个子节点的向量进行矩阵加法后放到父节点上就行了,这一点很容易理解。
:::success[参考代码]
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+5,mod=571373;
struct Matrix{
int n,m,a[3][3];
void clear(){for(int i=1;i<=n;i++)for(int j=1;j<=m;j++)a[i][j]=0;}
void reset(){clear();for(int i=1;i<=min(n,m);i++)a[i][i]=1;}
void init(int _n,int _m,int op){n=_n,m=_m;if(op==0)clear();else reset();}
Matrix friend operator+(const Matrix&A,const Matrix&B){
if(A.n!=B.n||A.m!=B.m)cout<<"Error:add",exit(0);
Matrix C;C.n=A.n,C.m=A.m;
for(int i=1;i<=A.n;i++)for(int j=1;j<=A.m;j++)
C.a[i][j]=(A.a[i][j]+B.a[i][j])%mod;
return C;
}
Matrix friend operator*(const Matrix&A,const Matrix&B){
if(A.m!=B.n)cout<<"Error:mul",exit(0);
Matrix C;C.init(A.n,B.m,0);
for(int i=1;i<=A.n;i++)for(int j=1;j<=B.m;j++)
for(int k=1;k<=A.m;k++)
C.a[i][j]=(1ll*A.a[i][k]*B.a[k][j]%mod+C.a[i][j])%mod;
return C;
}
};
struct segment{
#define mid (l+r>>1)
int n;Matrix tr[(N<<2)+1],tag[(N<<2)+1];
void push(int u,Matrix V){tr[u]=V*tr[u],tag[u]=V*tag[u];}
void pushdown(int u){push(u<<1,tag[u]),push(u<<1|1,tag[u]),tag[u].init(2,2,1);}
void pushup(int u){tr[u]=tr[u<<1]+tr[u<<1|1];}
void update(int u,int l,int r,int L,int R,int v,int op){
if(L<=l&&r<=R){
Matrix T;T.init(2,2,1);
if(op==1)T.a[1][1]=v;
else T.a[1][2]=v;
return push(u,T);
}
pushdown(u);
if(L<=mid)update(u<<1,l,mid,L,R,v,op);
if(R>mid)update(u<<1|1,mid+1,r,L,R,v,op);
pushup(u);
}void update(int L,int R,int v,int op){
update(1,1,n,L,R,v,op);}
Matrix query(int u,int l,int r,int L,int R){
if(L<=l&&r<=R)return tr[u];
pushdown(u);Matrix ans;ans.init(2,1,0);
if(L<=mid)ans=ans+query(u<<1,l,mid,L,R);
if(R>mid)ans=ans+query(u<<1|1,mid+1,r,L,R);
return ans;
}int query(int L,int R){
return query(1,1,n,L,R).a[1][1];}
void build(int u,int l,int r,int*a){
tr[u].init(2,1,0),tag[u].init(2,2,1);
if(l==r)return tr[u].a[1][1]=a[l],tr[u].a[2][1]=1,void();
build(u<<1,l,mid,a),build(u<<1|1,mid+1,r,a),pushup(u);
}void build(int n_,int*a){n=n_,build(1,1,n,a);}
#undef mid
}seg;
int n,q,m,a[N];
int main(){
ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
cin>>n>>q>>m;
for(int i=1;i<=n;i++)cin>>a[i];
seg.build(n,a);
int op,l,r,v;
for(;q--;){
cin>>op>>l>>r;
if(op<3)cin>>v,seg.update(l,r,v,op);
else cout<<seg.query(l,r)<<"\n";
}
return 0;
}
:::
(二)P1253 扶苏的问题
对于这道题,我们让线段树的每个节点存储一个向量 pushup 则取左右两个儿子中向量
:::success[参考代码]
//因为常数问题,该代码仅能得到90tps
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N=1e6+5;
struct Matrix{
int n,m;ll a[3][3];
void clear(){for(int i=1;i<=n;i++)for(int j=1;j<=m;j++)a[i][j]=0;}
void reset(){clear();for(int i=1;i<=min(n,m);i++)a[i][i]=1;}
void init(int _n,int _m,int op){n=_n,m=_m;if(op==0)clear();else reset();}
Matrix friend operator*(const Matrix&A,const Matrix&B){
if(A.m!=B.n)cout<<"Error:mul",exit(0);
Matrix C;C.init(A.n,B.m,0);
for(int i=1;i<=A.n;i++)for(int j=1;j<=B.m;j++)
for(int k=1;k<=A.m;k++)
C.a[i][j]+=A.a[i][k]*B.a[k][j];
return C;
}
};
struct segment{
#define mid (l+r>>1)
int n;Matrix tr[(N<<2)+1],tag[(N<<2)+1];
void push(int u,Matrix V){tr[u]=V*tr[u],tag[u]=V*tag[u];}
void pushdown(int u){push(u<<1,tag[u]),push(u<<1|1,tag[u]),tag[u].init(2,2,1);}
void pushup(int u){tr[u].a[1][1]=max(tr[u<<1].a[1][1],tr[u<<1|1].a[1][1]);}
void update(int u,int l,int r,int L,int R,int v,int op){
if(L<=l&&r<=R){
Matrix T;T.init(2,2,1);
T.a[1][1]=op-1,T.a[1][2]=v;
return push(u,T);
}
pushdown(u);
if(L<=mid)update(u<<1,l,mid,L,R,v,op);
if(R>mid)update(u<<1|1,mid+1,r,L,R,v,op);
pushup(u);
}void update(int L,int R,int v,int op){update(1,1,n,L,R,v,op);}
ll query(int u,int l,int r,int L,int R){
if(L<=l&&r<=R)return tr[u].a[1][1];
pushdown(u);ll ans=-1e18;
if(L<=mid)ans=max(ans,query(u<<1,l,mid,L,R));
if(R>mid)ans=max(ans,query(u<<1|1,mid+1,r,L,R));
return ans;
}ll query(int L,int R){return query(1,1,n,L,R);}
void build(int u,int l,int r,int*a){
tr[u].init(2,1,0),tag[u].init(2,2,1),tr[u].a[2][1]=1;
if(l==r)return tr[u].a[1][1]=a[l],tr[u].a[2][1]=1,void();
build(u<<1,l,mid,a),build(u<<1|1,mid+1,r,a),pushup(u);
}void build(int n_,int*a){n=n_,build(1,1,n,a);}
#undef mid
}seg;
int n,q,a[N];
int main(){
ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
cin>>n>>q;
for(int i=1;i<=n;i++)cin>>a[i];
seg.build(n,a);
int op,l,r,v;
for(;q--;){
cin>>op>>l>>r;
if(op<3)cin>>v,seg.update(l,r,v,op);
else cout<<seg.query(l,r)<<"\n";
}
return 0;
}
:::
二、区间历史最值
:::info[前置知识1:区间历史最值的概念]
历史最大值
简单地说,一个位置的历史最大值就是当前位置下曾经出现过的数的最大值。形式化地定义,我们定义一个辅助数组
这时,我们将
历史最小值
定义与历史最大值类似,在
历史版本和
辅助数组
我们称
※ 以上内容摘自 OI Wiki :::
:::info[前置知识2:广义矩阵乘]
一、定义
设两个
-
-
- 运算范围需一致(比如实数集、整数集、有限集合)。
在此基础上还需保证设计出来的广义矩阵乘满足结合律。
二、判断结合律:4 条半环公理
广义矩阵乘
-
-
-
-
::: 我们知道这一类问题的标准解法为吉司机线段树,但吉司机线段树也需要复杂的标记下传。而这里将介绍如何用矩阵乘法来解决,或者说用矩阵乘法来理解吉司机线段树。 ### (一)区间历史最大/最小
先来看区间历史最大。这次我们让线段树的每个节点存储的向量为 pushup ,也就变为了将左右儿子的向量用此时的广义加法(取
对于区间历史最小,就是将广义加法定义为取
(二)区间历史和
这里使用普通矩阵乘。让线段树的每个节点存储的向量为
总结
如你所见,矩阵乘法对线段树的标记下传的简化效果是很大的。但是,也会带来一些复杂度上的问题。假设矩阵大小为
最后,也希望这篇博客能为你带来帮助,感谢你的浏览!
参考资料
- OI-Wiki
- https://www.luogu.com.cn/article/ypgkm4vg
- https://www.luogu.com.cn/article/d04azg6j
- https://www.luogu.com.cn/article/tafs5gxk