题解:P12181 DerrickLo's Buildings (UBC002D)

· · 题解

我们可以对 i 分类讨论。当 i \ge 2 时,所有的情况都相似,仅有一个系数的差别。

因为 v 只有 M! 种取值,所以我们先化简 E \times M!,最后再除掉就行了。也就是,我们要计算:

  1. \sum_{k=1}^V\left(\sum_{v \text{ is a permutation of }[M]}\sum_{j=1}^N[v_k(j) = j]\right)
  2. \sum_{k=1}^V\left(\sum_{v \text{ is a permutation of }[M]}\sum_{j=1}^N[v_k(j) = 2j]\right)

由期望的线性性,所有合理范围内的 j 贡献都是一样的。因此,第一个式子中,我们只需算出 1 的贡献,乘上 N 即可。第二个式子中,算出 1 的贡献,乘上 \min\left\{N, \frac{M}{2}\right\} 即可。

第一个式子

定义排列 v 上的一个环为由不同数字组成的序列 \{a_1, a_2, \dots, a_m\},使得 v(a_1) = a_2, v(a_2) = a_3, \dots, v(a_m) = a_1,其中 m 是环长。可以发现,一个排列总能被唯一划分成若干个环。

那么,当 k 确定时,问题可以描述为:“有多少个排列 v1 所在的环的长度是 k 的因数”。

则,我们可以枚举环长 m,则对于所有的 m 的倍数 k,它的贡献是 A_{M - 1}^{m - 1} \times (M - m)! = (M - 1)!。这个值和 m 无关,所以 m 的贡献是 \lfloor\frac{V}{m}\rfloor \times (M - 1)!。那么,最终和式的值就是 (M - 1)! \times \sum_{m=1}^M\lfloor\frac{V}{m}\rfloor

除去 M! 之后,E = \frac{N}{M} \times \sum_{m=1}^M\lfloor\frac{V}{m}\rfloor,可以 O(\sqrt{V}) 计算。

第二个式子

我们只需要 12 在同一个环即可。当 km 确定的时候,12 在环内的相对位置是固定的。将 1 设定为 a_1,若 2 位于 a_j,则 j 就必须满足 j - 1 \equiv k \pmod m。这也说明 k 不能是 m 的倍数。

k 不为 m 的倍数时,有 A_{M - 2}^{m - 2} \times (M - m)! = (M - 2)! 个这样的排列。这样的 k 一共有 V -\lfloor\frac{V}{m}\rfloor 个,因此,m 的贡献就是 (M - 2)! \times (V - \lfloor\frac{V}{m}\rfloor)。答案即 \frac{\min\{N, \lfloor\frac{M}{2}\rfloor\}}{M(M - 1)} \times \left(MV - \sum_{m=1}^M\lfloor\frac{V}{m}\rfloor\right)。同样可以在 O(\sqrt{V}) 的时间内算完。

i \ge 3 时,这个式子就是 \frac{\min\{N, \lfloor\frac{M}{i}\rfloor\}}{M(M - 1)} \times (MV - \sum^{M}_{m = 1}\lfloor\frac{V}{m}\rfloor),后面的一项和 i 无关,前面的同样在 O(\sqrt{M}) 的复杂度内算出来。

更新:有个式子写错了。

std

#include <iostream>
using namespace std;

#define MOD 998244353ll

using ll = long long;

ll pow(ll b, ll p, ll m)
{
    b %= MOD;
    ll r = 1;
    while (p)
    {
        if (p & 1)
        {
            r = r * b % m;
        }
        b = b * b % m;
        p >>= 1;
    }
    return r;
}

ll inv(ll p)
{
    return pow(p, MOD - 2, MOD);
}

int main()
{
    int t;
    cin >> t;
    while (t--)
    {
        ll n, m, v;
        cin >> n >> m >> v;
        ll s = 0;
        for (ll l = 1, r; l <= m && l <= v; l = r + 1)
        {
            r = min(m, v / (v / l));
            s = (s + (r - l + 1) % MOD * (v / l % MOD) % MOD) % MOD;
        }
        ll mid = m / n;
        ll res = (mid - 1) % MOD * (n % MOD) % MOD;
        for (ll l = mid + 1, r; l <= m; l = r + 1)
        {
            r = m / (m / l);
            res = (res + (r - l + 1) % MOD * (m / l % MOD) % MOD) % MOD;
        }
        cout << (res * (((m % MOD) * (v % MOD) % MOD - s + MOD) % MOD) % MOD * inv(m % MOD * ((m - 1) % MOD) % MOD) % MOD + n % MOD * inv(m % MOD) % MOD * s % MOD) % MOD << endl;
    }

    return 0;
}