题解 P5689 【[CSP-SJX2019]多叉堆】

· · 题解

写在前面

数据结构

只有根节点的答案有用。任何一个节点在更新完其父亲结点的值后,其本身的任何值将不会再有任何改动或贡献,因此用并查集维护即可,记得路径压缩。

算法思路

解决这道题的关键,在于连接两个节点时答案的更新和与答案相关的值的维护。

a_i 表示以 i 为根节点的子树的填数方案总数。

假设当前连接的两个节点的编号分别是 u,v(保证两个节点都是根节点),且本次操作需要从 u 接到 v 上去。

比较显而易见的一点是,更新答案时,节点 v 所处的位置,一定只能填 0

维护 w_i 表示以 i 为根的树的重量,即节点数。将 u 连接到 v 上时,将 w_v 的值加上 w_u

那么填入原来的 u 子树中的数字就有 w_v - 1w_u 种不同的选择方案。剩下的数字填入原来的 v 子树中。又原本 u,v 子树中的填数总方案数分别为 a_u,a_v。那么根据乘法原理,就可以将 a_v 更新为 a_v \times a_u \times \binom{n}{m}

组合数求法

  1. 杨辉三角,加法,可以取模。但是这是 P2822 的组合数求法,需要用到二维数组,数组大小开不下,只能过 50% 的数据。

  2. 直接根据组合数计算公式:

\binom{n}{m} = \frac{n!}{m!(n - m)!}

直接用计算式来求组合数。预处理出阶乘。等下,模意义下的乘法……要求逆元啊。线性递推求逆元可以参考 P3811 的题解。

这道题的除数是可以大到 10^9 + 7 的,因此不能直接递推预处理逆元,但是只会用到阶乘的逆元,预处理这个即可。

那么怎么线性递推呢?看这里:(为了更好体现逆元的形式,式子中保证了分数分子上都是 1

\frac{1}{i!} = \frac{1}{i\cdot(i - 1)!} = \frac{1}{i}\cdot\frac{1}{(i - 1)!} \frac{1}{(i - 1)!} = i\cdot\frac{1}{i!}

因此,首先根据费马小定理推论用快速幂算出 \frac{1}{n!} 随后倒序枚举预处理即可。

Tips

Code

代码可读性还是很高的。

#include<bits/stdc++.h>
#define LL long long

using namespace std;

const int Maxn = 3e5 + 5;
const LL Mod = 1e9 + 7;

int n, q, opt, x, y;
int Ans;
LL a[Maxn] = {1};
int w[Maxn] = {1}; 

/*快速幂*/ 
inline LL qpow(LL b, LL p)
{
    if(p == 0) return 1;
    LL x = 1;
    for(; p;b *= b, b %= Mod,p >>= 1) if(p & 1) x *= b, x %= Mod;
    return x;
}

/*并查集*/
int fa[Maxn];

int find(int t)
{
    return fa[t] == t ? t : fa[t] = find(fa[t]);
}

/*数学*/
LL fac[Maxn] = {1};
LL invf[Maxn];

LL C(int N, int M)
{
    return 1ll * fac[N] % Mod * invf[M] % Mod * invf[N - M] % Mod;
}

/*预处理*/ 
void Setup()
{
    for(register int i = 1; i <= n; ++i)
    {
        fa[i] = i;
        w[i] = a[i] = 1;
        fac[i] = fac[i - 1] * i;
        fac[i] %= Mod;
    }
    invf[n] = qpow(fac[n], Mod - 2);
    for(register int i = n - 1; i >= 0; --i)
    {
        invf[i] = invf[i + 1] * (i + 1);
        invf[i] %= Mod;
    }
}

/*链接 处理答案*/
void line(int u, int v)
{
    w[v] += w[u];
    a[v] = a[v] * a[u] % Mod * C(w[v] - 1, w[u]) % Mod;
    fa[u] = v;
}

/*输出答案 更新强制在线值*/
void print(int t)
{
    Ans = (int)a[t];
    printf("%d\n", Ans);
} 

/*快速读入*/ 
inline int read()
{
    int f = 1, w = 0; char ch = getchar();
    for(; (ch < '0') || (ch > '9'); ch = getchar()) if(ch == '-') f = -1;
    for(; (ch >= '0') && (ch <= '9'); ch = getchar()) w = (w << 3) + (w << 1) + (int)(ch ^ '0');
    return f * w;
}

int main()
{
    n = read(); q = read();
    Setup();
    while(q--)
    {
        opt = read();
        if(opt == 1)
        {
            x = (read() + Ans) % n;
            y = (read() + Ans) % n;
            x = find(x); y = find(y);
            line(x, y); 
        }
        else
        {
            x = (read() + Ans) % n;
            x = find(x);
            print(x);
        }
    }
    return 0;
}