洛谷 P10167 [DTCPC 2024] 小方学乘法 题解

· · 题解

同步发表于 个人博客。

想法比较直接的优化 dp + 插值。

首先由于 x 的插入会导致原来的数的错位,因此一个显然的想法是对于位数不同的 x 分别处理。

然后注意到当 x 的位数固定时,对于任意一种 x 和乘号的填法,得到的乘法算式中的每一个数的值都是一个常数或者关于 x 的一次函数。于是将它们乘起来就会得到一个 xO(n) 次多项式。对所有填入乘号和 x 的方法求和并不改变多项式的次数。至此我们可以考虑对若干个 x 单独求值并通过多项式插值快速求和。

我们需要对一个固定的 x(记这个 x 的位数为 l)在 O(n) 时间内计算答案。考虑一个 dp,设 f_i 为考虑到第 i 个问号之前的部分,所有填法的计算结果的和。暴力的转移是枚举最后一个乘号的位置(设为 j),并将 f_j 乘上后面所有位置填 x 构成的数转移到 f_i。这样是 O(n^2) 的,考虑优化。

注意到 j 转移到的位置从 i 变为 i+1 时,f_j 乘上的值的变化一定是先乘上「x 的位数」个 10,加上 x,再乘上「第 i 个问号和第 i+1 个问号之间的数字串长度」个 10,再加上这串数字构成的值。所以我们在转移过程中维护 f_j 的和以及 f_j 乘上后面的值的和两个值,就可以 O(n) 完成上述 dp 了。具体可以看代码。

对每一个位数取出 O(n) 个点值,求前缀和之后插值即可。暴力插值或线性插值均可。复杂度 O(n^2\log R)

#include <bits/stdc++.h>
using namespace std;

#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;

inline int qread() {
    char c = getchar();
    int x = 0, f = 1;
    while (c < '0' || c > '9') {
        if (c == '-') f = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') {
        x = (x << 3) + (x << 1) + c - 48;
        c = getchar();
    }
    return x * f;
}

const long long mod = 1000000007;
const int N = 2077, L = 18;

int n;
char str[N];
vector <pair <long long, int> > val;
long long f[N], pw10[N], p10r[N], sl, sr, x[N], y[N], inv[N];

inline void Prefix() {
    long long cur = 0, len = 0;
    for (int i = 1;i <= n;i++) {
        if (str[i] == '?') {
            val.push_back(make_pair(cur, len));
            cur = 0; len = 0;
        } else {
            cur = (cur * 10 + str[i] - '0') % mod;
            len++;
        }
    }
    val.push_back(make_pair(cur, len));
    pw10[0] = 1;
    for (int i = 1;i <= n + L;i++) pw10[i] = pw10[i - 1] * 10 % mod;
    p10r[0] = 1;
    for (int i = 1;i <= L;i++) p10r[i] = p10r[i - 1] * 10;
    inv[1] = 1;
    for (int i = 2;i <= n + 10;i++) inv[i] = (mod - mod / i) * inv[mod % i] % mod;
}

inline long long getAns(long long x, long long xln) {
    x %= mod;
    long long fsm = 0, fpsm = 0;
    f[0] = val[0].first;
    fsm = 1;
    fpsm = val[0].first;
    for (int i = 1;i < val.size();i++) {
        fpsm = fpsm * pw10[xln] % mod;
        fpsm = (fpsm + x * fsm) % mod;
        fpsm = fpsm * pw10[val[i].second] % mod;
        fpsm = (fpsm + val[i].first * fsm) % mod;
        f[i] = (fpsm + f[i - 1] * val[i].first) % mod;
        fsm = (fsm + f[i - 1]) % mod;
        fpsm = (fpsm + f[i - 1] * val[i].first) % mod;
    }
    return f[val.size() - 1];
}

inline long long Inv(long long x) {
    if (x < 0) return mod - inv[-x];
    else return inv[x];
}

inline void Solve() {
    long long ans = 0;
    for (int i = 1;i <= L;i++) {
        long long vl = p10r[i - 1], vr = p10r[i] - 1;
        vl = max(vl, sl); vr = min(vr, sr);
        if (vl > vr) continue;
        long long cnt = val.size() + 5;
        if (vr - vl + 1 <= cnt) {
            for (long long j = vl;j <= vr;j++) ans = (ans + getAns(j % mod, i)) % mod;
        } else {
            y[0] = 0;
            for (long long j = vl;j <= vl + cnt - 1;j++) {
                x[j - vl + 1] = j;
                y[j - vl + 1] = getAns(j % mod, i);
            }
            for (int j = 1;j <= cnt;j++) y[j] = (y[j] + y[j - 1]) % mod;
            for (int j = 1;j <= cnt;j++) {
                long long cur = y[j];
                for (int k = 1;k <= cnt;k++) {
                    if (k == j) continue;
                    cur = cur * (vr % mod - x[k] % mod) % mod * Inv(x[j] - x[k]) % mod;
                }
                ans = (ans + cur) % mod;
            }
        }
    }
    cout << (ans % mod + mod) % mod << endl;
}

int main() {
    cin >> str + 1; n = strlen(str + 1);
    cin >> sl >> sr;
    Prefix();
    Solve();
    return 0;
}