题解 P3373 【【模板】线段树 2】

· · 题解

竟然没有矩阵乘法的题解

考虑线段树

对于线段树的每一个节点,我们维护一个向量(x,y),其中x为区间和,y为区间大小。

操作一

对于一个向量(a,b),我们要让它成为(a*x,b),相当于让这个向量乘以如下矩阵。

A=\begin{bmatrix}x&0\\0&1\end{bmatrix}

操作二

对于一个向量(a,b),我们要让它成为(a+x*b,b),相当于让这个向量乘以如下矩阵。

B=\begin{bmatrix}1&0\\x&1\end{bmatrix}

这样修改操作就等价于让一个区间乘以一个矩阵。

于是每个节点的Lazytag就维护一个矩阵即可。

代码

#include<bits/stdc++.h>
#define ll long long
#define N 100010
using namespace std;
template<typename T> void read(T &x){
    x=0;char c=getchar();T sig=1;
    for (;!isdigit(c);c=getchar()) if (c=='-') sig=-1;
    for (; isdigit(c);c=getchar()) x=(x<<3)+(x<<1)+c-'0';
    x*=sig;
}
ll Mod;
struct matrix{
    ll a[2][2];
    void crt(){
        a[0][0]=a[1][1]=1;
        a[0][1]=a[1][0]=0;
        return;
    }
    void clear(){
        a[0][0]=a[0][1]=a[1][0]=a[1][1]=0;
        return;
    }
    void crt1(ll x){
        clear();
        a[0][0]=x;
        a[1][1]=1;
        return;
    }
    void crt2(ll x){
        clear();
        a[0][0]=1;
        a[1][0]=x;
        a[1][1]=1;
        return;
    }
};
struct tree{
    int l,r;
    ll w[2];
    matrix tag;
    bool flag;
};
tree tr[N<<2];
ll a[N];
int n,k;
void timesA(tree &x,matrix y){
    tree res;
    res.w[0]=(x.w[0]*y.a[0][0]+x.w[1]*y.a[1][0])%Mod;
    res.w[1]=(x.w[0]*y.a[0][1]+x.w[1]*y.a[1][1])%Mod;
    x.w[0]=res.w[0];
    x.w[1]=res.w[1];
    return;
}
void timesB(matrix &x,matrix y){
    matrix z;
    z.a[0][0]=(x.a[0][0]*y.a[0][0]+x.a[0][1]*y.a[1][0])%Mod;
    z.a[0][1]=(x.a[0][0]*y.a[0][1]+x.a[0][1]*y.a[1][1])%Mod;
    z.a[1][0]=(x.a[1][0]*y.a[0][0]+x.a[1][1]*y.a[1][0])%Mod;
    z.a[1][1]=(x.a[1][0]*y.a[0][1]+x.a[1][1]*y.a[1][1])%Mod;
    x=z;
    return;
}
void pushup(int i){
    tr[i].w[0]=(tr[i<<1].w[0]+tr[i<<1|1].w[0])%Mod;
    tr[i].w[1]=(tr[i<<1].w[1]+tr[i<<1|1].w[1])%Mod;
    return;
}
void build(int i,int L,int R){
    tr[i].l=L;
    tr[i].r=R;
    tr[i].tag.crt();
    tr[i].flag=false;
    if (L==R){
        tr[i].w[0]=a[L];
        tr[i].w[1]=1;
        return;
    }
    int mid=(L+R)>>1;
    build(i<<1,L,mid);
    build(i<<1|1,mid+1,R);
    pushup(i);
    return;
}
void pushdown(int i){
    if (tr[i].flag){
        tr[i].flag=false;
        tr[i<<1].flag=true;
        tr[i<<1|1].flag=true;
        timesB(tr[i<<1].tag,tr[i].tag);
        timesB(tr[i<<1|1].tag,tr[i].tag);
        timesA(tr[i<<1],tr[i].tag);
        timesA(tr[i<<1|1],tr[i].tag);
        tr[i].tag.crt();
    }
    return;
}
void add(int i,int L,int R,matrix s){
    if (L<=tr[i].l&&tr[i].r<=R){
        timesA(tr[i],s);
        timesB(tr[i].tag,s);
        tr[i].flag=true;
        return;
    }
    pushdown(i);
    if (L<=tr[i<<1].r) add(i<<1,L,R,s);
    if (R>=tr[i<<1|1].l) add(i<<1|1,L,R,s);
    pushup(i);
    return;
}
ll query(int i,int L,int R){
    if (L<=tr[i].l&&tr[i].r<=R) return tr[i].w[0];
    pushdown(i);
    ll res=0;
    if (L<=tr[i<<1].r) res=query(i<<1,L,R);
    if (R>=tr[i<<1|1].l) res+=query(i<<1|1,L,R);
    res%=Mod;
    return res;
}
int main(){
    read(n);read(k);read(Mod);
    for (int i=1;i<=n;i++) read(a[i]);
    build(1,1,n);
    while (k--){
        int opt,L,R;
        read(opt);read(L);read(R);
        if (opt==3) printf("%lld\n",query(1,L,R));
        else{
            ll x;
            read(x);
            matrix type;
            if (opt==1) type.crt1(x);else type.crt2(x);
            add(1,L,R,type);
        }
    }
    return 0;
}