洛谷 P10167 [DTCPC 2024] 小方学乘法 题解
同步发表于 个人博客。
想法比较直接的优化 dp + 插值。
首先由于
然后注意到当
我们需要对一个固定的
注意到
对每一个位数取出
#include <bits/stdc++.h>
using namespace std;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;
inline int qread() {
char c = getchar();
int x = 0, f = 1;
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = (x << 3) + (x << 1) + c - 48;
c = getchar();
}
return x * f;
}
const long long mod = 1000000007;
const int N = 2077, L = 18;
int n;
char str[N];
vector <pair <long long, int> > val;
long long f[N], pw10[N], p10r[N], sl, sr, x[N], y[N], inv[N];
inline void Prefix() {
long long cur = 0, len = 0;
for (int i = 1;i <= n;i++) {
if (str[i] == '?') {
val.push_back(make_pair(cur, len));
cur = 0; len = 0;
} else {
cur = (cur * 10 + str[i] - '0') % mod;
len++;
}
}
val.push_back(make_pair(cur, len));
pw10[0] = 1;
for (int i = 1;i <= n + L;i++) pw10[i] = pw10[i - 1] * 10 % mod;
p10r[0] = 1;
for (int i = 1;i <= L;i++) p10r[i] = p10r[i - 1] * 10;
inv[1] = 1;
for (int i = 2;i <= n + 10;i++) inv[i] = (mod - mod / i) * inv[mod % i] % mod;
}
inline long long getAns(long long x, long long xln) {
x %= mod;
long long fsm = 0, fpsm = 0;
f[0] = val[0].first;
fsm = 1;
fpsm = val[0].first;
for (int i = 1;i < val.size();i++) {
fpsm = fpsm * pw10[xln] % mod;
fpsm = (fpsm + x * fsm) % mod;
fpsm = fpsm * pw10[val[i].second] % mod;
fpsm = (fpsm + val[i].first * fsm) % mod;
f[i] = (fpsm + f[i - 1] * val[i].first) % mod;
fsm = (fsm + f[i - 1]) % mod;
fpsm = (fpsm + f[i - 1] * val[i].first) % mod;
}
return f[val.size() - 1];
}
inline long long Inv(long long x) {
if (x < 0) return mod - inv[-x];
else return inv[x];
}
inline void Solve() {
long long ans = 0;
for (int i = 1;i <= L;i++) {
long long vl = p10r[i - 1], vr = p10r[i] - 1;
vl = max(vl, sl); vr = min(vr, sr);
if (vl > vr) continue;
long long cnt = val.size() + 5;
if (vr - vl + 1 <= cnt) {
for (long long j = vl;j <= vr;j++) ans = (ans + getAns(j % mod, i)) % mod;
} else {
y[0] = 0;
for (long long j = vl;j <= vl + cnt - 1;j++) {
x[j - vl + 1] = j;
y[j - vl + 1] = getAns(j % mod, i);
}
for (int j = 1;j <= cnt;j++) y[j] = (y[j] + y[j - 1]) % mod;
for (int j = 1;j <= cnt;j++) {
long long cur = y[j];
for (int k = 1;k <= cnt;k++) {
if (k == j) continue;
cur = cur * (vr % mod - x[k] % mod) % mod * Inv(x[j] - x[k]) % mod;
}
ans = (ans + cur) % mod;
}
}
}
cout << (ans % mod + mod) % mod << endl;
}
int main() {
cin >> str + 1; n = strlen(str + 1);
cin >> sl >> sr;
Prefix();
Solve();
return 0;
}