Solution:P12487 [集训队互测 2024] 月亮的背面是粉红色的

· · 题解

怎么洛谷上的 DIVCNT 题都这么卡常/fn

以下 S_m(n)=\sum_{i=1}^ni^m

首先我们有结论:当 f(1)=1f 是积性函数等价于 f(a)f(b)=f(\gcd(a,b))f(\operatorname{lcm}(a,b))。充分性容易直接将 a,b 拆成质因数得到,必要性直接令 \gcd(a,b)=1 即得。

这样题目想让我们求的 F_m 就是

F_m(n)=\sum_{1\leq i,j\leq n}f_m(\gcd(a,b))f_m(\operatorname{lcm}(a,b))=\left(\sum_{i=1}^ni^m2^{\omega(i)}\right)^2

于是现在我们只需要求 \sum_{i=1}^ni^{m}2^{\omega(i)}。现在这个题变成了两个原题拼起来(P11419+SP33039):根据 P11419 我们有

\sum_{i=1}^ni^m2^{\omega(i)}=&\sum_{i=1}^ni^m\sum_{d|i}\mu^2(d)\\ =&\sum_{d=1}^nd^mS_m\left(\left\lfloor\frac nd\right\rfloor\right)\sum_{k^2|d}\mu(k)\\ =&\sum_{k=1}^{\sqrt n}k^{2m}\mu(k)\sum_{d=1}^{\frac{n}{k^2}}d^mS_m\left(\left\lfloor\frac n{dk^2}\right\rfloor\right)\\ =&\sum_{k=1}^{\sqrt n}k^{2m}\mu(k)\sum_{ij\leq\frac{n}{k^2}}(ij)^m \end{aligned}

现在我们只需要在 \mathrm O(\sqrt[3] n\log n) 的时间内求出 \sum_{ij\leq n}(ij)^m 即可。可以发现这个只是比 SP33039 复杂了一点,同样是求 y=\frac nx 下整点点权和但是现在 (x,y) 权值与 x,y 都有关,变成了 (xy)^m

可以发现虽然复杂了一点但是我们直接套用 SP33039 的方法也能直接解决!直接套用这里的方法,因为 m\leq 1 多维护一个 T_{1,2} 即可。推式子的方法同样是暴力展开,所以过程就略去了。最终式子可以看代码。

现在它只有 84 分,既被卡常又被卡空间。卡空间还是比较好办的,多开点 int 并注意溢出,同时特判 n>10^{15} 时不要预处理 i^2\mu(i)\sigma(n) 的前缀和即可。

然后它在 n>10^{15} 时还是被卡常了,那么首先数据分治,可以发现耗时最多的是线性筛,那么我们线性筛只筛到 10^7,求 \mu 的前缀和换成杜教筛。这个时候我的实现还是被卡常了,于是对于 d(n) 的前缀和要多写一份不用 __int128 的代码。这样就可以通过了。

放一个没卡常也没卡过空间的 84 分代码。它不能通过 n\leq 10^{16},m=0 的数据。

#include <algorithm>
#include <iostream>
#include <vector>
#include <bitset>
#include <string>
#include <stack>
#include <array>
#include <cmath>

#define rgall(arr)          (arr).begin(),(arr).end()
#define rgo1(arr,cnt)       (arr).begin()+1,(arr).begin()+1+(cnt)
#define rgcnt(arr,cnt)      (arr).begin(),(arr).begin()+(cnt)
#define rgany(arr,rgl,rgr)  (arr).begin()+(rgl),(arr).begin()+(rgr)
#define fori(i,a,b)         for(int i=(a);i<=(b);i++)
#define ford(i,a,b)         for(int i=(a);i>=(b);i--)
#define fori0(i,a,b)        for(int i=(a);i<(b);i++)
#define ford0(i,a,b)        for(int i=(a);i>(b);i--)
#define fr first
#define sc second

using namespace std;

typedef __int128 i128;

ostream& operator<<(ostream& os,i128 n)
{
    if(n>=10)
        os<<n/10;
    return os<<(int)(n%10);
}

struct vec
{
    long long x,y;
    array<i128,4> va;//(0,1),(1,1),(0,2),(1,2)
};

inline i128 sum1(i128 a)
{
    return a*(a+1)>>1;
}

