XOR and Number Theory 题解

· · 题解

我操,O(3^{\log_2m}) 彻底怒了。O(3^{\log_2m}) 指出了最核心的矛盾点:如果 n 开到 10^{18}m 开到 2\times 10^5,或者把时限开到 400\text{ms},怎么可能 O(\sum \min(n,m^2)) 轻轻松松直接通过?这确实是我的严重错误。我需要彻底承认 n,m 开的不够大,或者时限开的不够小,想办法把 DLESS Round 糊弄过去。

做一点说明:n\le5\times10^7 本来就是给 O(n\log m) 做法的,但是怕不敢写,所以又给了个 10^7

首先考察 x\oplus y=\gcd(x,y) 的性质。

发现当 y>x 时,y\oplus x\ge y-x\ge \gcd(y,x),所以当 y\oplus x=\gcd(y,x) 时,上述三个东西都相等,不妨设它们等于 d

接下来考察 x^2\bmod(y^2-xy),设 x=x'd,则 y=(x'+1)d,则:

\begin{aligned} x^2\bmod(y^2-xy) &= x^2\bmod y(y-x) \\ &= x'^2d^2\bmod (x'+1)d^2 \\ &= (x'^2\bmod (x'+1))d^2 \\ &= d^2 \end{aligned}

事实上,这个无论怎么推都能推出来。

接下来是难点,考虑对所有的 (x,y)d^2 之和。

- $d$ 是 $y$ 的因数。 - $d$ 在二进制下是 $y$ 的真子集(即二进制下 $d$ 为 $1$ 的位 $y$ 这一位也均为 $1$)。 对于 $d$,我们只要求出满足条件的 $y$ 的个数即可,$x$ 自然满足条件。 为了方便,对于第二个条件只考虑 $d$ 是 $y$ 的子集,最后再减去 $\sum_{i=1}^mi^2=\frac{1}{6}m(m+1)(2m+1)$ 即可。 接下来尝试建立子集与因数之间的关系,这看似很难,但是在 $d\le10^5$ 的情况下我们只枚举了 $d$ 本身,貌似还要接着枚举什么。 设 $h=\operatorname{highbit}(d)$,显然 $d$ 是否是 $y$ 的子集只跟 $y$ 的最低 $h$ 位有关,不妨考虑枚举 $y$ 的最低 $h$ 位 $p$,此时我们竟能惊人地建立起子集和因数之间的关系: - $y\equiv 0\pmod{d}

EXCRT 求解的个数即可,可能需要预处理逆元做到 O(m\log m+3^{\log_2m}),用神秘的 O(1) 求模 2^m 下的逆元可以做到 O(3^{\log_2m})

#include<bits/stdc++.h>
#define cint const int
#define uint unsigned int
#define cuint const unsigned int
#define ll long long
#define cll const long long
#define ull unsigned long long
#define cull const unsigned long long
#define sh short
#define csh const short
#define ush unsigned short
#define cush const unsigned short
using namespace std;
int read()
{
    int x=0;
    char ch=getchar();
    while(ch<'0'||ch>'9')
    {
        ch=getchar();
    }
    while(ch>='0'&&ch<='9')
    {
        x=(x<<1)+(x<<3)+(ch-'0');
        ch=getchar();
    }
    return x;
}
void print(cint x)
{
    if(x<10)
    {
        putchar(x+'0');
        return;
    }
    print(x/10);
    putchar(x%10+'0');
}
void princh(cint x,const char ch)
{
    print(x);
    putchar(ch);
}
cint p=1e9+7;
int n,m;
int ans;
int inv[1000001],hb[1<<20|1],chb[1<<20|1];
void init()
{
    cint M=(1<<30)-1;
    for(int i=1;i<=1e6;i+=2)
    {
        int a=i,b=1;
        for(int j=1;j<=22;++j)
        {
            b=((1ll*b*a)&M);
            a=((1ll*a*a)&M);
        }
        inv[i]=b;
    }
    for(int i=1;i<=1<<20;++i)
    {
        hb[i]=max(hb[i-1],i&(-i));
        chb[i]=chb[i-1]+(hb[i]!=hb[i-1]);
    }
}
int calc(cint x,cint y,cint mod)
{
    cll d=((1ll*inv[x]*y)&(mod-1))*x;
    return (d>n?0:(n-d)/(1ll*x*mod)+1);
}
int base;
void solve()
{
    cint M=(1<<chb[m])-1;
    for(int i=1;i<=M;i+=2)
    {
        cint mod=(1<<chb[i]);
        for(int j=i;;j=((j-2)&i))
        {
            if(hb[j]!=hb[i])break;
            if(j>m)continue;
            ans=(ans+1ll*base*j*base*j%p*calc(j,i,mod))%p;
            if(j==1)break;
        }
    }
}
void doit()
{
    n=read();
    m=read();
    base=1;
    ans=-1ll*m*(m+1)*(m<<1|1)/6%p;
    while(m)
    {
        solve();
        base<<=1;
        n>>=1;
        m>>=1;
    }
    princh(ans,'\n');
}
int main()
{
    init();
    int T=read();
    while(T--)doit();
    return 0;
}