题解:CF1749D Counting Arrays

· · 题解

luogu link

codeforces link

Solution

因为直接求模糊的数组很难,考虑正难则反,求出不模糊的数组数,然后用总数减去它得到答案。

先求总数 total

total=\sum_{i=1}^{n} {m^i}=\frac{m^{n+1} - m}{m - 1}

:::info[证明] 把 total 乘以 m 可知:

\begin{align} total &= \sum_{i=1}^{n+1} m^i \\ m \times total &= \sum_{i=2}^{n+1} m^i \end{align}

所以用二式减去一式得到:

(m-1) \times total = m^{n+1} - m

也就是:

total = \frac{m^{n+1} - m}{m - 1}

:::

然后再求不模糊的数组数量。因为我们知道 \forall x \in \Z,\gcd(x,1)=1,所以肯定有一个移除序列是全 1 的。

对于一个下标为 i 的数 a_i,它会来到下标为 1 \sim i 的位置。

:::info[证明] 每当一个在它前面的数被删除,他就往前一位。 :::

所以要让这一位不破坏唯一性,它肯定要保证 \forall 2 \le j \le i,\gcd(a_i,j) \neq 1

说直接一点,就是 \prod\limits_{p \in P,p \le i} {p} \mid a_i,其中 P 是质数集合,也就是 a_i 要被 2 \sim i 中的所有质数之积整除才满足条件,定义 prod_i = \prod\limits_{p \in P,p \le i}

对于最终式子,定义 dp_i 为长度为 i 的不模糊数组数量,则:

dp_i = dp_{i-1} \times \lfloor \frac{m}{prod_i} \rfloor

边界 dp_1=m,因为第一位没有限制。

然后不模糊数组数就是:

\sum_{i=1}^{n} dp_i

因为 dp,prod 都只需要前一位递推,不模糊数组数可以在循环中算,所以只开了单变量。

Code

:::success[code]

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef char ch;
typedef string str;
typedef double db;
typedef __int128 i128;
const ll inf=9e18,maxn=3e5+1,mod=998244353;
const i128 Inf=1e35;
ll n,m,prod=1,cnt,ans,tot,prime[maxn],total;
bool isprime[maxn];
void get()
{
    memset(isprime,1,sizeof(isprime));
    isprime[1]=0;
    for(int i=2;i<=n;i++)
    {
        if(isprime[i]) prime[++total]=i;
        for(int j=1;j<=total&&i*prime[j]<=n;j++)
        {
            isprime[i*prime[j]]=0;
            if(i%prime[j]==0) break;
        }
    }
}
ll qpow(ll a,ll b)
{
    a%=mod;
    if(b==0) return 1;
    if(b==1) return a;
    ll temp=qpow(a,b/2);
    temp=(temp*temp)%mod;
    if(b%2) temp=(temp*a)%mod;
    return temp;
}
ll inv(ll x)
{
    return qpow(x,mod-2);
}
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    cin>>n>>m;
    get();
    tot=(qpow(m,n+1)-(m%mod)+mod)%mod*inv(m-1)%mod;
    if(m==1) tot=n;
    ans=cnt=m%mod;
    for(int i=2;i<=n;i++)
    {
        if(isprime[i]) prod*=i;
        if(prod>m) break;//到这里,已经没有符合要求的数了,直接结束。
        cnt=cnt*((m/prod)%mod)%mod;
        ans=(ans+cnt)%mod;
    }
    cout<<(tot-ans+mod)%mod;
}

:::

Thanks

感谢 Ivan20121212 提供思路(从开头到 \forall 2 \le j \le i,\gcd(a_i,j) \neq 1)。