【题解】Blinker 的仰慕者

· · 题解

(感谢洛谷第一篇题解)

暴力 dp 是 f(i,j) 表示 ni 位(设最高位为第 n 位)剩下的数乘积需要为 j 的方案数,g(i,j) 表示 ni 位剩下的数乘积需要为 j 的所有数和。但是每一位的数都不会大于 9,所以乘积的质因数一定只有 2,3,5,7,所以可以把所有合法状态都预处理出来,然后状态里就只用开合法状态个数的大小(6\text e 4 个左右)。正常记忆化深搜即可。转移如下(转移方程是暴力 dp 的转移,自己把它转化成 map 离散化的即可):

\begin{aligned} &f(i,j)=\sum_{x\in[1..9]}f\left(i-1,\frac{j}{x}\right)\\ &g(i,j)=\sum_{x\in[1..9]}f\left(i-1,\frac{j}{x}\right)\times x\times 10^{i-1}+g\left(i-1,\frac{j}{x}\right) \end{aligned}

有几个需要注意的点:

  1. 记忆化深搜里,除了要记 lim 表示是否卡到上界,还要记 st 表示现在还是不是前导零阶段;如果是,x 就可以继续取 0,反之则不行(如果数中间取 0,乘积就变成 0 了);
  2. 因为本题有点卡时间,所以要预处理一些东西和剪枝:预处理 to[i][x] 表示 map 的第 i 项如果除以 xx\in[0..9])会变成 map 的哪一项,和 stp[i] 表示 map 的第 i 项最少需要多少步才能除到 1(这个东西可以枚举这一步除以什么,然后 dp 转移);在深搜的时候,如果发现剩下能走的次数小于 stp[s],就直接剪枝;
  3. 本题中有 k=0 的情况很恶心,需要单独再写一个数位 DP。f(i,j) 表示前 i 位有 j0 的方案数,g(i,j) 表示前 i 位有 j0 的所有数的和,转移跟本来的 DP 类似。