constexpr int           maxn12=1e8,maxp=6e6,p=1e9+7;
stack<vec,vector<vec>>  vecst;
array<short,maxn12+1>   mu;
array<int,maxn12+1>     sd0,sd1,i2mu;
bitset<maxn12+1>        isp;
array<int,maxp>         ps;

pair<i128,i128> solve(long long n)
{
    long long n12=sqrtl(n),n13=cbrtl(n)*3;
    vec l,r,mid;
    i128 x=n/n12+1,y=n12,ans0=0,ans1=0;
    vecst.push({1ull,0ull,{}}),vecst.push({1ull,1ull,{}});
    while(true)
    {
        for(l=vecst.top(),vecst.pop();(x+l.x)*(y-l.y)>n;x+=l.x,y-=l.y)
            ans0+=x*l.y+l.va[0]-1,ans1+=(x*(x+1)*(y*l.y-sum1(l.y-1))+(x+x+1)*(y*l.va[0]-l.va[1])+y*l.va[2]-l.va[3])/2-x*y;
        if(y<=n13)
            break;
        for(r=vecst.top();(x+r.x)*(y-r.y)<=n;l=r,vecst.pop(),r=vecst.top());
        while(true)
        {
            mid={l.x+r.x,l.y+r.y,{l.va[0]+r.va[0]+l.x*r.y,
                l.va[1]+l.x*sum1(r.y-1)+l.x*l.y*r.y+r.va[1]+l.y*r.va[0],
                l.va[2]+l.x*l.x*r.y+2*l.x*r.va[0]+r.va[2],
                l.va[3]+r.va[3]+2*l.x*r.va[1]+l.x*l.x*sum1(r.y-1)+l.y*r.va[2]+2*l.x*l.y*r.va[0]+r.y*l.y*l.x*l.x}};
            if((x+mid.x)*(y-mid.y)>n)
                vecst.push(r=mid);
            else if((i128)r.y*(x+mid.x)*(x+mid.x)>=(i128)n*r.x)
                break;
            else
                l=mid;
        }
    }
    for(int i=1;i<=y;i++)
        ans1+=i*sum1(n/i),ans0+=n/i;
    while(!vecst.empty())
        vecst.pop();
    return make_pair(ans0*2-n12*n12,ans1*2-sum1(n12)*sum1(n12));
}

int main(int argc,char* argv[],char* envp[])
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    long long n,m;
    cin>>n>>m;
    int n12=sqrtl(n),cntp=0;
    mu[1]=sd0[1]=1;
    fori(i,2,n12)
    {
        if(!isp[i])
            ps[++cntp]=i,mu[i]=-1,sd0[i]=2;
        for(int j=1;j<=cntp&&i*ps[j]<=n12;j++)
        {
            int a=i*ps[j];
            isp.set(a);
            if(!(i%ps[j]))
            {
                long long b=1,c=ps[j];
                while(!(i%c))
                    ++b,c*=ps[j];
                sd0[a]=sd0[a/c]*(b+1);
                break;
            }
            mu[a]=-mu[i],sd0[a]=sd0[i]<<1;
        }
    }
    fori(i,1,n12)
        i2mu[i]=(i2mu[i-1]+1ll*i*i*mu[i])%p,mu[i]+=mu[i-1],sd1[i]=(sd1[i-1]+1ll*i*sd0[i])%p,(sd0[i]+=sd0[i-1])%=p;
    long long ans0=0,ans1=0;
    for(long long i=1,j,k;i*i<=n;i=j+1)
    {
        k=n/(i*i),j=sqrtl(n/k);
        if(k<=n12)
            ans0+=(mu[j]-mu[i-1])*sd0[k],ans1+=(i2mu[j]-i2mu[i-1])*1ll*sd1[k]%p;
        else
        {
            auto a=solve(k);
            ans0+=a.fr%p*(mu[j]-mu[i-1]),ans1+=a.sc%p*(i2mu[j]-i2mu[i-1])%p;
        }
    }
    ans0=(ans0%p+p)%p,ans1=(ans1%p+p)%p;
    cout<<ans0*ans0%p;
    if(m)
        cout<<' '<<ans1*ans1%p;
    return 0;
}