[SNOI2022] 数位 题解

· · 题解

看到还没有题解,赶紧偷跑一篇。

yysy,看到出题人在知乎吐槽考场上没有什么人尝试这道题,就来看了一眼。但是我做完之后,也确实感觉,如果我在考场上也不会去尝试写这道题,因为真的太难写了!!!!!!

但是总体来说作为一道省选的压轴题,代码难一点也无可厚非。

思路上来说,感觉没人特别难想的地方,作为省队水平选手大概 1020 分钟应该就能想到大致思路。

解题思路

首先我们先看看如果确定了 \sum{a_i}=S 是什么,怎么计算答案。

众所周知,如果把限制改成 a_i \in [L,+\infty) 的做法是直接使用隔板法,答案为 {S-k(L-1)-1}\choose{k-1}

但是现在限制是 a_i \in [L,R],直接做显然是不好做的。那我们考虑容斥。枚举有多少 a_i 突破了 R 的限制,此时答案为:

\sum_{i=0}^{k}{(-1)^i{{k}\choose{i}}{{S-k(l-1)-i(R-L+1)-1}\choose{k-1}}}

既然有了这种做法,我们的思路就可以转变为枚举 i,用数位 dp 计算 {{S-k(l-1)-i(R-L+1)-1}\choose{k-1}} 的和。

这个东西显然也是不好算的,但是如果我们将 S-k(l-1)-i(R-L+1)-1 设为 x。我们知道的是 x \choose k-1 可以写成 \sum_{i=0}^{k-1}{a_ix^i} 的形式,那么我们只需要计算出所有 x 的所有 k 次幂和即可求得答案。考虑到 k 次幂和之间的递推可以直接使用二项式定理,所以我们的数位 dp 可以直接记录,考虑到第几位,是否需要借位,上一个填的数是什么,当前维护的是几次幂和。这样的时间复杂度大概是 O(10k^3|R|),感觉很卡,但是实测是可过的。(其实是我本地没跑过感觉非常慌,但是交上去就过了)

下面是代码。QwQ

#include <stdio.h>
#include <bits/stdc++.h>

using namespace std;

const int N = 1110, M = 1010;
const int P = 998244353;
char s[N];

inline int add(int x, int y) {return x + y >= P ? x + y - P : x + y;}
inline int sub(int x, int y) {return x < y ? x - y + P : x - y;}
inline int mul(int x, int y) {return 1ll * x * y % P;}
inline void upd(int &x, int y) {x = add(x, y);}
inline int Pow(int x, int y) {int r = 1; for (; y; y >>= 1, x = mul(x, x)) if (y & 1) r = mul(r, x); return r;}

struct BigInt {
    int a[N], len;
    BigInt() {
        len = 0;
        memset(a, 0, sizeof a);
    }
    void input() {
        scanf("%s", &s);
        len = strlen(s);
        for (int i = 0; i < len; i++)
            a[i] = s[len - i - 1] - '0';
    }
};

BigInt operator + (BigInt x, BigInt y) {
    BigInt ret; ret.len = max(x.len, y.len);
    for (int i = 0; i < ret.len; i++)
        ret.a[i] = x.a[i] + y.a[i];
    for (int i = 0; i < ret.len; i++)
        if (ret.a[i] > 9) ret.a[i] -= 10, ret.a[i + 1]++;
    while (ret.a[ret.len]) ret.len++;
    return ret;
}

BigInt operator - (BigInt x, BigInt y) {
    BigInt ret; ret.len = x.len;
    for (int i = 0; i < ret.len; i++)
        ret.a[i] = x.a[i] - y.a[i];
    for (int i = 0; i < ret.len; i++)
        if (ret.a[i] < 0) ret.a[i] += 10, ret.a[i + 1]--;
    while (ret.len && !ret.a[ret.len - 1]) ret.len--;
    return ret;
}

BigInt L, R, AL;
int K, C[52][52], ans, dp[N][10][2][52], sdp[N][10][2][52], pw[10][52], pw10[N * 52], coe[52];