#include <cstdio>
#include <map>
#include <cstring>
#include <cassert>
#include <algorithm>
typedef long long lld;
const lld mod = 20120427;
const lld ths = 1e18;
const int maxl = 18 + 1;
const int maxs = 1e5 + 1;
const int maxc = 1e7 + 1;
struct Result { lld f, g; };
struct Node {
    int p1, p2, p3, p4;
    bool operator <(const Node &x) const {
        return p1 != x.p1 ? p1 < x.p1 : (p2 != x.p2 ? p2 < x.p2 : (p3 != x.p3 ? p3 < x.p3 : p4 < x.p4));
    }
};
int t, k1, k2, k3, k4;
lld l, r, k, pw[maxl];
Node mp[maxs];
int cnt = 0, n, a[maxl], to[maxs][10], stp[maxs], pstk = 0;
Result dp[maxl][maxs], dp0[maxl][maxl], stk[maxc];
int find(int p1, int p2, int p3, int p4) {
    return std::lower_bound(mp + 1, mp + cnt + 1, Node{ p1, p2, p3, p4 }) - mp;
}
Result dfs(int id, bool lim, bool st, int s, lld sum) {
    if (!id) return { !st && sum == 1, 0 };
    if (!lim && !st && ~dp[id][s].f) return dp[id][s];
    if (stp[s] > id) return { 0, 0 };
    int mxv = lim ? a[id] : 9;
    Result res = { 0, 0 }, tmp;
    for (int x = !st; x <= mxv; x++) {
        if (!x) res = dfs(id - 1, false, true, s, sum);
        else if (!(sum % x)) {
            tmp = dfs(id - 1, lim && x == mxv, false, to[s][x], sum / x);
            res.f = res.f + tmp.f;
            res.g = res.g + tmp.f * pw[id - 1] * x + tmp.g;
        }
    }
    res.f %= mod, res.g %= mod;
    if (!lim && !st) { dp[id][s] = res; stk[++pstk] = { id, s }; }
    return res;
}
Result dfs0(int id, bool lim, bool st, int cnt) {
    if (!id) return { !st && cnt, 0 };
    if (!lim && !st && ~dp0[id][cnt].f) return dp0[id][cnt];
    int mxv = lim ? a[id] : 9;
    Result res = { 0, 0 }, tmp;
    for (int x = 0; x <= mxv; x++) {
        tmp = dfs0(id - 1, lim && x == mxv, st && !x, cnt + (!st && !x));
        res.f = res.f + tmp.f;
        res.g = res.g + tmp.f * pw[id - 1] * x + tmp.g;
    }
    res.f %= mod, res.g %= mod;
    if (!lim && !st) dp0[id][cnt] = res;
    return res;
}
lld solve(lld x) {
    n = 0;
    while (x) a[++n] = x % 10ll, x /= 10ll;
    if (k) return dfs(n, true, true, find(k1, k2, k3, k4), k).g;
    else return dfs0(n, true, true, 0).g;
}
signed main() {
    int p1, p2, p3, p4; lld s1, s2, s3, s4;
    for (p1 = 0, s1 = 1; s1 <= ths; p1++, s1 <<= 1ll)
        for (p2 = 0, s2 = s1; s2 <= ths; p2++, s2 *= 3ll)
            for (p3 = 0, s3 = s2; s3 <= ths; p3++, s3 *= 5ll)
                for (p4 = 0, s4 = s3; s4 <= ths; p4++, s4 *= 7ll)
                    mp[++cnt] = { p1, p2, p3, p4 }, to[cnt][1] = cnt;
    std::sort(mp + 1, mp + cnt + 1);
    for (int i = 1; i <= cnt; i++) {
        if (mp[i].p1) to[i][2] = find(mp[i].p1 - 1, mp[i].p2, mp[i].p3, mp[i].p4);
        if (mp[i].p2) to[i][3] = find(mp[i].p1, mp[i].p2 - 1, mp[i].p3, mp[i].p4);
        to[i][4] = to[to[i][2]][2];
        if (mp[i].p3) to[i][5] = find(mp[i].p1, mp[i].p2, mp[i].p3 - 1, mp[i].p4);
        to[i][6] = to[to[i][2]][3];
        if (mp[i].p4) to[i][7] = find(mp[i].p1, mp[i].p2, mp[i].p3, mp[i].p4 - 1);
        to[i][8] = to[to[i][2]][4];
        to[i][9] = to[to[i][3]][3];
        stp[i] = maxs;
        for (int x = 2; x < 10; x++)
            if (to[i][x]) stp[i] = std::min(stp[i], stp[to[i][x]] + 1);
        if (stp[i] == maxs) stp[i] = 0;
    }
    pw[0] = 1;
    for (int i = 1; i < maxl; i++) pw[i] = pw[i - 1] * 10 % mod;
    for (int i = 1; i < maxl; i++)
        for (int j = 0; j < maxl; j++) dp0[i][j].f = -1;
    for (int i = 1; i < maxl; i++)
        for (int j = 1; j <= cnt; j++) dp[i][j].f = -1;
    scanf("%d", &t);
    while (t--) {
        scanf("%lld%lld%lld", &l, &r, &k);
        lld tmp = k;
        if (k) {
            k1 = k2 = k3 = k4 = 0;
            while (!(tmp & 1)) k1++, tmp >>= 1;
            while (!(tmp % 3)) k2++, tmp /= 3;
            while (!(tmp % 5)) k3++, tmp /= 5;
            while (!(tmp % 7)) k4++, tmp /= 7;
        }
        if (tmp > 1) puts("0");
        else printf("%lld\n", (solve(r) - solve(l - 1) + mod) % mod);
        Result tp;
        while (pstk) {
            tp = stk[pstk--];
            dp[tp.f][tp.g].f = -1;
        }
    }
    return 0;
}