题解:P10986 [蓝桥杯 2023 国 Python A] 2023

· · 题解

前情提要:这是一篇瞎搞题解,想看组合数学/二项式反演的解法还是看别的叭。

前置知识:KMP、NTT

题目要求我们计算,给定 n 位十进制数中,恰好包含 𝑚 次子串 2023 的数的个数。由于结果可能很大,我们需要对答案取模 998244353

那我们有一个很逆天自然的解法——计算数字序列中不含 2023 的方案数。

我们可以利用 KMP 来避免自己在构造不含 2023

然后进行动态规划。

定义 dp_{len,s} 表示长度为 len 且处于状态 s 的合法序列数,其中 s 表示在有限状态自动机中的状态,有转移:

dp_{len,s}=\sum_{d=0}^{9}dp_{len-1,t}

其中:

t=nxt_{s,d}

定义生成函数 F(x) 表示长度为 x 的数字序列中不包含 2023 的方案数,记为:

F(x)=\sum_{k=0}^{\infty}x^k\sum_{s=0}^{3}dp_{k,s}

定义 R 为选去 2023 后的剩余位置,易见:

R=n-4m

然后,剩余数字分布在 m+1 个区间中,每个区间内的数字序列均不能包含 2023,其方案数由生成函数 F(x) 给出,其 F(x) 表示长度为 x 的合法数字序列数。

对于 R<0 的情况直接输出 0 即可。对于 m=0 直接输出预处理后的 F(n) 即可。

接下来着重讲对于 m>0 的情况:

由于整体方案数是 m+1 个区间的数字序列组合,即等价于多项式 F^{m+1},答案为该多项式中 x^R 项的系数。由于题目仅要求我们求第 R 项的系数,在实现中可以仅保留 R 项。

考虑多项式快速幂,由于 n\leq 10^5 较大。考虑使用 NTT 优化多项式乘法。

然后多项式快速幂调用即可。能够通过此题。

代码见下:

#include <bits/stdc++.h>
using namespace std;
const int mod = 998244353;
int qpow (int a, int b) {
    int res = 1;
    while(b > 0){
        if(b & 1)
            res = (int)((1LL * res * a) % mod);
        a = (int)((1LL * a * a) % mod);
        b >>= 1;
    }
    return res;
}
void ntt(vector<int> & a, bool inv) {
    int n = a.size();
    for (int i = 1, j = 0; i < n; i++) {
        int bit = n >> 1;
        for (; j & bit; bit >>= 1)
            j -= bit;
        j += bit;
        if (i < j)
            swap(a[i], a[j]);
    }
    for (int len = 2; len <= n; len <<= 1) {
        int wlen = qpow(3, (mod - 1) / len);
        if (inv)
            wlen = qpow(wlen, mod - 2); 
        for (int i = 0; i < n; i += len) {
            int w = 1;
            for (int j = 0; j < len/2; j++) {
                int u = a[i+j];
                int v = (int)((1LL * a[i+j+len/2] * w) % mod);
                a[i+j] = u + v < mod ? u + v : u + v - mod;
                a[i+j+len/2] = u - v >= 0 ? u - v : u - v + mod;
                w = (int)((1LL * w * wlen) % mod);
            }
        }
    }
    if (inv) {
        int n_inv = qpow(n, mod - 2);
        for (int & x : a)
            x = (int)((1LL * x * n_inv) % mod);
    }
}
vector<int> poly_mul(const vector<int>& a, const vector<int>& b, int rLimit) {
    int n = a.size(), m = b.size();
    int sz = 1;
    while (sz < n + m - 1)
        sz *= 2;
    vector<int> fa(a.begin(), a.end()), fb(b.begin(), b.end());
    fa.resize(sz);
    fb.resize(sz);
    ntt(fa, false);
    ntt(fb, false);
    for (int i = 0; i < sz; i++) {
        fa[i] = (int)((1LL * fa[i] * fb[i]) % mod);
    }
    ntt(fa, true);
    int newSize = min((int)(n + m - 1), rLimit + 1);
    fa.resize(newSize);
    return fa;
}
vector<int> poly_pow(vector<int> poly, int b, int rLimit) {
    vector<int> result(1, 1); 
    while(b > 0) {
        if(b & 1)
            result = poly_mul(result, poly, rLimit);
        poly = poly_mul(poly, poly, rLimit);
        b >>= 1;
    }
    return result;
}
void kmp(vector<vector<int>> &nxt) {
    string pat = "2023";
    int M = pat.size();
    vector<int> pi(M, 0);
    for (int i = 1; i < M; i++) {
        int j = pi[i-1];
        while(j > 0 && pat[i] != pat[j])
            j = pi[j-1];
        if(pat[i] == pat[j])
            j++;
        pi[i] = j;
    }
    nxt.assign(M, vector<int>(10, 0));
    for (int s = 0; s < M; s++) {
        for (int d = 0; d < 10; d++) {
            char c = '0' + d;
            int x = s;
            while(x > 0 && pat[x] != c)
                x = pi[x-1];
            if(pat[x] == c)
                x++;
            if(x == M) {
                nxt[s][d] = -1;
            } else {
                nxt[s][d] = x;
            }
        }
    }
}
vector<int> computeF(int L_max) {
    int M = 4;
    vector<vector<int>> nxt;
    kmp(nxt);
    vector<vector<int>> dp(L_max+1, vector<int>(M, 0));
    dp[0][0] = 1;
    vector<int> F(L_max+1, 0);
    F[0] = 1;
    for (int len = 0; len < L_max; len++) {
        for (int s = 0; s < M; s++) {
            if (dp[len][s] == 0) continue;
            for (int d = 0; d < 10; d++) {
                int ns = nxt[s][d];
                if (ns == -1)
                    continue;
                dp[len+1][ns] = (dp[len+1][ns] + dp[len][s]) % mod;
            }
        }
        int sum = 0;
        for (int s = 0; s < M; s++)
            sum = (sum + dp[len+1][s]) % mod;
        F[len+1] = sum;
    }
    return F;
}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n, m;
    cin >> n >> m;
    if(4 * m > n) {
        cout << 0 << "\n";
        return 0;
    }
    if(m == 0) {
        vector<int> F = computeF(n);
        cout << F[n] % mod << "\n";
        return 0;
    }
    int R = n - 4 * m;
    vector<int> F = computeF(R);
    vector<int> polyPow = poly_pow(F, m + 1, R);
    int ans = 0;
    ans = polyPow[R];
    cout << ans % mod << "\n";
    return 0;
}