矩阵加速学习笔记

· · 题解

矩阵简介及矩阵加速

简介

在数学中,矩阵是一个按照长方阵列排列的复数或实数集合,最早来自于方程组的系数及常数所构成的方阵——百度百科

通俗的来讲,把集合里的一些数填入到一个矩形中即得到一个矩阵

定义

m\times n个数a_{i,j}排成的数表称为mn列的矩阵简称m\times n矩阵。

A=\begin{bmatrix}a_{1,1}&a_{1,2}&...&a_{1,n}\\a_{2,1}&a_{2,2}&...&a_{2,n}\\...&...&...&...\\a_{m,1}&a_{m,2}&...&a_{m,n}\end{bmatrix}

n\times m个数称为矩阵A的元素,简称为元。。。(剩下的都是百度百科的废话

mn列的矩阵也记作A_{mn}

特别的,两个n,m都相同的矩阵称为同型矩阵

#### 基本运算 ##### 加法 $$\begin{bmatrix}a_{1,1}&...&a_{1,n}\\...&...&...\\a_{m,1}&...&a_{m,n}\end{bmatrix}+\begin{bmatrix}b_{1,1}&...&b_{1,n}\\...&...&...\\b_{m,1}&...&b_{m,n}\end{bmatrix}=\begin{bmatrix}a_{1,1}+b_{1,1}&...&a_{1,n}+b_{1,n}\\...&...&...\\a_{m,1}+b_{m,1}&...&a_{m,n}+b_{m,n}\end{bmatrix}$$ 同时,矩阵的加法和实数的加法一样,满足交换律和结合律 即 $$A+B=B+A$$ $$(A+B)+C=A+(B+C)$$ 应当注意,只有同型矩阵之间才可以进行加法 ##### 减法 在数集中,减法作为加法的逆运算 在矩阵中也是一样的 $$\begin{bmatrix}a_{1,1}&...&a_{1,n}\\...&...&...\\a_{m,1}&...&a_{m,n}\end{bmatrix}-\begin{bmatrix}b_{1,1}&...&b_{1,n}\\...&...&...\\b_{m,1}&...&b_{m,n}\end{bmatrix}=\begin{bmatrix}a_{1,1}-b_{1,1}&...&a_{1,n}-b_{1,n}\\...&...&...\\a_{m,1}-b_{m,1}&...&a_{m,n}-b_{m,n}\end{bmatrix}$$ ##### 数乘 在矩阵中引入了数乘的概念,即为一个数乘一个矩阵 $$\mu·\begin{bmatrix}a_{1,1}&...&a_{1,n}\\...&...&...\\a_{m,1}&...&a_{m,n}\end{bmatrix}=\begin{bmatrix}\mu· a_{1,1}&...&\mu· a_{1,n}\\...&...&...\\\mu· a_{m,1}&...&\mu· a_{m,n}\end{bmatrix}$$ 矩阵的数乘满足以下规律: $$\lambda(\mu A)=\mu(\lambda A)$$ $$\lambda(\mu A)=(\lambda\mu A)$$ $$(\lambda +\mu)A=\lambda A+\mu A$$ $$\lambda(A+B)=\lambda A+\lambda B$$ ##### 乘法 两个矩阵的乘法仅当第一个矩阵$A$的列数和另一个矩阵$B$的行数相等时才能定义。 记作 $$A_{mn} B_{np}=C_{mp}$$ 也作 $$AB=C$$ $C$的一个元素$c_{i,j}$的值为 $$c_{i,j}=\sum_{k=1}^na_{i,k}\times b_{k,j}$$ 例如 $$\begin{bmatrix}1 &0 &2\\-1&3&1\end{bmatrix}\times\begin{bmatrix}3&1\\2&1\\1&0\end{bmatrix}=\begin{bmatrix}(1\times 3+0\times2+2\times1)&(1\times1+0\times1+2\times0)\\(-1\times3+3\times2+1\times1)&(-1\times1+3\times1+1\times0)\end{bmatrix}=\begin{bmatrix}5&1\\4&2\end{bmatrix}$$ 同时,矩阵的乘法满足以下运算律: 结合律: $$(AB)C=A(BC)$$ 左分配律: $$(A+B)C=AC+BC$$ 右分配律: $$C(A+B)=CA+CB$$ #### 矩阵加速 ##### 快速幂 这个还需要多说么? ``` ksm int base=a,ans=1; while(b>0){ if(b&1){ ans*=base ans%=c; } base*=base base%=c; b=b>>1; } return ans; ``` 就是很简单的按位运算 所以我们对于一个矩阵的幂也可以这么来运算,于是时间复杂度从$O(n)$降到了$O(log_2n)

这就很舒服

但是这个有什么用呢?

例题一

a[1]=a[2]=a[3]=1

a[x]=a[x-3]+a[x-1] (x>3)

求a数列的第n项对1000000007(10^9+7)取余的值。

第一眼看到这个题

?

这不是来送分的么

我一项一项去推不就好了???

对于100%的数据 T<=100,n<=2*10^9;

笑容逐渐凝固

所以现在就体现出矩阵加速的用处了

对于一个数列f(n)=f(n-x)+f(n-y)+...

我们可以构造出一个k\times1的矩阵表示当前的状态,即用f(x),f(y)...来表示我们当前已经知道的

然后我们再构造一个转移矩阵,使我们的原始矩阵乘上这个转移矩阵后可以变成一个新的矩阵,得到全新的f(x),从而一步步的向前推进,得到答案

显而易见,由于我们需要用到f(n-1)f(n-3),所以我们需要保留住这两个状态,f(n-3)是由上个状态的f(n-2)继承过来的,所以我们也需要保留,所以原始矩阵要构造成这个样子

\begin{bmatrix}f(3)\\f(2)\\f(1)\end{bmatrix}

即为

\begin{bmatrix}1\\1\\1\end{bmatrix}

然后思考我们怎么由这个矩阵得到下一个矩阵

设我们当前矩阵为

\begin{bmatrix}f(n-1)\\f(n-2)\\f(n-3)\end{bmatrix}

我们希望得到的矩阵为

\begin{bmatrix}f(n)\\f(n-1)\\f(n-2)\end{bmatrix}

即为

\begin{bmatrix}f(n-1)\times1+f(n-2)\times 0+f(n-3)\times1\\f(n-1)\times1+f(n-2)\times0+f(n-3)\times0\\f(n-1)\times0+f(n-2)\times1+f(n-3)\times0\end{bmatrix}

显而易见我们的转移矩阵为

\begin{bmatrix}1&0&1\\1&0&0\\0&1&0\end{bmatrix}

由于我们原始矩阵的第一项为f(3),所以我们只需要乘n-3次这个转移矩阵就可以得到了!

但是问题来了,这样不还是一个O(n)的算法么?而且还慢了好多

当然不是!想一下我们矩阵乘法的结合律

我们设原始矩阵为A,转移矩阵为B,那么我们最终答案的矩阵就是

\underbrace{B(B(B(B(B(B(B...(B}_{n-3}A))))))

根据矩阵乘法的结合律,这个式子可以简化为

B^{n-3}A

到了这里,我们就可以用矩阵快速幂来求出最终答案了!

贴上我的巨丑无比的代码: ```c++ #include<cstdio> #include<cstring> #include<iostream> #define mod 1000000007 using namespace std; struct Ju{ long long p[5][5]; Ju operator *(const Ju &a)const{ Ju c; for(long long i=1;i<=3;i++){ for(long long j=1;j<=3;j++){ c.p[i][j]=0; } } for(long long i=1;i<=3;i++){ for(long long j=1;j<=3;j++){ for(long long k=1;k<=3;k++){ c.p[i][j]+=(a.p[i][k]*p[k][j])%mod; c.p[i][j]%=mod; } } } return c; } }ans,base; void build(){ for(long long i=1;i<=3;i++){ for(long long j=1;j<=3;j++){ base.p[i][j]=0; } } base.p[1][1]=base.p[1][3]=base.p[2][1]=base.p[3][2]=1; ans.p[1][1]=ans.p[2][1]=ans.p[3][1]=1; } void qsm(long long p){ while(p!=0){ if(p&1){ ans=ans*base; } base=base*base; p>>=1; } } long long T,n; int main(){ cin>>T; for(long long i=1;i<=T;i++){ cin>>n; if(n<=3){ printf("1\n"); continue; } build(); qsm(n-3); cout<<ans.p[1][1]<<endl; } return 0; } ``` 同样的题还有[这一道](https://www.luogu.org/problem/P1962) 写作不易,点个赞再走呗QAQ