题解 P5481 【[BJOI2015] 糖果】

· · 题解

糖果

思路

$\ \ \ \ \ \ $ 那怎么算$s$呢?我们可以把这个东西转换一下模型,变成:把$m$个小球放到$k$个盒子中,盒子可以为空的方案数。 $\ \ \ \ \ \ $ 为什么两个问题是等价的呢?这里简单说明一下,可以想象成,一个盒子$a$里面放有$b$个小球,就相当于在序列中出现了$b$个$a$。 $\ \ \ \ \ \ $ 那上面那个模型怎么算呢?很简单,首先考虑到盒子可以为空的缘故,我们先一开始每个盒子都放入一个小球,然后就相当于把$m+k$放入$k$个盒子中每个盒子必放的方案数,这个直接用插板法可以算出: $$s=C_{m+k-1}^{k-1}=\frac{1}{m!}\prod_{i=0}^{m-1} (k+i)$$ $\ \ \ \ \ \ $ 但是这道题的$p$不一定为质数,所以我们不能直接用逆元。那我们怎么算$s$呢?其实,我们可以直接分解质因数,然后分子分母抵消就好了。算分母的时候因为是阶乘所以可以用$\text {Legendre}$定理优化一下,但是对总时间复杂度还是不影响的。 $\ \ \ \ \ \ $ 然后。。。就完了 ## $\text {Code}
#include <bits/stdc++.h>
using namespace std;

#define Int register int
#define ll long long
#define MAXN 100005

int tot;
int prime[MAXN];

bool vis[MAXN];

void Prime (int n)
{
    for (Int i = 2;i <= n;++ i)
    {
        if (!vis[i]) prime[++ tot] = i;
        for (Int j = 1;j <= tot && (int)i * prime[j] <= n;++ j)
        {
            vis[i * prime[j]] = 1;
            if (i % prime[j] == 0) break;
        }
    }
}

ll up[MAXN],Index[MAXN];

int n,m,k,p;

int read ()
{
    int x = 0;char c = getchar();int f = 1;
    while (c < '0' || c > '9'){if (c == '-') f = -f;c = getchar();}
    while (c >= '0' && c <= '9'){x = (x << 3) + (x << 1) + c - '0';c = getchar();}
    return x * f;
}

void write (int x)
{
    if (x < 0){x = -x;putchar ('-');}
    if (x > 9) write (x / 10);
    putchar (x % 10 + '0');
}

signed main()
{
    n = read (),m = read (),k = read (),p = read ();
    if (!p) return puts ("0"),0;
    for (Int i = 1;i <= m;++ i) up[i] = m + k - i;
    Prime (m);
    for (Int i = 1;i <= tot && prime[i] <= m;++ i)
    {
        for (Int j = prime[i];j <= m;j *= prime[i])
            Index[i] += m / j;
    }
    for (Int i = 1;i <= tot;++ i)
    {
        int s = m + k - 1 - (m + k - 1) % prime[i];
        for (Int j = m + k - s;j <= m;j += prime[i])
        {
            while (Index[i] && up[j] % prime[i] == 0) Index[i] --,up[j] /= prime[i];
            if (!Index[i]) break;
        }
    }
    int s = 1,sum = 1;
    for (Int i = 1;i <= m;++ i) s = (ll)s * up[i] % p;
    for (Int i = 1;i <= n;++ i) sum = (ll)sum * (s - i + 1) % p;
    write (sum),putchar ('\n');
    return 0;
}