P3390 题解

· · 题解

传送门:P3390 【模板】矩阵快速幂

更佳的阅读体验:洛谷 P3390 题解

在阅读本文之前,你需要了解矩阵乘法与快速幂。

算法介绍

给你一个 n \times n 的矩阵 A,现在你需要计算 A^k(即 \underbrace{A \times A \times A \times \cdots \times A}_{k \text{ 个 } A})。

k = 2 时,你知道可以直接将两个矩阵 A 相乘得到答案。

k \le 10 时,你知道可以通过手写循环语句来计算最后的答案。

那么,当 k \le 10^{12} 呢?显然朴素的循环已经无法满足我们对程序的效率的要求,我们需要一个更快的算法。

陷入迷茫时,不妨回顾矩阵乘法的基本性质:

此时想到,是否可以使用快速幂来优化算法?答案是肯定的。

与朴素的快速幂相同,对于矩阵快速幂,我们有:

A^k = \begin{cases} (A^{\tfrac{k}{2}})^2, & \text{if } k \text{ is even} \\ (A^{\tfrac{k - 1}{2}})^2 \times A, & \text{if } k \text{ is odd} \end{cases}

此时,问题也就迎刃而解了。

值得注意的是,在进行幂运算之前,我们需要将结果矩阵初始化为单位矩阵,即:

I = \begin{bmatrix} 1 & 0 & \cdots & 0 \\ 0 & 1 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & 1 \end{bmatrix}

这样才能保证矩阵快速幂的正确性。

复杂度分析

对于任意两个 n \times n 的矩阵,相乘的时间复杂度为 \Theta(n^3)。快速幂的复杂度为 \Theta(\log k)

因此,矩阵快速幂的时间复杂度为 \Theta(n^3 \log k),可以通过本题。

代码实现

为了提高代码的可读性,我们将矩阵封装到结构体内,并且在结构体内重载乘运算符。

#include <iostream>
using namespace std;
using ll = long long;

const int N = 110, MOD = 1e9 + 7;
int n;
ll k;
struct mat {
    ll m[N][N], h, w;
    void clear() {
        h = w = n;
        for (int i = 1; i <= h; ++i)
            for (int j = 1; j <= w; ++j) m[i][j] = 0;
    } void reset() {
        clear();
        for (int i = 1; i <= n; ++i) m[i][i] = 1;
    } mat operator *(const mat &x) const {
        mat res;
        res.clear();
        for (int i = 1; i <= h; ++i)
            for (int j = 1; j <= x.w; ++j)
                for (int k = 1; k <= w; ++k)
                    res.m[i][j] = (res.m[i][j] + m[i][k] * x.m[k][j]) % MOD;
        return res;
    }
} a;

mat expow(mat a, ll b) {
    mat res;
    res.reset(), res.h = res.w = n;
    while (b) {
        if (b & 1) res = res * a;
        a = a * a, b >>= 1;
    } return res;
}

int main() {
    cin.tie(nullptr);
    ios::sync_with_stdio(false);
    cin >> n >> k, a.h = a.w = n;
    for (int i = 1; i <= n; ++i)
        for (int j = 1; j <= n; ++j) cin >> a.m[i][j];
    a = expow(a, k);
    for (int i = 1; i <= n; ++i)
        for (int j = 1; j <= n; ++j) cout << a.m[i][j] << " \n"[j == n];
    return 0;
}