题解 P2183 【[国家集训队]礼物】

· · 题解

没人在这里发这题题解啊,我来发一篇。

明显答案是C_n^{w_1}\times C_{n-w_1}^{w_2} \times \cdots,如果要求的礼物总数超过了买的礼物,则直接输出Impossible

求解一个组合数C_n^mp_i^{k_i}的过程可以分别求分子和分母,然后再进行运算,即,求出n!\mod p_i^{k_i}m!\mod p_i^{k_i}(n-m)!\mod p_i^{k_i}后合并。注意分子可以含有大于k_ip_i,这样的话求出来就直接为0了,而除掉分母的时候是可能会除掉一些p_i的;同时分母本身也可能有含有p_i的约数,于是统一把p_i移除来,不乘进答案,接下来用分子有的p_i的质数减掉分母的,然后乘上去即可。不可能出现减掉后为负数的情况,因为那样意味着组合数除不尽;有可能出现减掉后仍然大于k_i的情况,那么这个组合数在模p_i^{k_i}下为0

具体地,对于求n! \mod p_i^{k_i}可以分为几个部分求解。将每个p_i的倍数提出一个p_i来,组成p_i^{\lfloor \frac{n}{p_i} \rfloor};对于剩余的部分,容易发现除掉了p_i的数(原来是p_i的倍数)组成了一个新的阶乘,即\lfloor \frac{n}{p_i} \rfloor !,这个我们可以递归求解。剩下的数形如1,2,\cdots,p-1,p+1,p+2,\cdots,2p-1 \cdots(因为p_i的倍数已经通过前两步瓜分到阶乘和次幂里去了),由于a \times b \mod p=(a \mod p) \times (b \mod p),易得\prod _{i=1}^{p_i^{k_i}-1}i[i \mod p \neq 0]\prod _{i=p_i^{k_i}+1}^{2p_i^{k_i}-1}i[i \mod p \neq 0] \mod p_i^{k_i}下是相等的,那么只需求一个p^k的长度的乘积在模p_i^{k_i}下的结果即可,而这个阶乘里包含了 \lfloor \frac{n}{p_i^{k_i}} \rfloor个这个余数,快速幂即可;还有剩下来的、没有达到p_i^{k_i}的一段数,由于长度不会超过p_i^{k_i}(即10^5),直接计算即可。

考虑求解答案,因为模的是一个和数,所以不能直接用普通的lucas求解。将和数P分解,得到P=\prod_{i=1}^{k}p_i^{k_i},我们要求的答案模p_i^{k_i}的结果,对于单独求组合数在模每个p_i^{k_i}的剩余系下的结果要相同。于是考虑对于每个组合数,单独求出对p_i^{k_i}的结果,然后通过中国剩余定理求解即可。

代码

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=6;
const int maxp=(int)1e5+5;
int P,n,m,w[maxn],pri[maxp],cnt[maxp],tot=0,c[maxp],R[maxp],M[maxp];
map<int ,int > re;
inline int maybs(int x) {return x<0?-x:x;}
void div(int x)
{
    for(int i=2;i*i<=x;i++) {
        if(x%i==0) {
            pri[++tot]=i;re[i]=tot;
            while(x%i==0) x/=i,cnt[tot]++;
        }
    }
    if(x!=1) {
        if(re.find(x)==re.end()) pri[++tot]=x,re[x]=tot;
        cnt[re[x]]++;
    }
}
int quickpow(int x,int k)
{
    int ret=1;
    while(k>0) {
        if(k&1) ret=ret*x;
        x=x*x; k>>=1;
    }
    return ret;
}
int quickpow(int x,int k,int mo)
{
    int ret=1;
    while(k>0) {
        if(k&1) ret=1ll*ret*x%mo;
        x=1ll*x*x%mo; k>>=1;
    }
    return ret;
}
int gcd(int a,int b) {return b==0?a:gcd(b,a%b);}
int extgcd(int a,int b,int &x,int &y)
{
    if(b==0) {x=1,y=0; return a;}
    else {int d=extgcd(b,a%b,y,x);y-=a/b*x;return d;}
}
int inv(int q,int p)
{
    assert(gcd(p,q)==1);
    int x,y;
    extgcd(q,p,x,y);
    x=(x%p+p)%p; if(!x) x+=p;
    return x;
}
int Cnt(int x,int p) {
    if(x==0) return 0;
    return Cnt(x/p,p)+x/p;
}
typedef pair<int ,int > pii;
#define fir first
#define sec second
#define MP make_pair
pii Cal(int x,int p,int k)
{
    if(x==1 || x==0) return MP(1,0);
    int ret=1,del=0,tmp=quickpow(p,k);
    int cou=Cnt(x,p);
    del=cou;

    pii nw=Cal(x/p,p,k);
    ret=1ll*ret*nw.fir%tmp;

    if(x>=tmp) {
        int lev=1;
        for(int i=1;i<tmp;i++) {
            if(i%p==0) continue;
            lev=1ll*lev*i%tmp;
        }
        ret=1ll*ret*quickpow(lev,x/tmp,tmp)%tmp;
    }

    for(int i=x;i>=1;i--) {
        if(i%tmp==0) break;
        if(i%p==0) continue;
        ret=1ll*ret*i%tmp;
    }
    return MP(ret,del);
}
int getAns()
{
    for(int i=1;i<=tot;i++) M[i]=1;
    for(int i=1;i<=tot;i++) {
        for(int j=1;j<=tot;j++)
            if(i==j) continue;
            else M[i]=1ll*M[i]*R[j]%P;
    }
    int ret=0;
    for(int i=1;i<=tot;i++) ret=(ret+((1ll*c[i]*M[i]%P)*1ll*inv(M[i],R[i]))%P)%P;
    return ret;
}
int main()
{
    scanf("%d",&P);
    div(P);
    scanf("%d%d",&n,&m);
    int sum=0;
    for(int i=1;i<=m;i++) {scanf("%d",&w[i]);sum+=w[i];}
    if(sum>n) {printf("Impossible\n");exit(0);}
    int ans=1,N=n,M,res=1;
    pii up,dw1,dw2;
    for(int i=1;i<=m;i++) {
        M=w[i];
        for(int j=1;j<=tot;j++) R[j]=quickpow(pri[j],cnt[j]);
        for(int j=1;j<=tot;j++) {
            ans=1;
            up=Cal(N,pri[j],cnt[j]);
            dw1=Cal(M,pri[j],cnt[j]),dw2=Cal(N-M,pri[j],cnt[j]);
            up.sec-=dw1.sec; up.sec-=dw2.sec;
            assert(up.sec>=0);

            if(up.sec>=cnt[j]) ans=0;
            else {
                ans=1ll*ans*up.fir%R[j];
                ans=1ll*ans*inv(dw1.fir,R[j])%R[j]; ans=1ll*ans*inv(dw2.fir,R[j])%R[j];
                ans=1ll*ans*quickpow(pri[j],up.sec,R[j])%R[j];
            }
            c[j]=ans;
        }
        res=1ll*res*getAns()%P;
        N-=w[i];
    }
    printf("%d\n",res);
    return 0;
}