题解:CF1967E2 Again Counting Arrays (Hard Version)

· · 题解

题意

给定 n,m,b_0,求有多少整数序列 a_{1\sim n},满足:

答案对 998244353 取模。

多测,1\leq n,m\leq 2\times 10^60\leq b_0\leq 2\times 10^6\sum{n}\leq 10^7

题解

反射容斥好题。

先考虑如何 check 一个 a 是否合法。注意到 b\geq 0 的下界限制,但是没有上界限制。考虑从左往右贪心确定 b_i,若 b_{i-1}+1\neq a_i,则我们一定取 b_i=b_{i-1}+1,否则只能取 b_i=b_{i-1}-1。于是只要不在位置取到 b_i=-1 就是合法的。同时我们也可以注意到,若某个位置取到 b_i=m,则接下来 a_{i+1\sim n} 无论怎么取都是合法的。

正难则反,用总方案数减去不合法的方案数。容易想到一个暴力 DP:令 f_{i,j}\ (0\leq j<m) 表示 b_i=j 对应的 a_{1\sim i} 的方案数。转移考虑是 +1 还是 -1

枚举最小的 i 使得 b_i=-1,答案为 \sum\limits_{i=1}^nf_{i-1,0}\cdot m^{n-i}。时间复杂度为 \mathcal{O}(nm)

这个形式同样是格路计数的形式:还是枚举最早碰到 -1 的位置 t,相当于从 (0,b_0) 开始走,每一步可以向右上或右下走一步,不能碰到 y=my=-1 两条直线,走到 (t-1,0) 的方案数。旋转一下坐标系,把 (0,b_0) 视作原点,每一步可以向上或向右走一步,不能碰到 y=x+(m-b_0)y=x-(b_0+1) 两条直线,走到 \left(\dfrac{t-1+b_0}{2},\dfrac{t-1-b_0}{2}\right) 的方案数。反射容斥直接做即可,注意还要乘上向上走的方案数 (m-1)^{(t-1-b_0)/2}t+1\sim n 随便走的方案数 m^{n-t}。时间复杂度为 \mathcal{O}\left(\dfrac{n^2}{m}\right)

结合两种做法,根号分治一下即可做到 \mathcal{O}(n\sqrt{n})。可以通过 Easy Version。

:::success[Easy Version 的代码]

#include <bits/stdc++.h>

using namespace std;

#define lowbit(x) ((x) & -(x))
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair<int, int> pii;
const int N = 2e5 + 5, MOD = 998244353;

template<typename T> inline void chk_min(T &x, T y) { x = min(x, y); }
template<typename T> inline void chk_max(T &x, T y) { x = max(x, y); }
template<typename T> inline T add(T x, T y) { return x += y, x >= MOD ? x - MOD : x; }
template<typename T> inline T sub(T x, T y) { return x -= y, x < 0 ? x + MOD : x; }
template<typename T> inline void cadd(T &x, T y) { x += y, x < MOD || (x -= MOD); }
template<typename T> inline void csub(T &x, T y) { x -= y, x < 0 && (x += MOD); }

int T, n, m, b0, f[2][N], pw1[N], pw2[N];
int fac[N], ifac[N];

int qpow(int a, int b) {
    int res = 1;
    for (; b; b >>= 1) {
        if (b & 1) res = (ll)res * a % MOD;
        a = (ll)a * a % MOD;
    }
    return res;
}
void prework(int n) {
    fac[0] = 1;
    for (int i = 1; i <= n; ++i) fac[i] = (ll)fac[i - 1] * i % MOD;
    ifac[n] = qpow(fac[n], MOD - 2);
    for (int i = n - 1; ~i; --i) ifac[i] = (ll)ifac[i + 1] * (i + 1) % MOD;
}

int C(int n, int m) {
    return n < 0 || m < 0 || n < m ? 0 : (ll)fac[n] * ifac[m] % MOD * ifac[n - m] % MOD;
}

