题解:CF1794D Counting Factorizations

· · 题解

upd:2025-03-05 修改了一些错误的内容

upd:2025-03-06 又修改了一些错误的内容。。。

upd:2025-11-03 修改了多重集组合公式推导

传送门

题意

将某一个数分解质因数,将质因数和指数存入一个数组 a 里(长度为 2n),求出这个数组中的数分别作质因数或指数时能构成的数的数量(质因数和指数必须两两配对)。

思路

1h 场切,很有意思的 dp。

首先明确两个性质:

根据题目,每个质数最多选一次作质因数,我们假设质因数已经确定,还剩 k 种数,每种数有 r_k 个,考虑方案数。

根据多重组合:

\begin{aligned} \begin{pmatrix} n\\ r_{1},r_{2},\dots,r_{k}\\ \end{pmatrix} &= \begin{pmatrix} n\\ r_{1}\\ \end{pmatrix} + \begin{pmatrix} n-r_{1}\\ r_{2}\\ \end{pmatrix} + \cdots + \begin{pmatrix} n-(r_{1}+r_{2}+\dots+r_{k-1})\\ r_k\\ \end{pmatrix} \\&= \frac{n!}{r_{1}!(n-r_{1})!} \times \frac{(n-r_{1})!}{r_{2}!(n-(r_{1}+r_{2}))!} \times \cdots \times \frac{(n-(r_{1}+r_{2}+\dots+r_{k-1}))!}{r_{k}!} \\&= \frac{n!}{\prod r_{i}!} \end{aligned}

所以只需要确定质因数再计算上面这个式子就可以算出所有方案数了。但 n \le 2022,找出所有质因数的情况数量非常大,无法枚举,考虑用 dp 维护质因数个数。

定义:dp_{i} 表示目前选了 i 个数作质因数(其他遍历过的数全作指数)时的方案数。

线性筛求出所有质数,将数组 a 中的数去重并记录个数,注意到 mod=998244353,考虑用乘法逆元代替除法,这样我们可以在转移的同时计算将剩下的数当作指数分配给 n 个位置的方案数(可以将 dp_{0} 初始化为 n! 或在最后将答案乘 n!)。

其中 j 要类似 01 背包的方式从大往小枚举,ny_k 表示 k! 的乘法逆元,sum_i 表示第 i 种数的个数,时间复杂度 O(\max_{i=1}^{n}a_i + n^2)据说还有更快的做法(逃

code

#include<bits/stdc++.h>
#define int long long
using namespace std;
int n,m,no=0,mod=998244353,jc[6005],ny[6005],dp[6005],zs[1000005],sum=0;
bool z[1000005]={0},nz[1000005]={0};
struct node
{
    int x,sum;
}a[6005];
int ksm(int x,int y)
{
    int now=1;
    while(y!=0)
    {
        if(y%2==1)
        {
            now*=x;
            now%=mod;
        }
        x*=x;
        x%=mod;
        y/=2;
    }
    return now;
}
bool cmp(node x,node y)
{
    return x.x<y.x;
}
signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    jc[0]=1;
    for(int i=1;i<=6000;i++)
    {
        jc[i]=(jc[i-1]*i)%mod;
    }
    for(int i=0;i<=6000;i++)
    {
        ny[i]=ksm(jc[i],mod-2);
    }
    cin>>n;
    n*=2;
    for(int i=1;i<=n;i++)
    {
        cin>>a[i].x;
        a[i].sum=1;
    }
    int m=0;
    for(int i=2;i<=1e6;i++)
    {
        if(z[i]==0)
        {
            m++;
            zs[m]=i;
        }
        z[i]=1;
        for(int j=1;j<=m;j++)
        {
            if(i*zs[j]>1e6)
            {
                break;
            }
            z[i*zs[j]]=1;
            if(i%zs[j]==0)
            {
                break;
            }
        }
    }
    for(int i=1;i<=m;i++)
    {
        nz[zs[i]]=1;
    }
    sort(a+1,a+n+1,cmp);
    no=n;
    for(int i=1;i<=n;i++)
    {
        if(a[i].x==a[i+1].x)
        {
            no--;
            a[i+1].sum=a[i].sum+1;
            a[i].x=INT_MAX;
        }
    }
    sort(a+1,a+n+1,cmp);
    n/=2;
    dp[0]=1;
    for(int i=1;i<=no;i++)
    {
        if(nz[a[i].x]==0)
        {
            for(int j=0;j<=n;j++)
            {
                dp[j]=dp[j]*ny[a[i].sum];
                dp[j]%=mod;
            }
        }
        else
        {
            for(int j=n;j>=0;j--)
            {
                if(j!=n)
                {
                    dp[j+1]+=dp[j]*ny[a[i].sum-1];
                    dp[j+1]%=mod;
                }
                dp[j]=dp[j]*ny[a[i].sum];
                dp[j]%=mod;
            }
        }
    }
    cout<<(dp[n]*jc[n])%mod;
    return 0;
}