题解 P3373 【【模板】线段树 2】
竟然没有矩阵乘法的题解
考虑线段树
对于线段树的每一个节点,我们维护一个向量(x,y),其中x为区间和,y为区间大小。
操作一
对于一个向量(a,b),我们要让它成为(a*x,b),相当于让这个向量乘以如下矩阵。
操作二
对于一个向量(a,b),我们要让它成为(a+x*b,b),相当于让这个向量乘以如下矩阵。
这样修改操作就等价于让一个区间乘以一个矩阵。
于是每个节点的Lazytag就维护一个矩阵即可。
代码
#include<bits/stdc++.h>
#define ll long long
#define N 100010
using namespace std;
template<typename T> void read(T &x){
x=0;char c=getchar();T sig=1;
for (;!isdigit(c);c=getchar()) if (c=='-') sig=-1;
for (; isdigit(c);c=getchar()) x=(x<<3)+(x<<1)+c-'0';
x*=sig;
}
ll Mod;
struct matrix{
ll a[2][2];
void crt(){
a[0][0]=a[1][1]=1;
a[0][1]=a[1][0]=0;
return;
}
void clear(){
a[0][0]=a[0][1]=a[1][0]=a[1][1]=0;
return;
}
void crt1(ll x){
clear();
a[0][0]=x;
a[1][1]=1;
return;
}
void crt2(ll x){
clear();
a[0][0]=1;
a[1][0]=x;
a[1][1]=1;
return;
}
};
struct tree{
int l,r;
ll w[2];
matrix tag;
bool flag;
};
tree tr[N<<2];
ll a[N];
int n,k;
void timesA(tree &x,matrix y){
tree res;
res.w[0]=(x.w[0]*y.a[0][0]+x.w[1]*y.a[1][0])%Mod;
res.w[1]=(x.w[0]*y.a[0][1]+x.w[1]*y.a[1][1])%Mod;
x.w[0]=res.w[0];
x.w[1]=res.w[1];
return;
}
void timesB(matrix &x,matrix y){
matrix z;
z.a[0][0]=(x.a[0][0]*y.a[0][0]+x.a[0][1]*y.a[1][0])%Mod;
z.a[0][1]=(x.a[0][0]*y.a[0][1]+x.a[0][1]*y.a[1][1])%Mod;
z.a[1][0]=(x.a[1][0]*y.a[0][0]+x.a[1][1]*y.a[1][0])%Mod;
z.a[1][1]=(x.a[1][0]*y.a[0][1]+x.a[1][1]*y.a[1][1])%Mod;
x=z;
return;
}
void pushup(int i){
tr[i].w[0]=(tr[i<<1].w[0]+tr[i<<1|1].w[0])%Mod;
tr[i].w[1]=(tr[i<<1].w[1]+tr[i<<1|1].w[1])%Mod;
return;
}
void build(int i,int L,int R){
tr[i].l=L;
tr[i].r=R;
tr[i].tag.crt();
tr[i].flag=false;
if (L==R){
tr[i].w[0]=a[L];
tr[i].w[1]=1;
return;
}
int mid=(L+R)>>1;
build(i<<1,L,mid);
build(i<<1|1,mid+1,R);
pushup(i);
return;
}
void pushdown(int i){
if (tr[i].flag){
tr[i].flag=false;
tr[i<<1].flag=true;
tr[i<<1|1].flag=true;
timesB(tr[i<<1].tag,tr[i].tag);
timesB(tr[i<<1|1].tag,tr[i].tag);
timesA(tr[i<<1],tr[i].tag);
timesA(tr[i<<1|1],tr[i].tag);
tr[i].tag.crt();
}
return;
}
void add(int i,int L,int R,matrix s){
if (L<=tr[i].l&&tr[i].r<=R){
timesA(tr[i],s);
timesB(tr[i].tag,s);
tr[i].flag=true;
return;
}
pushdown(i);
if (L<=tr[i<<1].r) add(i<<1,L,R,s);
if (R>=tr[i<<1|1].l) add(i<<1|1,L,R,s);
pushup(i);
return;
}
ll query(int i,int L,int R){
if (L<=tr[i].l&&tr[i].r<=R) return tr[i].w[0];
pushdown(i);
ll res=0;
if (L<=tr[i<<1].r) res=query(i<<1,L,R);
if (R>=tr[i<<1|1].l) res+=query(i<<1|1,L,R);
res%=Mod;
return res;
}
int main(){
read(n);read(k);read(Mod);
for (int i=1;i<=n;i++) read(a[i]);
build(1,1,n);
while (k--){
int opt,L,R;
read(opt);read(L);read(R);
if (opt==3) printf("%lld\n",query(1,L,R));
else{
ll x;
read(x);
matrix type;
if (opt==1) type.crt1(x);else type.crt2(x);
add(1,L,R,type);
}
}
return 0;
}