题解:P12847 [蓝桥杯 2025 国 A] 斐波那契数列

· · 题解

题目传送门

思路

看到 n=10^{18},我会矩阵加速!很遗憾,递推式中包含乘号,我们无法加速它。但它求的答案是乘积?所以也许我们能从指数上找规律。直接打表:

n G ans
1 G_1=2 ans_1=2^1·3^0
2 G_2=3 ans_2=2^1·3^1
3 G_3=2·3 ans_3=2^2·3^2
4 G_4=2·3^2 ans_4=2^3·3^4
5 G_5=2^2·3^3 ans_5=2^5·3^7

不难看出,左边的指数为斐波那契数列,满足 F_i=F_{i-1}+F_{i-2},F_1=F_2=1。右边的指数满足 F^{'}_i=F^{'}_{i-1}+F^{'}_{i-2}+1,F^{'}_0=F^{'}_1=0。于是我们就可以矩阵加速计算这两个递推式。关键是指数需要取模,怎么办?有欧拉定理:

a^{n\mod \varphi(m)}\equiv a^n(mod\ m),\gcd(a,m)=1

这样我们对递推式计算的结果取模 \varphi(998244353)=998244352 即可。时间复杂度 O(M^3\log n),其中 M 为矩阵边长,本题可以理解为 M=3

Code

#include <bits/stdc++.h>
using namespace std;
#define il inline
typedef long long ll;
const int N = 3, mod = 998244353;
int len;
struct node
{
    ll a[N][N];
    friend node operator * (const node &n1, const node &n2)
    {
        node n3 = (node){{{0}}};
        for(int i = 0;i < len;++i)
            for(int j = 0;j < len;++j)
                for(int k = 0;k < len;++k)
                    n3.a[i][j] = (n3.a[i][j] + n1.a[i][k] * n2.a[k][j] % (mod - 1)) % (mod - 1);
        return n3;
    }
};
il node matpow(node A, node T, ll n)
{
    while(n)
    {
        if(n & 1) A = A * T;
        T = T * T;
        n >>= 1;
    }
    return A;
}
il ll f1(ll n)
{
    if(n <= 0) return 0;
    if(n <= 2) return 1;
    node T = (node){{
        {1, 1},
        {1, 0}
    }};
    node A = (node){{
        {1, 0},
        {0, 1}
    }};
    len = 2;
    node ans = matpow(A, T, n - 2);
    return (ans.a[0][0] + ans.a[0][1]) % (mod - 1);
}
il ll f2(ll n)
{
    if(n == 1) return 0;
    node T = (node){{
        {1, 1, 1},
        {1, 0, 0},
        {0, 0, 1}
    }};
    node A = (node){{
        {0, 0, 0},
        {0, 0, 0},
        {1, 0, 0}
    }};
    node B = (node){{
        {1, 0, 0},
        {0, 1, 0},
        {0, 0, 1}
    }};
    len = 3;
    node ans = matpow(B, T, n - 1) * A;
    return ans.a[0][0];
}
il ll fastpow(ll a, ll b)
{
    ll ans = 1;
    for(;b > 0;b >>= 1, a = a * a % mod)
        if(b & 1) ans = ans * a % mod;
    return ans;
}
int main()
{
    ll n;cin >> n;
//  cout << f1(n) << " " << f2(n) << "\n";
    if(n == 1) cout << 2;
    else cout << fastpow(2, f1(n)) * fastpow(3, f2(n)) % mod;
    return 0;
}