题解:CF1967E2 Again Counting Arrays (Hard Version)
题意
给定
-
- 存在非负整数序列
b_{0\sim n} ,使得\forall 1\leq i\leq n,|b_i-b_{i-1}|=1\land b_i\neq a_i 。
答案对
多测,
题解
反射容斥好题。
先考虑如何 check 一个
正难则反,用总方案数减去不合法的方案数。容易想到一个暴力 DP:令
- 若
b_i=b_{i-1}+1 ,则a_i 可以取[1,m] 内不为b_i 的任意整数,于是f_{i,j}\gets (m-1)f_{i-1,j-1} 。 - 若
b_i=b_{i-1}-1 ,则a_i=b_i ,于是f_{i,j}\gets f_{i-1,j+1} 。
枚举最小的
这个形式同样是格路计数的形式:还是枚举最早碰到
结合两种做法,根号分治一下即可做到
:::success[Easy Version 的代码]
#include <bits/stdc++.h>
using namespace std;
#define lowbit(x) ((x) & -(x))
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair<int, int> pii;
const int N = 2e5 + 5, MOD = 998244353;
template<typename T> inline void chk_min(T &x, T y) { x = min(x, y); }
template<typename T> inline void chk_max(T &x, T y) { x = max(x, y); }
template<typename T> inline T add(T x, T y) { return x += y, x >= MOD ? x - MOD : x; }
template<typename T> inline T sub(T x, T y) { return x -= y, x < 0 ? x + MOD : x; }
template<typename T> inline void cadd(T &x, T y) { x += y, x < MOD || (x -= MOD); }
template<typename T> inline void csub(T &x, T y) { x -= y, x < 0 && (x += MOD); }
int T, n, m, b0, f[2][N], pw1[N], pw2[N];
int fac[N], ifac[N];
int qpow(int a, int b) {
int res = 1;
for (; b; b >>= 1) {
if (b & 1) res = (ll)res * a % MOD;
a = (ll)a * a % MOD;
}
return res;
}
void prework(int n) {
fac[0] = 1;
for (int i = 1; i <= n; ++i) fac[i] = (ll)fac[i - 1] * i % MOD;
ifac[n] = qpow(fac[n], MOD - 2);
for (int i = n - 1; ~i; --i) ifac[i] = (ll)ifac[i + 1] * (i + 1) % MOD;
}
int C(int n, int m) {
return n < 0 || m < 0 || n < m ? 0 : (ll)fac[n] * ifac[m] % MOD * ifac[n - m] % MOD;
}
int solve(int n, int x, int y) {
if (x + m - b0 <= y || x - b0 - 1 >= y || x + y != n) return 0;
int tx = x, ty = y, res = C(tx + ty, tx);
while (tx >= 0 && ty >= 0) {
swap(tx, ty), tx -= m - b0, ty += m - b0;
csub(res, C(tx + ty, tx));
swap(tx, ty), tx -= -b0 - 1, ty += -b0 - 1;
cadd(res, C(tx + ty, tx));
}
tx = x, ty = y;
while (tx >= 0 && ty >= 0) {
swap(tx, ty), tx -= -b0 - 1, ty += -b0 - 1;
csub(res, C(tx + ty, tx));
swap(tx, ty), tx -= m - b0, ty += m - b0;
cadd(res, C(tx + ty, tx));
}
return (ll)res * pw2[y] % MOD;
}
int main() {
ios::sync_with_stdio(0), cin.tie(0);
prework(N - 5);
cin >> T;
while (T--) {
cin >> n >> m >> b0;
pw1[0] = pw2[0] = 1;
for (int i = 1; i <= n; ++i)
pw1[i] = (ll)pw1[i - 1] * m % MOD, pw2[i] = (ll)pw2[i - 1] * (m - 1) % MOD;
int ans = pw1[n];
if (b0 >= m) { cout << ans << '\n'; continue; }
if ((ll)m * m <= n) {
fill(f[0], f[0] + m + 1, 0), f[0][b0] = 1;
for (int i = 1; i <= n; ++i) {
int cur = i & 1, prv = cur ^ 1;
for (int j = 0; j < m; ++j) {
f[cur][j] = f[prv][j + 1];
if (j) cadd<int>(f[cur][j], (ll)f[prv][j - 1] * (m - 1) % MOD);
}
f[cur][m] = 0;
csub<int>(ans, (ll)f[prv][0] * pw1[n - i] % MOD);
}
} else {
for (int t = 1; t <= n; ++t) if ((t - 1 + b0 & 1) == 0) {
int c = solve(t - 1, t - 1 + b0 >> 1, t - 1 - b0 >> 1);
csub<int>(ans, (ll)c * pw1[n - t] % MOD);
}
}
cout << ans << '\n';
}
return 0;
}
:::
DP 做法显然没有前途,考虑如何进一步优化反射容斥的做法。我们发现枚举时刻
那么能否不去枚举碰到
- 若
p\leq -1 ,则一定会碰到B ,但是要保证先于碰到上边界A ,因此要减去先碰到A 的方案数,再加上先碰到B 再碰到A 的方案数……也就是计算f(B)-f(AB)+f(BAB)-\cdots 。反射容斥计算之,注意最后一个B 不需要对点做对称。 - 若
p>1 ,则可以对称到-2-p 变成p\leq -1 的形式。
这样枚举之后,反射容斥计算时用到的组合数上指标就都变为
显然答案是
注意一个细节:处理
:::success[Hard Version 的代码]
#include <bits/stdc++.h>
using namespace std;
#define lowbit(x) ((x) & -(x))
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair<int, int> pii;
const int N = 2e6 + 5, MOD = 998244353;
template<typename T> inline void chk_min(T &x, T y) { x = min(x, y); }
template<typename T> inline void chk_max(T &x, T y) { x = max(x, y); }
template<typename T> inline T add(T x, T y) { return x += y, x >= MOD ? x - MOD : x; }
template<typename T> inline T sub(T x, T y) { return x -= y, x < 0 ? x + MOD : x; }
template<typename T> inline void cadd(T &x, T y) { x += y, x < MOD || (x -= MOD); }
template<typename T> inline void csub(T &x, T y) { x -= y, x < 0 && (x += MOD); }
int T, n, m, b0, d[N];
int pw[N], fac[N], ifac[N];
int qpow(int a, int b) {
int res = 1;
for (; b; b >>= 1) {
if (b & 1) res = (ll)res * a % MOD;
a = (ll)a * a % MOD;
}
return res;
}
void prework(int n) {
fac[0] = 1;
for (int i = 1; i <= n; ++i) fac[i] = (ll)fac[i - 1] * i % MOD;
ifac[n] = qpow(fac[n], MOD - 2);
for (int i = n - 1; ~i; --i) ifac[i] = (ll)ifac[i + 1] * (i + 1) % MOD;
}
int C(int n, int m) {
return n < 0 || m < 0 || n < m ? 0 : (ll)fac[n] * ifac[m] % MOD * ifac[n - m] % MOD;
}
int main() {
ios::sync_with_stdio(0), cin.tie(0);
prework(N - 5);
cin >> T;
while (T--) {
cin >> n >> m >> b0;
int ans = qpow(m, n);
if (b0 >= min(n, m)) { cout << ans << '\n'; continue; }
pw[0] = 1;
for (int i = 1; i <= n; ++i) pw[i] = (ll)pw[i - 1] * (m - 1) % MOD;
fill(d, d + n + 1, 0);
for (int p = b0 - n; p < 0; p += 2) {
int x = n + p - b0 >> 1, v = pw[n + p - b0 >> 1];
if (x >= 0) {
cadd(d[x % (m + 1)], v);
if (x + m + 1 <= n) csub(d[x + m + 1], v);
}
x = n + p + b0 - (m << 1) >> 1;
if (x >= 0) {
csub(d[x % (m + 1)], v);
if (x + m + 1 <= n) cadd(d[x + m + 1], v);
}
}
for (int p = b0 + n & 1; p <= b0 + n; p += 2) {
int np = -2 - p;
int x = n + np - b0 >> 1, v = pw[n + p - b0 >> 1];
if (x >= 0) {
cadd(d[x % (m + 1)], v);
if (x + m + 1 <= n) csub(d[x + m + 1], v);
}
x = n + np + b0 - (m << 1) >> 1;
if (x >= 0) {
csub(d[x % (m + 1)], v);
if (x + m + 1 <= n) cadd(d[x + m + 1], v);
}
}
for (int i = m + 1; i <= n; ++i) cadd(d[i], d[i - m - 1]);
for (int i = 0; i <= n; ++i) csub<int>(ans, (ll)d[i] * C(n, i) % MOD);
cout << ans << '\n';
}
return 0;
}
:::