AT_abc349_f [ABC349F] Subsequence LCM
XuanXuanMeow · · 题解
这里介绍一种非常魔怔的做法,莫比乌斯反演 + bitset。
首先,莫比乌斯反演是什么呢?不会的可以 点这里 。
令
则有:
根据莫比乌斯反演公式,可得:
而
于是,就可以洋洋洒洒写下如下代码:
//code by xuanxuanmeow
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 5;
const int mod = 998244353;
int n;
long long a[N],m,cnt;
inline long long g(long long x){
long long t = 1;
for (int i = 1;i <= n;++i) if (x % a[i] == 0) t = (t << 1) % mod;
return (t - 1 + mod) % mod;
}
inline long long mobius(long long x){
long long f = 1;
for (long long i = 2;i * i <= x;++i){
cnt = 0;
while (x % i == 0) x /= i,++cnt;
if (cnt > 1) return 0;
if (cnt) f = -f;
}
if (x != 1) f = -f;
return f;
}
inline long long f(long long x){
long long res = 0;
for (long long i = 1;i * i <= x;++i){
if (x % i == 0)
if (i != x / i) res = (res + ((((mobius(i) * g(x / i) + mod) % mod) + ((mobius(x / i) * g(i) + mod) % mod) + mod) % mod)) % mod;
else res = (res + ((mobius(i) * g(x / i) + mod) % mod)) % mod;
}
return res;
}
int main(){
cin >> n >> m;
for (int i = 1;i <= n;++i) cin >> a[i];
cout << f(m);
}
之后你就会发现过不去,只能过
那咋办?换做法。
让我们进行神秘优化吧。
首先,从莫比乌斯函数入手。
-
- 否则莫比乌斯函数与质数因数个数有关。
可以考虑先质因数分解 dfs 枚举每种质数选或不选。
为什么这里可以将 dfs 枚举质数选不选作为优化手段呢?因为可以发现不同种类的质数个数最多只会有
之后正常做就好了,再补上一点特判,代码如下:
//code by xuanxuanmeow
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 5;
const int mod = 998244353;
int n;
long long a[N],tms,m,cnt,seq[N],c[N],tot,res;
inline long long g(long long x){
long long t = 1;
for (int i = 1;i <= n;++i) if (x % a[i] == 0) t = (t << 1) % mod;
return (t - 1 + mod) % mod;
}
inline void dfs(int x,long long num,long long f){
res = (res + ((f * g(m / num) + mod) % mod)) % mod;
for (int i = x;i <= tot;++i) dfs(i + 1,num * seq[i],-f);
}
int main(){
cin >> n >> m;
if (n == 1){
cin >> a[1];
if (a[1] == m) cout << 1;
else cout << 0;
return 0;
}
if (m == 1){
long long ans = 1;
for (int i = 1;i <= n;++i){
cin >> a[i];
if (a[i] == 1) ans = (ans << 1) % mod;
}
cout << ans - 1;
return 0;
}
tms = m;
for (long long i = 2;i * i <= tms;++i){
cnt = 0;
while (tms % i == 0) tms /= i,++cnt;
if (cnt) seq[++tot] = i;
}
if (tms != 1) seq[++tot] = tms;
for (int i = 1;i <= n;++i) cin >> a[i];
dfs(1,1,1);
cout << res;
}
但是你会发现还是过不去,只能过
思考一下这时候的瓶颈在哪里?发现
考虑对每个
分解质因数后如何快速求个数呢?使用 bitset 就可以了,代码如下:
//code by xuanxuanmeow
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 5;
const int mod = 998244353;
long long n,a[N],tms,m,cnt,seq[N],tag,tot,res,tor,tmper,pr,ps;
bitset <N> c[16][56],banned;
#define nc() (p1==p2 && (p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
char *p1,*p2,buf[100000],ch;
inline void read(long long &x){
ch = nc(),x = 0;
while (ch < '0' || ch > '9') ch = nc();
while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + (ch ^ 48),ch = nc();
}
inline long long fastpow(long long a,long long b){
long long num = 1,t = a;
while (b){
if (b & 1) num = num * t % mod;
t = t * t % mod,b >>= 1;
}
return num;
}
inline long long g(long long x){
banned.reset();
for (int i = 1;i <= tot;++i){
cnt = 0;
while (x % seq[i] == 0) x /= seq[i],++cnt;
banned |= c[i - 1][cnt + 1];
}
return (fastpow(2,tor - banned.count()) - 1 + mod) % mod;
}
inline void dfs(int x,long long num,long long f){
res = (res + ((f * g(m / num) + mod) % mod)) % mod;
for (int i = x;i <= tot;++i) dfs(i + 1,num * seq[i],-f);
}
int main(){
read(n),read(m);
tms = m,tmper = 1;
for (long long i = 2;i * i <= tms;++i){
cnt = 0;
while (tms % i == 0) tms /= i,++cnt;
if (cnt) seq[++tot] = i,tmper *= seq[tot];
}
if (tms != 1) seq[++tot] = tms,tmper *= seq[tot];
for (int i = 1;i <= n;++i){
read(a[++tor]);
if (m % a[tor] != 0){
--tor;
continue;
}
tag = 0,pr = a[tor];
while ((ps = __gcd(tmper,pr)) != 1) pr /= ps;
if (pr != 1){
--tor;
continue;
}
for (int j = 1;j <= tot;++j){
cnt = 0;
while (a[tor] % seq[j] == 0) a[tor] /= seq[j],c[j - 1][++cnt][tor] = true;
}
}
if (!tor){
cout << 0;
return 0;
}
dfs(1,1,1);
cout << res;
}
至此,可以得到满分。
但这跑的还是太慢了,如何优化?
观察到现在的瓶颈其实是在对
使用 Pollard-Rho 在
代码如下:
//code by xuanxuanmeow
#include <bits/stdc++.h>
using namespace std;
#define f(x) (((((x) % n) * ((x) % n) % n) + c) % n)
#define abs(x) ((x) > 0?(x):-(x))
mt19937 rnd(random_device{}());
const int N = 2e5 + 5;
const int mod = 998244353;
long long n,a[N],m,cnt,seq[N],tag,tot,res,tor,tmper,pr,ps;
bitset <N> c[16][56],banned;
#define nc() (p1==p2 && (p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
#define lcm(x,y) (x / __gcd(x,y) * y)
char *p1,*p2,buf[100000],ch;
inline void read(long long &x){
ch = nc(),x = 0;
while (ch < '0' || ch > '9') ch = nc();
while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + (ch ^ 48),ch = nc();
}
inline long long fastpow(long long a,long long b){
long long num = 1,t = a;
while (b){
if (b & 1) num = num * t % mod;
t = t * t % mod,b >>= 1;
}
return num;
}
inline long long g(long long x){
banned.reset();
for (int i = 1;i <= tot;++i){
cnt = 0;
while (x % seq[i] == 0) x /= seq[i],++cnt;
banned |= c[i - 1][cnt + 1];
}
return (fastpow(2,tor - banned.count()) - 1 + mod) % mod;
}
inline void dfs(int x,long long num,long long f){
res = (res + ((f * g(m / num) + mod) % mod)) % mod;
for (int i = x;i <= tot;++i) dfs(i + 1,num * seq[i],-f);
}
inline __int128 random(__int128 l,__int128 r){
return rnd() % (r - l + 1) + l;
}
inline __int128 fastpows(__int128 a,__int128 b,__int128 moder){
__int128 num = 1,t = a;
while (b){
if (b & 1) num = num * t % moder;
t = t * t % moder,b >>= 1;
}
return num;
}
inline bool checker(long long n){
if (n < 3 || !(n & 1)) return n == 2;
if (n % 3 == 0) return n == 3;
__int128 u = n - 1,t = 0;
while (!(u & 1)) u >>= 1,++t;
for (int i = 1;i <= 10;++i){
__int128 a = random(2,n - 2),v = fastpows(a,u,n);
if (v == 1) continue;
int s;
for (s = 0;s < t;++s){
if (v == n - 1) break;
v = 1ll * v * v % n;
}
if (s == t) return false;
}
return true;
}
inline long long prs(__int128 n){
if (n == 4) return 2;
__int128 x = random(0,n - 1),y = x,c = random(3,n - 1),d = 1,cnt,tmp;
x = f(x),y = f(f(y));
for (int lim = 1;x != y;lim = min(128,lim << 1)){
cnt = 1;
for (int i = 1;i <= lim;++i){
tmp = cnt * abs(x - y) % n;
if (!tmp) break;
cnt = tmp;
x = f(x),y = f(f(y));
}
d = __gcd(cnt,n);
if (d != 1) return d;
}
return n;
}
inline void primes(long long m){
if (m == 1) return;
if (checker(m)){
seq[++tot] = m,tmper *= m;
return;
}
long long x = prs(m),y = m / x;
while ((ps = __gcd(x,y)) != 1) y /= ps;
primes(x),primes(y);
}
int main(){
read(n),read(m);
tmper = 1,primes(m);
for (int i = 1;i <= n;++i){
read(a[++tor]);
if (m % a[tor] != 0){
--tor;
continue;
}
tag = 0,pr = a[tor];
while ((ps = __gcd(tmper,pr)) != 1) pr /= ps;
if (pr != 1){
--tor;
continue;
}
for (int j = 1;j <= tot;++j){
cnt = 0;
while (a[tor] % seq[j] == 0) a[tor] /= seq[j],c[j - 1][++cnt][tor] = true;
}
}
if (!tor){
cout << 0;
return 0;
}
dfs(1,1,1);
cout << res;
}