T4-简单题

· · 题解

算法 1

枚举并求出值即可。时间复杂度 O(n^2logk+n^2\sqrt n)

期望得分:10

算法 2:

在算法 1 的基础上预处理函数的值 O(n^2)

期望得分:20

算法 3:

开始反演了。

可以轻易得知,f(i)= \mu^2(i)

于是原式可变为:

\sum\limits_{i=1}^n\sum\limits_{j=1}^m(i+j)^k\mu^2(\gcd(i,j))\gcd(i,j)

=\sum\limits_{d}d^k\sum\limits_{i=1}^{n/d}\sum\limits_{j=1}^{n/d}(i+j)^k\mu^2(d)d[\gcd(i,j)==1]

=\sum\limits_{d}d^{k+1}\mu^2(d)\sum\limits_{i=1}^{n/d}\sum\limits_{j=1}^{n/d}(i+j)^k \sum\limits_{p|i,p|j}\mu(p)

=\sum\limits_p\mu(p)p^k\sum\limits_{d}d^{k+1}\mu^2(d)\sum\limits_{i=1}^{n/dp}\sum\limits_{j=1}^{n/dp}(i+j)^k

=\sum\limits_TT^k\sum\limits_{d|T}d\mu^2(d)\mu(T/d)g(n/dp)

其中 g(x)=\sum\limits_{i=1}^x\sum\limits_{j=1}^x(i+j)^k

前面这一块又是一个积性函数,可以线性筛。

构造 h(T)=\sum\limits_TT^k\sum\limits_{d|T}d\mu^2(d)\mu(T/d)

再求前缀和,之后用整除分块。

注意,本题的 n 很大,i^k 也要线性筛,由于质数不多,可以近似于 O(n)

解决完 h,再来看看 g

直接求值是 n^2 的,我们可以改为枚举 (i+j) 的值,之前线性筛可以求出 k 次方的值,单次询问为 O(n)

使用整除分块,需要求 g(n/1),g(n/2),g(n/3)...,总和为 n \ln n

所以总复杂度为O(n+\sqrt n+n\ln n)(预处理 + 整除分块 +g)。

如果上述三个任有一个没做到就只能得到部分分。

注意,因为算 i+j 时答案可能会达到 2n,线性筛需要筛到 2n。

#pragma GCC optimize(2,3,4,5)
#include<bits/stdc++.h>
#define N 10000010
#define M 998244353
#define re register
using namespace std;
inline long long read()
{
    re long long x=0;
    re char c=getchar();
    while(c<'0'){c=getchar();}
    while(c>='0') {x=(x<<1)+(x<<3)+(c-48);c=getchar();} 
    return x;
}
bool vis[N];
int prim[N],sum[N],cnt,tot,f[N];
int kp[N];
inline int g(re int x){
    re int ans=0;
    for(re int i=1;i<=x;++i)ans=(long long)(ans+1ll*(i-1)*kp[i])%M;
    for(re int i=x+1;i<=2*x;++i)ans=(long long)(ans+1ll*(2*x+1-i)*kp[i])%M;
    return ans;
}
inline int ksm(re int p,re int b){
    re int ans=1;
    while(b){
        if(b&1)ans=1ll*ans*p%M;
        p=1ll*p*p%M;
        b>>=1;
    }
    return ans;
}
inline int solve(re int a){ 
        re int ans;
        ans=0;
        for(re int l=1,r;l<=a;l=r+1)
        {
            r=a/(a/l);
            ans=(long long)(ans+1ll*g(a/l)*(sum[r]-sum[l-1]))%M;
        }
        if(ans<0)ans%=M,ans+=M;
        return ans;
}
inline void write(re int x){
    if(x>9)write(x/10);
    putchar((x%10)+'0');
}
signed main()
{
//  freopen("test1.in.txt","r",stdin);
//  freopen("test1.out","w",stdout);
    re int n;
    re long long k;
    f[1]=1;
    sum[1]=1;kp[1]=1;n=read();k=read()%(M-1);
    for(re int i=2;i<=n*2;++i){
        if(!vis[i])kp[i]=ksm(i,k),f[i]=1ll*(i-1)*kp[i]%M,prim[++cnt]=i;
        for(re int j=1;j<=cnt&&i*prim[j]<=n*2;++j)
        {
            vis[i*prim[j]]=1;
            kp[i*prim[j]]=1ll*kp[i]*kp[prim[j]]%M;
            if(i%prim[j]==0){
                if((i/prim[j])%prim[j]){
                f[i*prim[j]]=1ll*(-prim[j])*f[i/prim[j]]%M*kp[prim[j]*prim[j]]%M;}

                break;
            }
            else f[i*prim[j]]=1ll*f[i]*f[prim[j]]%M;
        }
        sum[i]=f[i]+sum[i-1];
        sum[i]%=M;
    }
    write(solve(n));
    putchar('\n');
}