题解 P10502 Matrix Power Series

· · 题解

不是很难的矩阵快速幂,老师说用等比数列公式求逆矩阵做,然而这题模数并不是质数。

题意

给定一个 n×n 矩阵 A 和一个正整数 k,求 S=\sum\limits_{i=1}^k A^i

分析

模数不为质数就不要想等比数列求和公式求逆矩阵做了,考虑一个比较朴素的实现。

直接求肯定会超时,又因矩阵乘法存在分配律,所以考虑分治。

k 为奇数时

x=\dfrac {k+1} 2 ,则可以推出:

\sum\limits_{i=1}^k A^i = \left(\sum\limits_{i=1}^{x-1} A^i\right) + A^x + A^x \sum\limits_{i=1}^{x-1} A^i

k 为偶数时

x=\dfrac k 2,则可以推出:

\sum\limits_{i=1}^k A^i = \left(\sum\limits_{i=1}^{x} A^i\right) + A^x \sum\limits_{i=1}^{x} A^i

这个可以通过分治求解,时间复杂度 T(k) = T(\dfrac k 2) + O(n^3 \log k),所以总体时间复杂度为 O(n^3 \log^2 k)

代码

//the code is from chenjh
#include<cstdio>
#include<cstring>
#include<cassert>
using namespace std;
int n,k,mod;
template<typename T>
struct Mat{
    int n,m;
    T **a;
    Mat(int _n=0,int _m=0):n(_n),m(_m){
        a=new T*[n];
        for(int i=0;i<n;i++) a[i]=new T[m],memset(a[i],0,sizeof(T)*m);
    }
    Mat(const Mat&B){
        n=B.n,m=B.m;
        a=new T*[n];
        for(int i=0;i<n;i++) a[i]=new T[m],memcpy(a[i],B.a[i],sizeof(T)*m);
    }
    ~Mat(){delete[] a;}
    Mat&operator = (const Mat&B){
        delete[] a;
        n=B.n,m=B.m;
        a=new T*[n];
        for(int i=0;i<n;i++) a[i]=new T[m],memcpy(a[i],B.a[i],sizeof(T)*m);
        return *this;
    }
    Mat operator + (const Mat&B)const{//矩阵加法。
        assert(n==B.n&&m==B.m);
        Mat ret(n,m);
        for(int i=0;i<n;i++)for(int j=0;j<m;j++) ret.a[i][j]=(a[i][j]+B.a[i][j])%mod;
        return ret;
    }
    Mat&operator += (const Mat&B){return *this=*this+B;}
    Mat operator * (const Mat&B)const{//矩阵乘法。
        Mat ret(n,B.m);
        for(int i=0;i<n;++i)for(int j=0;j<B.m;ret.a[i][j++]%=mod)for(int k=0;k<m;++k)ret.a[i][j]+=a[i][k]*B.a[k][j]%mod;
        return ret;
    }
    Mat&operator *= (const Mat&B){return *this=*this*B;}
};
Mat<int> qpow(Mat<int> A,int b){//矩阵快速幂。
    Mat<int> ret(A);
    for(--b;b;b>>=1,A*=A)if(b&1)ret*=A;
    return ret;
}
Mat<int>dfs(const Mat<int>&A,const int x){//递归求解。
    if(x==1) return A;
    Mat<int> ret(dfs(A,x>>1));
    if(x&1) ret+=qpow(A,(x+1)>>1)+qpow(A,(x+1)>>1)*ret;//分类讨论。
    else ret+=qpow(A,x>>1)*ret;
    return ret;
}
int main(){
    scanf("%d%d%d",&n,&k,&mod);
    Mat<int> A(n,n);
    for(int i=0;i<n;i++)for(int j=0;j<n;j++) scanf("%d",&A.a[i][j]),A.a[i][j]%=mod;
    Mat<int> ans(dfs(A,k));
    for(int i=0;i<n;i++,putchar('\n'))for(int j=0;j<n;j++) printf("%d ",ans.a[i][j]);
    return 0;
}