用矩阵表示变化&&线段树维护矩阵

· · 题解

引入问题

洛谷P3373 (只有后面三种操作)

矩阵入门

可以看看我写的这篇博客

正文

线段树维护的是带有结合律的东西

假如维护的东西不存在结合律

那么每次一个新的标记就会把旧的标记挤下去

每一个标记都要遍历整棵树,复杂度变成 O(nm)

线段树维护矩阵

众所周知,矩阵乘法是有结合率

即 :

A(B C) = (AB)C

那么我们可以在线段树的每一个节点维护一个矩阵来表示一些信息

那么我们的每一次操作都想方设法把它变成"乘上一个矩阵"

那么我们就可以线段树维护区间矩阵乘法,来表示一些结合律不是那么明显的信息

注意 : 由于矩阵乘法不具有交换律,所以我们每次只能左乘,或只能右乘一个转移矩阵

左乘右乘 分别看个人实现的是一个列矩阵,还是一个行矩阵

矩阵模拟变化

每一个结点用矩阵可以表示成这样

\begin{bmatrix}v\\s\end{bmatrix} *** 按题目中的意思模拟 把原矩阵**左乘**一个矩阵来得到我们要的新矩阵 #### 区间加 $x \begin{bmatrix}1&x\\0&1\end{bmatrix}\times\begin{bmatrix}v\\s\end{bmatrix}=\begin{bmatrix}v + s \times x\\s\end{bmatrix}

区间乘 x

\begin{bmatrix}x&0\\0&1\end{bmatrix}\times\begin{bmatrix}v\\s\end{bmatrix}=\begin{bmatrix}v \times x\\s\end{bmatrix}

区间赋值为 x

\begin{bmatrix}0&x\\0&1\end{bmatrix}\times\begin{bmatrix}v\\s\end{bmatrix}=\begin{bmatrix}s \times x\\s\end{bmatrix}

区间求和

即按普通的线段树那样

把每一个结点的矩阵加在一起即可

对了,最好写一个矩阵类,弄个矩阵乘法加法之类的

由于矩阵运算有一个 8 的小常数,我吸氧后才勉强通过

\color {DeepSkyBlue} Code

#include <bits/stdc++.h>
#define N 100005

using namespace std;

int n, m, mod;

struct matrix {
    int r, c;
    long long a[2][2];
    inline matrix(int _r = 0, int _c = 0) {
        r = _r, c = _c;
        memset(a, 0, sizeof a);
    }
    inline bool operator != (const matrix &rhs) const {
        if(r != rhs.r || c != rhs.c) return true;
        for(int i = 0; i < r; ++i)
          for(int j = 0; j < c; ++j)
            if(a[i][j] != rhs.a[i][j])
              return true;
        return false;
    }
    /*
    for debug
    inline void print() {
        for(int i = 0; i < r; ++i) {
          for(int j = 0; j < c; ++j)
            printf("%lld ", a[i][j]);
            putchar('\n');
        }
    }
    */
}e(2, 2), delta(2, 2);
inline matrix mul(matrix A, matrix B) {
    matrix C(A.r, B.c);
    for(int i = 0; i < A.r; ++i)
      for(int j = 0; j < B.c; ++j)
        for(int k = 0; k < A.c; ++k)
          C.a[i][j] += A.a[i][k] * B.a[k][j], C.a[i][j] %= mod;
    return C;
}
inline matrix add(matrix A, matrix B) {
    matrix C(A.r, A.c);
    for(int i = 0; i < A.r; ++i)
      for(int j = 0; j < A.c; ++j)
        C.a[i][j] = A.a[i][j] + B.a[i][j], C.a[i][j] %= mod;
    return C;
}

template <typename T> inline void read(T &x) {
    x = 0; char ch = getchar();
    while(!isdigit(ch)) ch = getchar();
    while(isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
}

struct SMT {
#define ls(x) (x << 1)
#define rs(x) (x << 1 | 1)
    matrix sum[N << 2], tag[N << 2];
    inline void update(int u) { sum[u] = add(sum[ls(u)], sum[rs(u)]); }
    inline void pushdown(int u) {
        if(tag[u] != e) {
            int l = ls(u), r = rs(u);
            sum[l] = mul(tag[u], sum[l]), tag[l] = mul(tag[u], tag[l]);
            sum[r] = mul(tag[u], sum[r]), tag[r] = mul(tag[u], tag[r]);
            tag[u] = e;
        }
    }
    inline void build(int l, int r, int u) {
        tag[u] = e;
        sum[u].r = 2, sum[u].c = 1;
        if(l == r) {
            read(sum[u].a[0][0]);
            sum[u].a[1][0] = 1;
            return;
        }
        int mid = l + r >> 1;
        build(l, mid, ls(u)), build(mid + 1, r, rs(u));
        update(u);
        return;
    }
    inline void modify(int ll, int rr, int l, int r, int u) {
        if(ll <= l && r <= rr) {
            sum[u] = mul(delta, sum[u]);
            tag[u] = mul(delta, tag[u]);
            return;
        }
        pushdown(u);
        int mid = l + r >> 1;
        if(ll <= mid) modify(ll, rr, l, mid, ls(u));
        if(mid < rr) modify(ll, rr, mid + 1, r, rs(u));
        update(u);
        return;
    }
    inline matrix query(int ll, int rr, int l, int r, int u) {
        if(ll <= l && r <= rr) return sum[u];
        pushdown(u);
        int mid = l + r >> 1;
        if(rr <= mid) return query(ll, rr, l, mid, ls(u));
        else if(mid < ll) return query(ll, rr, mid + 1, r, rs(u));
        else return add(query(ll, rr, l, mid, ls(u)), query(ll, rr, mid + 1, r, rs(u)));
    }
#undef ls
#undef rs
}Lsk;

int main() {
    e.a[0][0] = e.a[1][1] = 1;
    read(n), read(m), read(mod);
    Lsk.build(1, n, 1);

    int opt, l, r;
    while(m--) {
        read(opt), read(l), read(r);
        if(opt == 3) {
            printf("%lld\n", Lsk.query(l, r, 1, n, 1).a[0][0]);
        } else {
            delta = e;
            if(opt == 1) {
                read(delta.a[0][0]);
                /*
                    mul
                    |x  0| |v|   |x * v|
                    |    |*| | = |     |
                    |0  1| |1|   |  1  |
                */
            } else if(opt == 2) {
                read(delta.a[0][1]);
                /*
                    add
                    |1  x| |v|   |v + x|
                    |    |*| | = |     |
                    |0  1| |1|   |  1  |
                */
            }
            Lsk.modify(l, r, 1, n, 1);
        }
    }
    return 0;
}