题解:CF439E Devu and Birthday Celebration

· · 题解

很标准的一道容斥题。
首先我们研究这个 \gcd(a_1,a_2,...,a_f)= 1,直接算公因数为 1 显然很难算,考虑逆向思维,用序列总数减去 \gcd(a_1,a_2,...,a_f)\ne 1 的序列个数。
序列总数很好算,简单的插板法 \binom{n-1}{f-1} 即可算出。
那这个不合法部分怎么算呢?显然我们可以枚举 \gcd 的值,然而不同值的贡献有重合部分,所以就要容斥。推导这个容斥许多题解用了莫比乌斯函数,这里我用通俗的语言来解释一下。容易想到我们算了 2 的情况,那么 2 的幂次显然就不用算了,也就是说我们只需要算 n 的每个质因数的贡献。但是直接加 23 的贡献会多算 6 的贡献,考虑容斥模型 \lvert \bigcup_{i=1}^{n} S_i \rvert =\sum_{m=1}^{n} (-1)^{m-1} \sum_{a_i<a_{i+1}} \lvert \bigcap_{i=1}^{m} S_{a_i} \rvert

至于具体贡献大小,仍然是插板法 \binom{\frac{n}{p}-1}{f-1}
容斥系数在做线性筛的时候算就好了,代码很好写:

#include<bits/stdc++.h>
#define int long long
#define db double
#define maxn 1000005
#define mod 1000000007
#define fir first
#define sec second
#define pr pair<int,int>
#define pb push_back
#define mk make_pair
#define inf 10000000000000000
using namespace std;
inline int read()
{
    int SS=0,WW=1;
    char ch=getchar();
    while(ch<'0'||ch>'9')
    {
        if(ch=='-')WW=-1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9')
    {
        SS=(SS<<1)+(SS<<3)+(ch^48);
        ch=getchar();
    }
    return SS*WW;
}
inline void write(int XX)
{
    if(XX<0)putchar('-'),XX=-XX;
    if(XX>9)write(XX/10);
    putchar(XX%10+'0');
}
int T,n,m,ans,fac[maxn],ifac[maxn],p[maxn],cnt,mu[maxn];
bool ip[maxn];
int ksm(int b,int p)
{
    int s=1;
    while(p)
    {
        if(p&1)s=s*b%mod;
        b=b*b%mod,p>>=1;
    }
    return s;
}
int C(int x,int y)
{
    if(x>y)return 0;
    return fac[y]*ifac[x]%mod*ifac[y-x]%mod;
}
void prime()
{
    mu[1]=1; 
    for(int i=2;i<=100000;i++)
    {
        if(!ip[i])p[++cnt]=i,mu[i]=-1;
        for(int j=1;j<=cnt&&i*p[j]<=100000;j++)
        {
            ip[i*p[j]]=1;
            if(i%p[j]==0)break;
            mu[i*p[j]]=-mu[i];
        }
    }
}
signed main()
{
    prime();
    fac[0]=ifac[0]=1;
    for(int i=1;i<=100000;i++)fac[i]=fac[i-1]*i%mod,ifac[i]=ksm(fac[i],mod-2);
    for(T=read();T;T--)
    {
        n=read(),m=read(),ans=0;
        for(int i=1;i*i<=n;i++)
        {
            if(n%i==0)
            {
                ans=(ans+mu[i]*C(m-1,n/i-1)%mod)%mod;
                if(i!=n/i)ans=(ans+mu[n/i]*C(m-1,i-1))%mod;
            }
        }
        write((ans+mod)%mod),puts("");
    }
    return 0;
}