P9129 [USACO23FEB] Piling Papers G

· · 题解

先考虑一次询问怎么做。首先容斥成 \leq B 的答案减去 \leq A-1 的答案,因此我们只考虑 \leq t 的答案怎么求。

一个朴素的想法是设 f_{i,j,0/1/2} 表示前 i 个数,从低位开始填了 j 位,此时和 t 的最低 j 位的大小关系是小于,等于或大于的方案数。但你发现,虽然在做 x\rightarrow\overline{x+a_i} 的时候我们可以 \mathcal{O}(1) 地得到大小关系,但是在 x\rightarrow\overline{a_i+x} 时并没有办法靠谱地确定,除非我们直接记录下当前填出的 x,但这样状态就爆炸了。

这给我们一些启示:在设计状态的时候,我们不能让 xt 上对应的区间移动。自然想到设 f_{i,l,r,0/1/2} 表示前 i 个数,形成的 r - l +1 位的数字 x,和 t 从高到低的第 l 位到第 r 位形成的数字 t' 的大小关系是小于,等于或大于的方案数,转移只需要考虑往哪个方向扩展。此时我们就可以只通过当前位确定出整个区间的大小关系了,于是可以 \mathcal{O}(1) 完成转移。

t 的位数为 d,那么最后的答案是 f_{n,1,d,0} + f_{n,1,d,1} + \sum \limits_{j=2}^d \sum \limits_{k=0}^2 f_{n,j,d,k},其中最后一项计算的是位数 <d 的方案数。

于是我们得到了一个 \mathcal{O}(qn \log^2 B) 的做法。但我们发现,每次 DP 的时候我们其实都计算出了每个前缀的答案,因此我们对于每个后缀做一次 DP,查询的时候用前缀和 \mathcal{O}(1) 查询,总时间复杂度就变成了 \mathcal{O}(n^2 \log^2 B + q),足以通过本题。

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 3e2 + 5, M = 20, mod = 1e9 + 7;
int n, m, q, a[N], lim[M], ansL[N][N], ansR[N][N], f[M][M][3]; LL L, R;
void getlim(LL x) {
    m = 0;
    while (x > 0) lim[++m] = x % 10, x = x / 10;
    reverse(lim + 1, lim + m + 1);
}
int chk(int a, int b) {
    if (a < b) return 0;
    else if (a == b) return 1;
    else return 2;
}
int main() { 
    ios :: sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> L >> R;
    for (int i = 1; i <= n; i++) cin >> a[i];
    getlim(L - 1);
    for (int i = 1; i <= n; i++) {
        memset(f, 0, sizeof(f));
        for (int j = i; j <= n; j++) {
            for (int x = 1; x <= m; x++) {
                for (int y = m; y > x; y--) {
                    if (a[j] > lim[x]) {
                        for (int k = 0; k <= 2; k++) f[x][y][2] = (f[x][y][2] + f[x + 1][y][k]) % mod;
                    } else if (a[j] == lim[x]) {
                        for (int k = 0; k <= 2; k++) f[x][y][k] = (f[x][y][k] + f[x + 1][y][k]) % mod;
                    } else {
                        for (int k = 0; k <= 2; k++) f[x][y][0] = (f[x][y][0] + f[x + 1][y][k]) % mod;
                    }
                    f[x][y][2] = (f[x][y][2] + f[x][y - 1][2]) % mod;
                    f[x][y][chk(a[j], lim[y])] = (f[x][y][chk(a[j], lim[y])] + f[x][y - 1][1]) % mod;
                    f[x][y][0] = (f[x][y][0] + f[x][y - 1][0]) % mod;
                }
            }
            for (int x = 1; x <= m; x++) f[x][x][chk(a[j], lim[x])] = (f[x][x][chk(a[j], lim[x])] + 2) % mod;
            for (int x = 1; x <= m; x++) {
                ansL[i][j] = (ansL[i][j] + f[x][m][1]) % mod;
                ansL[i][j] = (ansL[i][j] + f[x][m][0]) % mod;
                if (x > 1) ansL[i][j] = (ansL[i][j] + f[x][m][2]) % mod; 
            }
        }
    }
    getlim(R);
    for (int i = 1; i <= n; i++) {
        memset(f, 0, sizeof(f));
        for (int j = i; j <= n; j++) {
            for (int x = 1; x <= m; x++) {
                for (int y = m; y > x; y--) {
                    if (a[j] > lim[x]) {
                        for (int k = 0; k <= 2; k++) f[x][y][2] = (f[x][y][2] + f[x + 1][y][k]) % mod;
                    } else if (a[j] == lim[x]) {
                        for (int k = 0; k <= 2; k++) f[x][y][k] = (f[x][y][k] + f[x + 1][y][k]) % mod;
                    } else {
                        for (int k = 0; k <= 2; k++) f[x][y][0] = (f[x][y][0] + f[x + 1][y][k]) % mod;
                    }
                    f[x][y][2] = (f[x][y][2] + f[x][y - 1][2]) % mod;
                    f[x][y][chk(a[j], lim[y])] = (f[x][y][chk(a[j], lim[y])] + f[x][y - 1][1]) % mod;
                    f[x][y][0] = (f[x][y][0] + f[x][y - 1][0]) % mod;
                }
            }
            for (int x = 1; x <= m; x++) f[x][x][chk(a[j], lim[x])] = (f[x][x][chk(a[j], lim[x])] + 2) % mod;
            for (int x = 1; x <= m; x++) {
                ansR[i][j] = (ansR[i][j] + f[x][m][1]) % mod;
                ansR[i][j] = (ansR[i][j] + f[x][m][0]) % mod;
                if (x > 1) ansR[i][j] = (ansR[i][j] + f[x][m][2]) % mod; 
            }
        }
    }
    cin >> q;
    while (q--) {
        int l, r; cin >> l >> r;
        int ans = (ansR[l][r] - ansL[l][r] + mod) % mod;
        cout << ans << "\n";
    }
    return 0;
}