int main() {
    L.input(), R.input(), scanf("%d", &K);
    L.a[0]--;
    for (int i = 0; ; i++) {
        if (L.a[i] < 0) L.a[i] += 10, L.a[i + 1]--;
        else break;
    }
    while (L.len && !L.a[L.len - 1]) L.len--;
    R = R - L; AL.len = 1, AL.a[0] = 1;
    for (int i = 1; i <= K; i++) AL = AL + L;
    for (int i = 0; i <= K; i++) {
        C[i][0] = C[i][i] = 1;
        for (int j = 1; j < i; j++)
            C[i][j] = add(C[i - 1][j - 1], C[i - 1][j]);
    }
    for (int i = 0; i < 10; i++) {
        pw[i][0] = 1;
        for (int j = 1; j <= K; j++)
            pw[i][j] = mul(pw[i][j - 1], i);
    }
    pw10[0] = 1;
    for (int i = 1; i < N * 52; i++)
        pw10[i] = mul(pw10[i - 1], 10);
    coe[1] = 1;
    for (int i = 1; i < K - 1; i++) {
        for (int j = K; j > 0; j--)
            coe[j] = coe[j - 1];
        coe[0] = 0;
        for (int j = 0; j < K; j++)
            coe[j] = sub(coe[j], mul(i, coe[j + 1]));
    }
    if (K == 1) coe[1] = 0, coe[0] = 1;
    for (int T = 0; T <= K; T++) {
        for (int i = 0; i < K; i++)
            for (int j = 0; j < 10; j++)
                for (int o = 0; o < 2; o++)
                    dp[0][j][o][i] = 0;
        for (int i = 0; i < K; i++)
            for (int j = 0; j < 10; j++) {
                if (AL.a[0] > j) dp[0][j][1][i] = pw[j - AL.a[0] + 10][i];
                else dp[0][j][0][i] = pw[j - AL.a[0]][i];
            }
        int temp = 0;
        for (int i = 0; i < M; i++) {
            if (i + 1 >= AL.len)
                for (int k = 1; k < 10; k++)
                    for (int o = 0; o < K; o++)
                        upd(temp, mul(coe[o], dp[i][k][0][o]));
            for (int j = 0; j < 2; j++)
                for (int k = 0; k < 10; k++)
                    for (int o = 0; o < K; o++) {
                        sdp[i][k][j][o] = dp[i][k][j][o];
                        if (k > 0) sdp[i][k][j][o] = add(sdp[i][k][j][o], sdp[i][k - 1][j][o]);
                    }
            for (int j = 0; j < 2; j++)
                for (int k = 0; k < 10; k++)
                    for (int o = 0; o < K; o++)
                        dp[i + 1][k][j][o] = 0;
            int num = AL.a[i + 1];
            for (int j = 0; j < 2; j++)
                for (int o = 0; o < K; o++)
                    for (int p = 0; p < 10; p++) {
                        if (!sdp[i][p][j][o]) continue;
                        int tr = sdp[i][p][j][o];
                        for (int m = o; m < K; m++) {
                            if (p - j < num) {
                                int tmp = p - j - num + 10;
                                upd(dp[i + 1][p][1][m], mul(tr, mul(C[m][o], mul(pw[tmp][m - o], pw10[(i + 1) * (m - o)]))));
                            }
                            else {
                                int tmp = p - j - num;
                                upd(dp[i + 1][p][0][m], mul(tr, mul(C[m][o], mul(pw[tmp][m - o], pw10[(i + 1) * (m - o)]))));
                            }
                        }
                }
        }
        for (int k = 1; k < 10; k++)
            for (int o = 0; o < K; o++)
                upd(temp, mul(coe[o], dp[M][k][0][o]));
        for (int i = 2; i < K; i++)
            temp = mul(temp, Pow(i, P - 2));
        temp = mul(temp, C[K][T]);
        if (T & 1) ans = sub(ans, temp);
        else ans = add(ans, temp);
        AL = AL + R;
    }
    printf("%d\n", ans);
    return 0;
}