int solve(int n, int x, int y) {
    if (x + m - b0 <= y || x - b0 - 1 >= y || x + y != n) return 0;
    int tx = x, ty = y, res = C(tx + ty, tx);
    while (tx >= 0 && ty >= 0) {
        swap(tx, ty), tx -= m - b0, ty += m - b0;
        csub(res, C(tx + ty, tx));
        swap(tx, ty), tx -= -b0 - 1, ty += -b0 - 1;
        cadd(res, C(tx + ty, tx));
    }
    tx = x, ty = y;
    while (tx >= 0 && ty >= 0) {
        swap(tx, ty), tx -= -b0 - 1, ty += -b0 - 1;
        csub(res, C(tx + ty, tx));
        swap(tx, ty), tx -= m - b0, ty += m - b0;
        cadd(res, C(tx + ty, tx));
    }
    return (ll)res * pw2[y] % MOD;
}

int main() {
    ios::sync_with_stdio(0), cin.tie(0);
    prework(N - 5);
    cin >> T;
    while (T--) {
        cin >> n >> m >> b0;
        pw1[0] = pw2[0] = 1;
        for (int i = 1; i <= n; ++i)
            pw1[i] = (ll)pw1[i - 1] * m % MOD, pw2[i] = (ll)pw2[i - 1] * (m - 1) % MOD;
        int ans = pw1[n];
        if (b0 >= m) { cout << ans << '\n'; continue; }
        if ((ll)m * m <= n) {
            fill(f[0], f[0] + m + 1, 0), f[0][b0] = 1;
            for (int i = 1; i <= n; ++i) {
                int cur = i & 1, prv = cur ^ 1;
                for (int j = 0; j < m; ++j) {
                    f[cur][j] = f[prv][j + 1];
                    if (j) cadd<int>(f[cur][j], (ll)f[prv][j - 1] * (m - 1) % MOD);
                }
                f[cur][m] = 0;
                csub<int>(ans, (ll)f[prv][0] * pw1[n - i] % MOD);
            }
        } else {
            for (int t = 1; t <= n; ++t) if ((t - 1 + b0 & 1) == 0) {
                int c = solve(t - 1, t - 1 + b0 >> 1, t - 1 - b0 >> 1);
                csub<int>(ans, (ll)c * pw1[n - t] % MOD);
            }
        }
        cout << ans << '\n';
    }
    return 0;
}

:::

DP 做法显然没有前途,考虑如何进一步优化反射容斥的做法。我们发现枚举时刻 t,会使得反射容斥时用到的组合数上指标不同,这基本埋没了进一步优化的可能。

那么能否不去枚举碰到 -1 的时刻呢?其实是可以的。我们发现碰到 -1 之后乘上的 m^{n-t},其实等价于乱走的方案数。也就是说,如果对于 t+1\sim n 时刻的路径,我们依然给向上走赋以 m-1 的权重,向右走赋以 1 的权重,计算出的方案数依旧是对的。所以可以考虑改为枚举在原坐标系中的终点 (n,p),即新坐标系中的终点 \left(\dfrac{n-p+b_0}{2},\dfrac{n+p-b_0}{2}\right)。设 y=x+m-b_0 为直线 Ay=x-(b_0+1) 为直线 B。分类讨论计数:

这样枚举之后,反射容斥计算时用到的组合数上指标就都变为 n 了。以 p\leq -1 为例,可以列出式子:

\sum_{\substack{p\in[b_0-n,b_0+n]\\p\equiv b_0-n\pmod{2}}}(m-1)^{(n+p-b_0)/2}\sum_{k\geq 0}\binom{n}{\dfrac{n+p-b_0-2k(m+1)}{2}}-\binom{n}{\dfrac{n+p+b_0-2m-2k(m+1)}{2}}

