[题解] P9118 [春季测试 2023] 幂次

· · 题解

这个题一看就很普及。我们分情况考虑:

k=1

显然直接输出 n 即可。

k\ge 3

发现 \sqrt[k]{n} \le 10^6,那么我们可以枚举底数。对于每个底数 a,看一下 a^k 是否 \le n。如果不是,那么就可以直接停止枚举。如果是,那么我们暴力乘找到最小的 k' 满足 a^{k'}>n。那么,\sum k'-k 就是答案?

错了。注意到不同底数的不同次幂可能相同,所以需要判重。这里我使用了 map 判重。令 c=\sqrt[k]{n},时间复杂度 O(c \log^2 n)。事实上远远跑不满。

看起来很暴力,但可以通过,并且并不需要用容斥什么的动脑子。

k=2

为什么要单独讲 k=2?因为 \sqrt[2]{n} \le 10^9,用上面的方法妥妥的超时。

思考我们枚举的本质是什么?统计答案?不是,如果只是统计答案的话,那么二分或者其他方法肯定更快。那为什么要一个个枚举?是因为要判重。那么,如果我们对完全平方数能找到一个更好的性质用于判重,就可以避免对于指数为 2 的数的枚举了。

这个时候就要思考完全平方数的性质了,最常用的一点就是,分解质因数后所有指数都是偶数。那么,我们先把 n 以内的完全平方数个数计入答案(\sqrt{n} 个),然后,从 2\sqrt[3]{n} 枚举 a。如果 a 为完全平方数,那么直接跳过(因为这样的话,a 的任意次幂肯定是完全平方数),否则,看 a^k 是否 \le n。如果不是,停止枚举;如果是,暴力乘找到最小的 k' 满足 k' \ge n。对于每个 k \le j \le k',如果 a^j 已经出现过,或者 j \bmod 2=0,那么不计入答案;否则计入答案。时间复杂度 O(c \log^2 n)

好想好写,适合做一道普及组的题。

Code:

#include <bits/stdc++.h>
using namespace std;
#define int long long
int n, k, ans=1, pf[1000010];
map <int, int> mp;
int sqrt(int x){
    int l = 0, r = 1e9;
    while (l <= r){
        int mid = l + r >> 1;
        if (mid * mid <= x) l = mid + 1;
        else r = mid - 1; 
    }
    return l - 1;
}
signed main(){
    scanf ("%lld%lld", &n, &k);
    if (k == 1) return printf ("%lld\n", n), 0;
    if (k == 2){
        ans = sqrt(n);
        for (int i=1; i<=1000; i++) pf[i*i] = 1;
        if (k == 2) k = 3;
        for (int i=2; i; i++){
            if (mp.find(i) != mp.end() || pf[i]) continue;
            int now = 1, flag = 0;
            for (int j=1; j<=k; j++){
                if (now > n / i){
                    flag = 1;
                    break;
                }
                now *= i;
            }
            if (flag) break;
            for (int j=k; j; j++){
                if (mp.find(now) == mp.end() && j % 2 == 1) ans ++;
                mp[now] = 1;
                if (now > n / i) break;
                now *= i;
            }
        }
    }
    if (k >= 3){
        for (int i=2; i; i++){
            if (mp.find(i) != mp.end()) continue;
            int now = 1, flag = 0;
            for (int j=1; j<=k; j++){
                if (now > n / i){
                    flag = 1;
                    break;
                }
                now *= i;
            }
            if (flag) break;
            for (;;){
                if (mp.find(now) == mp.end()) ans ++;
                mp[now] = 1;
                if (now > n / i) break;
                now *= i;
            }
        }
    }
    printf ("%lld\n", ans);
    return 0;
}