题解:P14072 热恋

· · 题解

数据范围何意味?

首先注意到前半段和后半段至少有一个部分的积是 k 的倍数,所以相当于两段都要是 k 的倍数。正着做不太好做,考虑数有多少个排列不合法,那只要数 \{1,2,\cdots,2n\} 有多少个大小为 n 的子集使得它的乘积不是 k 的倍数。

k=\prod p_i^{\alpha_i},g_S=\prod_{i\in S}p_i,h_S=\prod_{i\in S}p_i^{\alpha_i}

考虑莫反/容斥,答案为 \sum_{S\neq \empty}(-1)^{|S|}f(\frac{h_S}{g_S}),其中 f(x) 表示有多少个大小为 n 的子集使得它的乘积和 k\gcdx 的因数。这也相当于给 S 内的质因子加了一个上界。考虑设计一个 dp,f_{i,j} 表示共选了 i 个数,目前的乘积和 h_S\gcd=j 的方案数。然后可以把 1\sim2n 的所有数按照 \gcd(i,h_S) 分类,这样的话每一类的转移是一个卷积,看起来复杂度很高。

但是考虑忽略 \gcd=1 的类,那么剩下的部分只会选 O(\log k) 个,这样可以每个类暴力枚举选几个数,然后直接转移即可,要乘上一个组合数作为系数。

复杂度不太会算,一个很松的上界是 O(n+\sigma^3(k)\log^2k),大概达不到这个复杂度,其中 \sigma(x) 表示 x 的因数个数。常数很小,跑的很快,感觉数据范围可以加 0。

这个实现不太精细。

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int mod=998244353,N=2e5;
int n,k,cnt[N+5],nc[N+5],id[N+5],cntp,pr[20],f[__lg(N)*2+2][N>>4],
g[__lg(N)*2+2][N>>4],nd[N+5],d[N>>4],tot,fac[N+5],inv[N+5];
int Pow(int x,int y){
    int res=1;
    for(;y;y>>=1,x=1ll*x*x%mod)
        if(y&1)
            res=1ll*res*x%mod;
    return res;
}
int calc(int S){
    int c=0,p[20],vp=1,up=1,s=0,cp=0;
    for(int i=1;i<=cntp;i++)
        if(S&(1<<i-1)){
            p[++cp]=pr[i];
            int j=k;
            while(!(j%pr[i])) j/=pr[i],vp*=pr[i],up*=pr[i];
            up/=pr[i];
        }
    for(int i=1;i<=tot;i++)
        if(!(up%d[i]))
            id[d[i]]=++c,nd[c]=d[i];
    memset(nc,0,sizeof(nc));
    for(int i=1;i<=tot;i++)
        if(up%__gcd(d[i],vp)==0)
            nc[id[__gcd(d[i],vp)]]+=cnt[d[i]];
    memset(f,0,sizeof(f)),f[0][1]=1;
    int ns=0;
    for(int i=2;i<=c;i++){
        memset(g,0,sizeof(g));
        int ct=0;
        for(int j=0,pr=1;j<=nc[i]&&up%pr==0;j++){
            ct=j;
            for(int k=0;k<=ns;k++)
                for(int o=1;o<=c;o++)
                    if(up%(pr*nd[o])==0)
                        g[j+k][id[pr*nd[o]]]=(g[j+k][id[pr*nd[o]]]+
                        1ll*f[k][o]*fac[nc[i]]%mod*inv[j]%mod*inv[nc[i]-j])%mod;
            pr*=nd[i];
        }
        ns+=ct,ns=min(ns,__lg(k)); 
        memcpy(f,g,sizeof(f));
    }
    int res=0;
    for(int i=max(0,n-nc[1]);i<=min(n,ns);i++){
        int r=0;
        for(int j=1;j<=c;j++) r=(r+f[i][j])%mod;
        res=(res+1ll*fac[nc[1]]%mod*inv[n-i]%mod*inv[nc[1]-n+i]%mod*r)%mod;
    }
    return res;
}
int main(){
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    cin>>n>>k,fac[0]=1;
    for(int i=1;i<=2*n;i++)
        cnt[__gcd(i,k)]++,fac[i]=1ll*fac[i-1]*i%mod;
    inv[2*n]=Pow(fac[2*n],mod-2);
    for(int i=2*n-1;~i;i--) inv[i]=1ll*inv[i+1]*(i+1)%mod;
    for(int i=1;i<=k;i++)
        if(cnt[i])
            d[++tot]=i;
    int kk=k;
    for(int i=2;i<=tot&&d[i]*d[i]<=kk;i++)
        if(!(kk%d[i])){
            pr[++cntp]=d[i];
            while(!(kk%d[i])) kk/=d[i];
        }
    if(kk>1) pr[++cntp]=kk;
    int ans=0;
    for(int i=1;i<(1<<cntp);i++)
        ans=(ans+(__builtin_popcount(i)&1?-1:1)*calc(i))%mod;
    ans=2ll*ans*fac[n]%mod*fac[n]%mod,ans=(fac[2*n]+ans)%mod;
    cout<<(ans+mod)%mod<<'\n';
    return 0;
}