AT_abc436_g の题解

· · 题解

终于搞明白官解的非 BM 做法了,自认为理解透了。

官解用 \times 表示内积,有病。

先一笔带过 BM 做法。就是很明显答案是 \sum_{k=0}^M[x^k]\prod_{i=1}^N\frac{1}{1-x^{A_i}},直接做即可。

以下约定 \cdot 表示内积,即 A\cdot B=\sum A_iB_i,同时小写表示数,大写(除了 N,M)表示序列。

我们设 X=(x_1,x_2,\dots,x_N),那么要求的就是 A\cdot X\le M 的非负整数解个数,记 f(M)A\cdot X\le M 的非负整数解个数。

因为我们知道,对于任意一个非负整数 x=2\lfloor\frac{x}{2}\rfloor+(x\bmod 2) 记作 x=q+r 是一个双射,类似的记 X=2Q+R,其中 \forall r\in R,r\in\{0,1\}

因此 A\cdot X=A\cdot(2Q+R)=2\times A\cdot Q+A\cdot R,我们改为枚举 R 计数。

所以 A\cdot X\le M\Leftrightarrow A\cdot Q\le\lfloor\frac{M-A\cdot R}{2}\rfloor

改为枚举 R 计数 Q 的方案数就有 \boxed{f(M)=\sum_{R\in\{0,1\}^N}f(\lfloor\frac{M-A\cdot R}{2}\rfloor)}

最关键的一步转化已经做完了,考虑怎么算方案数。

我们拎一个辅助权重数组(序列)C,初始时 C_i\gets[i=M],也就是始终维护 \sum_iC_if(i)=f(M)

为什么要这样做呢,因为我们可以把一个 f(M) 拆成若干个 f(?) 之和,我们已知 f(0)=1,f(x<0)=0,所以我们只用对 C 变换,最后只保留 C_0 项,那么就是答案。

我们要求的 \sum_iC_if(i)=\sum_iC_i\sum_{R\in\{0,1\}^N}f(\lfloor\frac{i-A\cdot R}{2}\rfloor)

我们记 S=\{A\cdot R\mid R\in \{0,1\}^N\},就有:

\begin{aligned} \sum_iC_if(i)&=\sum_iC_i\sum_{R\in\{0,1\}^N}f(\lfloor\frac{i-A\cdot R}{2}\rfloor) \\&=\sum_iC_i\sum_{s\in S}f(\lfloor\frac{i-s}{2}\rfloor). \end{aligned}

我们改为枚举 \lfloor\frac{i-s}{2}\rfloor,即 i\gets 2i+j+s

\begin{aligned} \sum_iC_if(i)&=\sum_iC_i\sum_{s\in S}f(\lfloor\frac{i-s}{2}\rfloor). \\&=\sum_i\sum_{j\in\{0,1\}}\sum_{s\in S}C_{2i+j+s}f(i). \end{aligned}

对比左右两边,得出 \boxed{C_i\gets\sum_{j\in\{0,1\}}\sum_{s\in S}C_{2i+j+s}} 并不会改变 \sum_iC_if(i) 的值,所以考虑直接对 C 进行 C_i\gets\sum_{s\in S}C_{2i+s}+C_{2i+s+1} 的变换。

现在已经不需要 f 了,直接维护 C 即可,每次 M'=\max_{C_i>0}i 都至少会减半,所以复杂度我不会证是对的。

[代码乱写的](https://atcoder.jp/contests/abc436/submissions/71747121),摘一部分。 :::info[Code] ```cpp #endif // ATCODER_CONVOLUTION_HPP #include <bits/stdc++.h> using namespace std; using ll = long long; #define int ll const int N = 1e6 + 10; const int INF = 1e18; const int MOD = 998244353; using it2 = array<int, 2>; template <typename T> inline T &M(T &x) { return x %= MOD; } int n, m; signed main() { cin.tie(0)->sync_with_stdio(false), cout.setf(ios::fixed), cout.precision(10); cin >> n >> m; vector<int> s{1LL}; for (int i = 0, a; i < n; ++i) { cin >> a; vector<int> tmp(a + 1, 0); tmp.front() = 1, tmp.back() = 1; s = atcoder::convolution(s, tmp); vector<int>().swap(tmp); } vector<int> c{1LL}; // 逆序存储方便实现 c_i=c_{i+s} while (m) { vector<int> nxt(m / 2 - (m - min<int>(m, s.size() + c.size() - 2)) / 2 + 1, 0), res = atcoder::convolution(s, c); // cerr << nxt.size() << '\n'; for (unsigned i = 0; i < res.size() && i <= m; ++i) { M(nxt[m / 2 - (m - i) / 2] += res[i]); } c.swap(nxt); vector<int>().swap(nxt); vector<int>().swap(res); m /= 2; } cout << c.front() << '\n'; return 0; } ``` :::