AT_abc349_f [ABC349F] Subsequence LCM

· · 题解

这里介绍一种非常魔怔的做法,莫比乌斯反演 + bitset

首先,莫比乌斯反演是什么呢?不会的可以 点这里 。

f(x)\operatorname{lcm} = x 的选取方案数,g(x)\operatorname{lcm} \mid x 的选取方案数。

则有:

g(x)=\sum_{d \mid x} f(d)

根据莫比乌斯反演公式,可得:

f(x)=\sum_{d \mid x} \mu(\frac{x}{d}) \times g(d)

g(x) 其实就是统计序列中有多少 x 的因数,将这个个数记为 pg(x)=2^p-1,这里的减一是为了排除空集。

于是,就可以洋洋洒洒写下如下代码:

//code by xuanxuanmeow
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 5;
const int mod = 998244353;
int n;
long long a[N],m,cnt;
inline long long g(long long x){
    long long t = 1;
    for (int i = 1;i <= n;++i) if (x % a[i] == 0) t = (t << 1) % mod;
    return (t - 1 + mod) % mod;
}
inline long long mobius(long long x){
    long long f = 1;
    for (long long i = 2;i * i <= x;++i){
        cnt = 0;
        while (x % i == 0) x /= i,++cnt;
        if (cnt > 1) return 0;
        if (cnt) f = -f;
    }
    if (x != 1) f = -f;
    return f;
}
inline long long f(long long x){
    long long res = 0;
    for (long long i = 1;i * i <= x;++i){
        if (x % i == 0) 
            if (i != x / i) res = (res + ((((mobius(i) * g(x / i) + mod) % mod) + ((mobius(x / i) * g(i) + mod) % mod) + mod) % mod)) % mod;
            else res = (res + ((mobius(i) * g(x / i) + mod) % mod)) % mod;
    }
    return res;
}
int main(){
    cin >> n >> m;
    for (int i = 1;i <= n;++i) cin >> a[i];
    cout << f(m);
}

之后你就会发现过不去,只能过 25 个点。

那咋办?换做法。

让我们进行神秘优化吧。

首先,从莫比乌斯函数入手。

可以考虑先质因数分解 m,获取所有质数因数后,dfs 枚举每种质数选或不选。

为什么这里可以将 dfs 枚举质数选不选作为优化手段呢?因为可以发现不同种类的质数个数最多只会有 13 个,而此时的复杂度是 O(2^k),这时复杂度是明显优于一遍遍重新求莫比乌斯函数以及枚举因数。

之后正常做就好了,再补上一点特判,代码如下:

//code by xuanxuanmeow
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 5;
const int mod = 998244353;
int n;
long long a[N],tms,m,cnt,seq[N],c[N],tot,res;
inline long long g(long long x){
    long long t = 1;
    for (int i = 1;i <= n;++i) if (x % a[i] == 0) t = (t << 1) % mod;
    return (t - 1 + mod) % mod;
}
inline void dfs(int x,long long num,long long f){
    res = (res + ((f * g(m / num) + mod) % mod)) % mod;
    for (int i = x;i <= tot;++i) dfs(i + 1,num * seq[i],-f);
}
int main(){
    cin >> n >> m;
  if (n == 1){
        cin >> a[1];
        if (a[1] == m) cout << 1;
        else cout << 0;
        return 0;
    }
    if (m == 1){
        long long ans = 1;
        for (int i = 1;i <= n;++i){
            cin >> a[i];
            if (a[i] == 1) ans = (ans << 1) % mod;
        }
        cout << ans - 1;
        return 0;
    }
    tms = m;
    for (long long i = 2;i * i <= tms;++i){
        cnt = 0;
        while (tms % i == 0) tms /= i,++cnt;
        if (cnt) seq[++tot] = i;
    }
    if (tms != 1) seq[++tot] = tms;
    for (int i = 1;i <= n;++i) cin >> a[i];
    dfs(1,1,1);
    cout << res;
}

但是你会发现还是过不去,只能过 32 个点,这时候就需要更加神秘的优化了。

思考一下这时候的瓶颈在哪里?发现 g(x) 函数求解是一个 O(n) 的过程,如何优化它呢?

考虑对每个 a_i 分解质因数,但是请注意,这里并不是暴力分解,而是对着 m 分解出来的质因数分解,因为如果 a_i 包含不在 m 内的质因数,一定不会被选中,也就可以忽略不计。

分解质因数后如何快速求个数呢?使用 bitset 就可以了,代码如下:

