一种暴力O(1)解法

· · 题解

这道题目看到数据都会以为是 O(1) 复杂度的吧……

所以直接数学方法暴力求解。

1.求解答案

首先看题目,对于任意的 n,其三次方都能表示成 n 个连续奇数的和。

令集合 S_{n} =\left \{ x_{n}|f(x_{n})=n \right \} ,目标:求其中最小的元素 x_{n,1} 的值。

由题目可得 S 内所有元素的和等于 n^3,根据高斯求和公式:

\frac{(x_{n,1}+2(n-1))n}{2}=n^3

易解得: x_{n,1}=n^2-n+1

因此可得最大元素:x_{n,n}=n^2+n-1

f(k)=p,然后考虑分区间,将原来 [1,k] 区间分成 S_{1}+S_{2}+S_{3}+...+S_{p-1}+RR 表示分区间后的剩余元素集合。

则答案为:

\sum_{i=1}^{p}\sum_{j=1}^{i} i+\sum_{i=x_{p,1}}^{k}p

化简一下:

=\sum_{i=1}^{p}i^2+\frac{k-x_{p,1}+2}{2}p =\frac{(p-1)p(2p-1)}{6}+\frac{k-(p^2-p+1)+2}{2} p

再考虑求 p,先计算 x_{n,1}=k,也就是:

n^2-n+1=k

得:

n=\frac{\sqrt{4k-3} +1}{2}

显然:x_{p,1} \le k < x_{p+1,1}

也就是:x_{p,1} \le x_{n,1} < x_{p+1,1}

所以:p=\left \lfloor n \right \rfloor

这样,我们就求出了这道题。

2.数据范围与取模

先注意数据范围,k<2^{64},所以 p<2^{32},注意我们上边的求解过程,计算出的答案最大是 p^3 左右,大概为 2^{96},我们可以用 __int128 储存变量,但是会因为常数过大超时,所以我们要在计算时先确定好能被除数整除的数,先对其做除法,便可以完美的求出答案。

但注意另外一个点,求 p 的过程中,计算 \sqrt{4k-3} 时会先计算 4k-3 但是计算时电脑是使用 64 位二进制储存的,但是 __int128 并不支持根号运算,所以在计算时要考虑强转 long double 再存入 p 中,注意 p 不能使用 __int128 储存,不然会导致常数过大导致超时。

取模也是一个问题,为了节省常数复杂度,不能暴力直接全用 __int128,所以当我们除以 2 时,我们得考虑乘上它的乘法逆元,再考虑取余,显然 2 此时的乘法逆元是 500000004 所以我们取余时就可以分开取余了。

3.Code

展示我丑陋的代码

#include<bits/stdc++.h>
using namespace std;
typedef unsigned long long ull;
const ull md=1e9+7;
const ull ppp=500000004;
namespace Mker {
    ull SA, SB, SC, p = -1;
    int base;
    void init(){scanf("%llu%llu%llu%d", &SA, &SB, &SC, &base); p = p << (65 - base) >> (65 - base);}
    ull rand() {
        ull now = SC; now += !(now & 1);
        SA ^= SA << 32, SA ^= SA >> 13, SA ^= SA << 1;
        ull t = SA;
        return SA = SB, SB = SC, SC ^= t ^ SA, (now & p) + p + 1;
    }
};
void print(__int128 x){
    if(x==0)return;
    print(x/10);
    putchar(x%10+'0');
}
ull k;
int T;
int main(){
    scanf("%d",&T); Mker::init();
    while(T--){
        k = Mker::rand(); 
        ull p=(sqrt((long double)4*k-3)+1)/2;
        __int128 ans;
        ull a=p-1,b=p,c=2*p-1;      
        if(a%3==0)a/=3;
        else if(b%3==0)b/=3;
        else c/=3;
        if(a%2==0)a/=2;
        else if(b%2==0)b/=2;
        else c/=2;
        ans=(__int128)a*b%md*c%md;
        p%=md;
        ull len=(k%md*ppp%md-(((p*p%md-p%md)+md)%md+1)%md*ppp%md+md+1)%md;
        p%=md;
        (ans+=len*p%md)%=md;
        print(ans);
        putchar('\n');
    }
}