显然答案是 \sum\limits_{i=0}^nc_i\binom{n}{i} 的形式,考虑直接维护出 c_{0\sim n}。可以发现两个组合数的下指标都是公差为 -(m+1) 的等差数列,那么我们按 \bmod{(m+1)} 分组,组内差分维护即可。时间复杂度为 \mathcal{O}(n)

注意一个细节:处理 p\geq 0 的部分时,m-1 的指数要取对称前的 \dfrac{n+p-b_0}{2}

:::success[Hard Version 的代码]

#include <bits/stdc++.h>

using namespace std;

#define lowbit(x) ((x) & -(x))
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair<int, int> pii;
const int N = 2e6 + 5, MOD = 998244353;

template<typename T> inline void chk_min(T &x, T y) { x = min(x, y); }
template<typename T> inline void chk_max(T &x, T y) { x = max(x, y); }
template<typename T> inline T add(T x, T y) { return x += y, x >= MOD ? x - MOD : x; }
template<typename T> inline T sub(T x, T y) { return x -= y, x < 0 ? x + MOD : x; }
template<typename T> inline void cadd(T &x, T y) { x += y, x < MOD || (x -= MOD); }
template<typename T> inline void csub(T &x, T y) { x -= y, x < 0 && (x += MOD); }

int T, n, m, b0, d[N];
int pw[N], fac[N], ifac[N];

int qpow(int a, int b) {
    int res = 1;
    for (; b; b >>= 1) {
        if (b & 1) res = (ll)res * a % MOD;
        a = (ll)a * a % MOD;
    }
    return res;
}
void prework(int n) {
    fac[0] = 1;
    for (int i = 1; i <= n; ++i) fac[i] = (ll)fac[i - 1] * i % MOD;
    ifac[n] = qpow(fac[n], MOD - 2);
    for (int i = n - 1; ~i; --i) ifac[i] = (ll)ifac[i + 1] * (i + 1) % MOD;
}

int C(int n, int m) {
    return n < 0 || m < 0 || n < m ? 0 : (ll)fac[n] * ifac[m] % MOD * ifac[n - m] % MOD;
}

int main() {
    ios::sync_with_stdio(0), cin.tie(0);
    prework(N - 5);
    cin >> T;
    while (T--) {
        cin >> n >> m >> b0;
        int ans = qpow(m, n);
        if (b0 >= min(n, m)) { cout << ans << '\n'; continue; }
        pw[0] = 1;
        for (int i = 1; i <= n; ++i) pw[i] = (ll)pw[i - 1] * (m - 1) % MOD;
        fill(d, d + n + 1, 0);
        for (int p = b0 - n; p < 0; p += 2) {
            int x = n + p - b0 >> 1, v = pw[n + p - b0 >> 1];
            if (x >= 0) {
                cadd(d[x % (m + 1)], v);
                if (x + m + 1 <= n) csub(d[x + m + 1], v);
            }
            x = n + p + b0 - (m << 1) >> 1;
            if (x >= 0) {
                csub(d[x % (m + 1)], v);
                if (x + m + 1 <= n) cadd(d[x + m + 1], v);
            }
        }
        for (int p = b0 + n & 1; p <= b0 + n; p += 2) {
            int np = -2 - p;
            int x = n + np - b0 >> 1, v = pw[n + p - b0 >> 1];
            if (x >= 0) {
                cadd(d[x % (m + 1)], v);
                if (x + m + 1 <= n) csub(d[x + m + 1], v);
            }
            x = n + np + b0 - (m << 1) >> 1;
            if (x >= 0) {
                csub(d[x % (m + 1)], v);
                if (x + m + 1 <= n) cadd(d[x + m + 1], v);
            }
        }
        for (int i = m + 1; i <= n; ++i) cadd(d[i], d[i - m - 1]);
        for (int i = 0; i <= n; ++i) csub<int>(ans, (ll)d[i] * C(n, i) % MOD);
        cout << ans << '\n';
    }
    return 0;
}

:::