题解:SP10649 MYQ10 - Mirror Number

· · 题解

Mirror Number是只有数字 0,1,8 的回文数字,更严谨的定义是:

定义 Mirror Number 是一个自然数 n,满足:设

n=a_0+10a_1+10^2a_2+\cdots+10^ma_m,\ m\in\N,a_i\in\{0,1,2,3,4,5,6,7,8,9\}\ (\forall 0\le i\le m\land i\in\N),

那么对于任意 0\le i\le m\land i\in\N,有 a_i\in\{0,1,8\},且 a_i=a_{m-i}

dp[i] 是位数为 i 的 Mirror Number 的数量(没有前导零,所以这里 0 的位数是 0),可以推出

dp_i=\begin{cases}2\times 3^{\lceil i/2\rceil-1}&\text{for}\ i>0\\1&\text{otherwise}\end{cases}

接下来考虑计算答案,定义一个函数 get(x) 表示区间 [0,x] 内 Mirror Number 的数量,如果 x=0,那么直接返回 1。否则,首先考虑位数小于 x 的位数 \mathrm{len} 的数,它们对返回值的贡献是 \displaystyle\sum_{i=0}^{\mathrm{len}-1}dp_i;然后考虑位数为 \mathrm{len} 的数,从最高位(第 1 高位)开始,枚举到第 \lceil \mathrm{len}/2\rceil 高位,对于枚举的第 i 高位,从小到大填 0,1,8 中的数字,在填的数字小于 x 的第 i 高位数字 x_i 时,将答案增加 3^{\lceil r/2\rceil},其中 r=\mathrm{len}-2(i+1),然后,如果填数字到大于 x_i 时,直接返回答案,否则将这个填的数字拼接到一个新的数字 s(初始为 0)的末尾,然后继续枚举数位,当枚举结束后,如果 len 是奇数,将 s 的除去最后一位的数字的反转拼接到 s 末尾,否则将 s 整体反转得到的数字拼接到 s 末尾,如果发现 s\le x,则将答案增加 1,因为在枚举数位的时候没有统计这个数字的答案。

还需要一个函数 check(x),检查 x 是不是 Mirror Number,这直接根据定义去判断即可。

当输入 a,b 时,最终答案等于 get(b)-get(a)+check(a),由于题目范围中 Mirror Number 的自由度不超过 22,所以 get 返回值不超过 3^{22},可以用 long long 返回。

时间复杂度是 \mathcal{O}(T\lg a+T\lg b)=\mathcal{O}(T\lg b)。(耗时 0.43\operatorname{s}

代码:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
int T, nxt[9] = {1, 8, 0, 0, 0, 0, 0, 0, 10};
string a, b;
ll dp[45];
bool check(string x){
    for(int i=0;i*2<=x.length();i++){
        if (x[i] != x[x.length()-i-1] || (x[i] != '0' && x[i] != '1' && x[i] != '8'))
            return 0;
    }
    return 1;
}
ll pow3(ll y){
    return y?pow3(y-1)*3:1;
}
ll get(string x){
    if (x.length() == 1 && x[0] == '0') return 1;
    // [0, x]
    ll len = x.length(), ans = 0;
    // 小于len位
    for (int i=0;i<len;i++) ans += dp[i];
    // 等于len位
    string s;
    for(int i=0;i*2+1<=len;i++){
        int num = x[i] - '0', c = i==0;
        while(c < num){
            // 已用长度:i+1 => ans += 3^(len-2*(i+1) )
            if (2*i+1 == len) ans++;
            else {
                int re = len - 2*(i+1);
                ans += pow3(re/2 + re%2);
            }
            c = nxt[c];
        }
        if (c == num) s += x[i];
        else return ans;
    }
    string t = s.substr(0, s.length() - len%2);
    reverse(t.begin(), t.end());
    s += t;
    ans += (s <= x);
    return ans;
}

int main(){
    cin>>T;
    dp[0] = 1;
    for(int i=1;i<=44;i++){
        dp[i] = 2 * pow3(i/2 + i%2 - 1);
    }
    while(T--){
        cin>>a>>b;
        cout << get(b) - get(a) + check(a) << endl;
    }
}