//code by xuanxuanmeow
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 5;
const int mod = 998244353;
long long n,a[N],tms,m,cnt,seq[N],tag,tot,res,tor,tmper,pr,ps;
bitset <N> c[16][56],banned;
#define nc() (p1==p2 && (p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
char *p1,*p2,buf[100000],ch;
inline void read(long long &x){
    ch = nc(),x = 0;
    while (ch < '0' || ch > '9') ch = nc();
    while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + (ch ^ 48),ch = nc();
}
inline long long fastpow(long long a,long long b){
    long long num = 1,t = a;
    while (b){
        if (b & 1) num = num * t % mod; 
        t = t * t % mod,b >>= 1;
    }
    return num;
}
inline long long g(long long x){
    banned.reset();
    for (int i = 1;i <= tot;++i){
        cnt = 0;
        while (x % seq[i] == 0) x /= seq[i],++cnt;
        banned |= c[i - 1][cnt + 1];
    }
    return (fastpow(2,tor - banned.count()) - 1 + mod) % mod;
}
inline void dfs(int x,long long num,long long f){
    res = (res + ((f * g(m / num) + mod) % mod)) % mod;
    for (int i = x;i <= tot;++i) dfs(i + 1,num * seq[i],-f);
}
int main(){
    read(n),read(m);
    tms = m,tmper = 1;
    for (long long i = 2;i * i <= tms;++i){
        cnt = 0;
        while (tms % i == 0) tms /= i,++cnt;
        if (cnt) seq[++tot] = i,tmper *= seq[tot];
    }
    if (tms != 1) seq[++tot] = tms,tmper *= seq[tot];
    for (int i = 1;i <= n;++i){
        read(a[++tor]);
        if (m % a[tor] != 0){
            --tor;
            continue;
        }
        tag = 0,pr = a[tor];
        while ((ps = __gcd(tmper,pr)) != 1) pr /= ps;
        if (pr != 1){
            --tor;
            continue;
        }
        for (int j = 1;j <= tot;++j){
            cnt = 0;
            while (a[tor] % seq[j] == 0) a[tor] /= seq[j],c[j - 1][++cnt][tor] = true;
        }
    }
    if (!tor){
      cout << 0;
      return 0;
    }
    dfs(1,1,1);
    cout << res;
}

至此,可以得到满分。

但这跑的还是太慢了,如何优化?

观察到现在的瓶颈其实是在对 m 的质因数分解上。

使用 Pollard-RhoO(m^{\frac{1}{4}}) 的时间复杂度分解质因数。

代码如下:

//code by xuanxuanmeow
#include <bits/stdc++.h>
using namespace std;
#define f(x) (((((x) % n) * ((x) % n) % n) + c) % n)
#define abs(x) ((x) > 0?(x):-(x))
mt19937 rnd(random_device{}());
const int N = 2e5 + 5;
const int mod = 998244353;
long long n,a[N],m,cnt,seq[N],tag,tot,res,tor,tmper,pr,ps;
bitset <N> c[16][56],banned;
#define nc() (p1==p2 && (p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
#define lcm(x,y) (x / __gcd(x,y) * y)
char *p1,*p2,buf[100000],ch;
inline void read(long long &x){
    ch = nc(),x = 0;
    while (ch < '0' || ch > '9') ch = nc();
    while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + (ch ^ 48),ch = nc();
}
inline long long fastpow(long long a,long long b){
    long long num = 1,t = a;
    while (b){
        if (b & 1) num = num * t % mod; 
        t = t * t % mod,b >>= 1;
    }
    return num;
}
inline long long g(long long x){
    banned.reset();
    for (int i = 1;i <= tot;++i){
        cnt = 0;
        while (x % seq[i] == 0) x /= seq[i],++cnt;
        banned |= c[i - 1][cnt + 1];
    }
    return (fastpow(2,tor - banned.count()) - 1 + mod) % mod;
}
inline void dfs(int x,long long num,long long f){
    res = (res + ((f * g(m / num) + mod) % mod)) % mod;
    for (int i = x;i <= tot;++i) dfs(i + 1,num * seq[i],-f);
}
inline __int128 random(__int128 l,__int128 r){
    return rnd() % (r - l + 1) + l;
}
inline __int128 fastpows(__int128 a,__int128 b,__int128 moder){
    __int128 num = 1,t = a;
    while (b){
        if (b & 1) num = num * t % moder; 
        t = t * t % moder,b >>= 1;
    }
    return num;
}
inline bool checker(long long n){
    if (n < 3 || !(n & 1)) return n == 2;
    if (n % 3 == 0) return n == 3;
    __int128 u = n - 1,t = 0;
    while (!(u & 1)) u >>= 1,++t;
    for (int i = 1;i <= 10;++i){
        __int128 a = random(2,n - 2),v = fastpows(a,u,n);
        if (v == 1) continue;
        int s;
        for (s = 0;s < t;++s){
            if (v == n - 1) break;
            v = 1ll * v * v % n;
        }
        if (s == t) return false;
    }
    return true;
}
inline long long prs(__int128 n){
    if (n == 4) return 2;
    __int128 x = random(0,n - 1),y = x,c = random(3,n - 1),d = 1,cnt,tmp;
    x = f(x),y = f(f(y));
    for (int lim = 1;x != y;lim = min(128,lim << 1)){
        cnt = 1;
        for (int i = 1;i <= lim;++i){
            tmp = cnt * abs(x - y) % n;
            if (!tmp) break;
            cnt = tmp;
            x = f(x),y = f(f(y));
        }
        d = __gcd(cnt,n);
        if (d != 1) return d;
    }
    return n;
}
inline void primes(long long m){
    if (m == 1) return;
    if (checker(m)){
        seq[++tot] = m,tmper *= m;
        return;
    }
    long long x = prs(m),y = m / x;
    while ((ps = __gcd(x,y)) != 1) y /= ps;
    primes(x),primes(y);
}
int main(){
    read(n),read(m);
    tmper = 1,primes(m);
    for (int i = 1;i <= n;++i){
        read(a[++tor]);
        if (m % a[tor] != 0){
            --tor;
            continue;
        }
        tag = 0,pr = a[tor];
        while ((ps = __gcd(tmper,pr)) != 1) pr /= ps;
        if (pr != 1){
            --tor;
            continue;
        }
        for (int j = 1;j <= tot;++j){
            cnt = 0;
            while (a[tor] % seq[j] == 0) a[tor] /= seq[j],c[j - 1][++cnt][tor] = true;
        }
    }
    if (!tor){
        cout << 0;
        return 0;
    }
    dfs(1,1,1);
    cout << res;
}