【题解】P8362
题解
考虑如何统计答案,从
考虑如何计算组合数,不难发现组合数是关于
时间复杂度
如果采用前缀和优化或改为记录
代码
#include <bits/stdc++.h>
const int mod = 998244353;
struct Int
{
int a[1010];
inline int & operator [] (int x) {return a[x];}
inline bool friend operator < (Int a, Int b) {for (int i = 1002; ~i; i--) if (a[i] != b[i]) return a[i] < b[i]; return false;}
inline Int friend operator + (Int a, int b) {a[0] += b; for (int i = 0; i <= 1002; i++) a[i + 1] += a[i] > 9, a[i] -= 10 * (a[i] > 9); return a;}
inline Int friend operator + (Int a, Int b) {for (int i = 0; i <= 1002; i++) a[i] += b[i], a[i + 1] += a[i] > 9, a[i] -= 10 * (a[i] > 9); return a;}
inline Int friend operator - (Int a, Int b) {for (int i = 0; i <= 1002; i++) a[i] -= b[i], a[i + 1] -= a[i] < 0, a[i] += 10 * (a[i] < 0); return a;}
inline Int friend operator * (Int a, int b) {a[0] *= b; for (int i = 1; i <= 1002; i++) a[i] = a[i] * b + a[i - 1] / 10, a[i - 1] %= 10; return a;}
inline int friend operator % (Int a, int p) {int res = 0; for (int i = 1002; ~i; i--) res = (10LL * res + a[i]) % p; return res;}
inline void read (void) {std::string s; std::cin >> s; for (int i = 0; i < (int)s.length(); i++) a[s.length() - i - 1] = s[i] ^ 48;}
} L, R, Q[2];
int k, wl, wr, sgn[2], inv[60], P[12][60], C[60][60], S[60][60], tr[60][60][10], Pow[60], f[1010][10][2][60], g[60], h[60], ans;
inline void dp (int *g, Int Q)
{
for (int j = 1002, p = 1; j >= 1; j--) f[j][9][p &= (Q[j] == 0)][0] = 1;
for (int j = 1001; j >= 0; j--) for (int x = 0; x <= 9; x++)
for (int p = 0; p < 2; p++) for (int t = 0; t < k; t++) if (f[j + 1][x][p][t])
for (int y = 0; y <= x and y <= (p ? Q[j] : 9); y++)
for (int r = t, q = p & (y == Q[j]); r < k; r++) f[j][y][q][r] = (f[j][y][q][r] + 1LL * tr[r][t][y] * f[j + 1][x][p][t]) % mod;
for (int x = 0; x <= 9; x++) for (int p = 0; p < 2; p++) for (int r = 0; r < k; r++) g[r] = (g[r] + f[0][x][p][r]) % mod;
for (int j = 1002; ~j; j--) for (int x = 0; x <= 9; x++) for (int p = 0; p < 2; p++) for (int r = 0; r < k; r++) f[j][x][p][r] = 0;
}
signed main ()
{
L.read(); R.read(); std::cin >> k; wl = L % mod; wr = R % mod;
inv[0] = inv[1] = sgn[0] = 1; sgn[1] = mod - 1;
for (int i = 2; i <= k; i++) inv[i] = 1LL * (mod - mod / i) * inv[mod % i] % mod;
for (int i = 3; i <= k; i++) inv[i] = 1LL * inv[i - 1] * inv[i] % mod;
for (int i = 0; i <= 10; i++) for (int j = P[i][0] = 1; j <= k; j++) P[i][j] = 1LL * P[i][j - 1] * i % mod;
for (int i = 0; i <= k; i++) for (int j = C[i][0] = 1; j <= i; j++) C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % mod;
for (int i = S[0][0] = 1; i <= k; i++) for (int j = 1; j <= i; j++) S[i][j] = (S[i - 1][j - 1] + 1LL * (i - 1) * S[i - 1][j]) % mod;
for (int i = 0; i < k; i++) for (int j = 0; j <= i; j++) for (int k = 0; k <= 9; k++) tr[i][j][k] = 1LL * C[i][j] * P[10][j] % mod * P[k][i - j] % mod;
Q[0] = L * k + (-1); Q[1] = R * k; R = R - L + 1; dp(h, Q[1]);
for (int i = 0; i <= k and Q[0] < Q[1]; i++, Q[0] = Q[0] + R)
{
dp(g, Q[0]); for (int j = 0; j < k; j++) g[j] = (h[j] - g[j] + mod) % mod;
int w = 1LL * sgn[i & 1] * C[k][i] % mod * inv[k - 1] % mod;
Pow[0] = 1; Pow[1] = (mod - (1LL * k * (wl - 1) + 1LL * i * (wr - wl + 1) + 1) % mod) % mod;
for (int j = 2; j < k; j++) Pow[j] = 1LL * Pow[j - 1] * Pow[1] % mod;
for (int j = k - 1; j >= 0; j--) for (int r = 0; r < j; r++) g[j] = (g[j] + 1LL * C[j][r] * g[r] % mod * Pow[j - r]) % mod;
for (int j = 0; j < k; j++) ans = (ans + 1LL * w * sgn[(k - 1 - j) & 1] % mod * S[k - 1][j] % mod * g[j]) % mod, g[j] = 0;
}
return !printf("%d\n", ans);
}