CF1750D

· · 题解

题目链接

\text{Solution}

容易发现本题的答案为:

\prod_{i=1}^{n-1} \sum_{w=1}^m [\operatorname{gcd}(a_i,w)=a_{i+1}]

转化一下式子:

\prod_{i=1}^{n-1} \sum_{w=1}^m [\operatorname{gcd}(a_i,w)=a_{i+1}] = \prod_{i=1}^{n-1} \sum_{w=1}^{\tiny{\left\lfloor\dfrac{m}{a_{i+1}}\right\rfloor}} [\operatorname{gcd}(\frac{a_i}{a_{i+1}},w)=1]

k_i={\tiny{\left\lfloor\dfrac{m}{a_{i+1}}\right\rfloor}} , x_i=\large{\frac{a_i}{a_{i+1}}}.

对后面的求和式反演一下:

\sum_{w=1}^{k_i} [\operatorname{gcd}(x_i,w)=1] = \sum_{w=1}^{k_i} \sum_{d|x_i,d|w} \mu(d) = \sum_{d|x_i} \mu(d) \sum_{w=1}^{k_i} [d|w] = \sum_{d|x_i} \mu(d) \left\lfloor\dfrac{k_i}{d}\right\rfloor

最后的整个式子:

\prod_{i=1}^{n-1} \sum_{d|x_i} \mu(d) \left\lfloor\dfrac{k_i}{d}\right\rfloor

然后再分析一下复杂度:

  1. 如果 {\exists \; i \in [1,n) , a_{i+1} \nmid a_{i}} 那么答案一定是 0. 所以说数组 a 单调非严格递减且后一个数是前一个数的因数。

  2. a_i=a_{i+1}, x_i = 1, 每次求后面这个和式是 O(1) 的。

  3. a_i \not = a_{i+1}, 每次根号求 \mu ,后面这个求和式单次复杂度为 O(\sqrt{x_i}) = O(x_i^{1/2}), 注意到此时 a_{i+1} 一定小于 a_i 的一半,所以这个操作最多只会执行 \log a_1 次。

  4. 由均值不等式可得每组数据总复杂度为 O(\log a_1 \times \sum\limits_{i=1}^{n-1} \sqrt{x_i}) = O(\log a_1 \times \sum\limits_{i=1}^{n-1}x_i^{1/2}) \le O(\log a_1 \times \sqrt{\frac{\sum\limits_{i=1}^{n-1}x_i}{n-1}}) \le O(\sqrt m \log m) .

点击查看缺省源"Jairon.h"

#include <bits/stdc++.h>
using namespace std;

#define int long long
#define lint long long

#include <"Jairon.h">

#define maxn 1000010
#define SIZE 5010

const int mod = 998244353;

int Mu(int x){
    int res=0;
    for(int i=2;i*i<=x;i++){
        if(x%i==0){
            ++res;
            int cnt=0;
            while(x%i==0){ x/=i; ++cnt; if(cnt>1){ return 0; } }
        }
    } if(x!=1){ ++res; }

    return (res&1)?(-1):1;
}

int n,m;
int a[maxn];

signed main(){
    int T=read(_);
    while(T--){
        bool flag=true;
        read(n,m);
        FOR(i,1,n){ read(a[i]); }
        FOR(i,1,n-1){ if(a[i]%a[i+1]!=0){ flag=false; break; } }
        if(flag==false){ puts("0"); continue; }

        int ans=1;

        FOR(i,1,n-1){
            int tmp=0;
            int x=a[i]/a[i+1];
            for(int j=1;j*j<=x;j++){
                if(x%j==0){
                    int d1=j;
                    int d2=x/j;
                    tmp=(tmp+1ll*Mu(d1)*(m/a[i+1]/d1)%mod+mod)%mod;
                    if(d1!=d2){
                        tmp=(tmp+1ll*Mu(d2)*(m/a[i+1]/d2)%mod+mod)%mod;
                    }
                }
            } ans=1ll*ans*tmp%mod;
        } assi(),outn(ans);
    }
    return 0;
}

/**/