题解:P13831 【MX-X18-T3】「FAOI-R6」比亚多西

· · 题解

应该能感觉到 f 存在递推关系。打个表可以发现 f(i)=f(i-1) + \lfloor \log_2 i \rfloor + 1

::::info[证明]

考虑题目中不断二分查找的过程可以用一棵 [1,n] 的二叉搜索树来刻画,每个结点的深度就是需要二分查找的次数。

例如对于 n=7,8,9

对于区间 [1, n] 的二分查找树:根节点是 m = \lfloor \frac{1+n}{2} \rfloor,左子树是 [1, m-1] 的二分查找树,右子树是 [m+1, n] 的二分查找树。

当区间从 [1,n-1] 扩展到 [1,n] 时,新插入的结点是 n,同时结点 n 的插入位置一定在树的最右侧路径上,深度为 \lfloor \log_2 n \rfloor+1,即查找次数 c_n = \lfloor \log_2 n \rfloor+1(也是二分查找的最坏次数)。

所以 f(n)=f(n-1)+c_n=f(n-1) + \lfloor \log_2 n \rfloor + 1。证毕。 ::::

题目要求 f 的区间和,直接算不好算,可以考虑转化为前缀和相减的形式。设 s(n) = \sum_{i=1}^n f(i),则题目要求的 \sum_{i=L}^R f(i) = s(R)-s(L-1)

化简 s(n) 可以考虑对每个 \lfloor \log_2 i \rfloor+1 单独计算贡献,再加起来:

s(n) = \sum_{i=1}^n f(i) = \sum_{i=1}^n \bigg((\lfloor \log_2 i \rfloor + 1) \times (n-i+1) \bigg)

发现 \lfloor \log_2 i \rfloor+1 单调不降,且每个值都会重复许多次,所以考虑算出每个值重复出现的次数,从而减少运算(类似整除分块)。

具体地,发现 \log_2 i + 1 = k 对应 i \in [l=2^{k-1}, r=\min\{2^{k}-1,n\}],设 K=\lfloor \log_2 n \rfloor + 1,可得:

s(n) = \sum_{k=1}^{K} \left(k \times \sum_{i=l}^{r} (n-i+1) \right)

发现乘积后面一项是一个明显的等差数列求和:

\begin{align*} \sum_{i=l}^{r} (n-i+1) &= \frac{[(n-l+1)+(n-r+1)] \times (r-l+1)}{2} \\ &= \frac{(2n-l-r+2)(r-l+1)}{2} \end{align*}

所以最终答案即为:

s(n) = \sum_{k=1}^{K} \frac{k \times (2n-l-r+2)(r-l+1)}{2}

时间复杂度为 O(\sum \log R)

这道题的取模和数据大小很烦人,建议答案计数器使用 int128 类型(其他地方不要多用,会超时)。

#include <bits/stdc++.h>
#define int long long
#define int128 __int128_t
using namespace std;
const int mod = 998244353, inv2 = 499122177;

int s(int n){
    if(n == 0) return 0;
    int128 cnt = 0;
    for(int k = 1; 1LL << (k - 1) <= n; k ++){
        int l = 1LL << (k - 1);
        int r = min((1LL << k) - 1, n);
        int128 t = (2 * n - l - r + 2) % mod; 
        t = (t * (r - l + 1)) % mod;
        t = (t * inv2) % mod;   
        t = (t * k) % mod;
        cnt += t;
        if(cnt >= mod) cnt -= mod;
    }
    return cnt;
}

void write(int128 x){
    if(x < 0) putchar('-'), x = -x;
    if(x > 9) write(x / 10);
    putchar(x % 10 + '0');
}

void solve()
{
    int L, R; cin >> L >> R;
    write((s(R) - s(L - 1) + mod) % mod); putchar('\n');
}

signed main()
{
    ios :: sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr);
    int T; cin >> T;
    while(T --) solve();
    return 0;
}