AT_abc436_g の题解
aaron0919
·
·
题解
终于搞明白官解的非 BM 做法了,自认为理解透了。
官解用 \times 表示内积,有病。
先一笔带过 BM 做法。就是很明显答案是 \sum_{k=0}^M[x^k]\prod_{i=1}^N\frac{1}{1-x^{A_i}},直接做即可。
以下约定 \cdot 表示内积,即 A\cdot B=\sum A_iB_i,同时小写表示数,大写(除了 N,M)表示序列。
我们设 X=(x_1,x_2,\dots,x_N),那么要求的就是 A\cdot X\le M 的非负整数解个数,记 f(M) 为 A\cdot X\le M 的非负整数解个数。
因为我们知道,对于任意一个非负整数 x=2\lfloor\frac{x}{2}\rfloor+(x\bmod 2) 记作 x=q+r 是一个双射,类似的记 X=2Q+R,其中 \forall r\in R,r\in\{0,1\}。
因此 A\cdot X=A\cdot(2Q+R)=2\times A\cdot Q+A\cdot R,我们改为枚举 R 计数。
所以 A\cdot X\le M\Leftrightarrow A\cdot Q\le\lfloor\frac{M-A\cdot R}{2}\rfloor。
改为枚举 R 计数 Q 的方案数就有 \boxed{f(M)=\sum_{R\in\{0,1\}^N}f(\lfloor\frac{M-A\cdot R}{2}\rfloor)}。
最关键的一步转化已经做完了,考虑怎么算方案数。
我们拎一个辅助权重数组(序列)C,初始时 C_i\gets[i=M],也就是始终维护 \sum_iC_if(i)=f(M)。
为什么要这样做呢,因为我们可以把一个 f(M) 拆成若干个 f(?) 之和,我们已知 f(0)=1,f(x<0)=0,所以我们只用对 C 变换,最后只保留 C_0 项,那么就是答案。
我们要求的 \sum_iC_if(i)=\sum_iC_i\sum_{R\in\{0,1\}^N}f(\lfloor\frac{i-A\cdot R}{2}\rfloor)。
我们记 S=\{A\cdot R\mid R\in \{0,1\}^N\},就有:
\begin{aligned}
\sum_iC_if(i)&=\sum_iC_i\sum_{R\in\{0,1\}^N}f(\lfloor\frac{i-A\cdot R}{2}\rfloor)
\\&=\sum_iC_i\sum_{s\in S}f(\lfloor\frac{i-s}{2}\rfloor).
\end{aligned}
我们改为枚举 \lfloor\frac{i-s}{2}\rfloor,即 i\gets 2i+j+s。
\begin{aligned}
\sum_iC_if(i)&=\sum_iC_i\sum_{s\in S}f(\lfloor\frac{i-s}{2}\rfloor).
\\&=\sum_i\sum_{j\in\{0,1\}}\sum_{s\in S}C_{2i+j+s}f(i).
\end{aligned}
对比左右两边,得出 \boxed{C_i\gets\sum_{j\in\{0,1\}}\sum_{s\in S}C_{2i+j+s}} 并不会改变 \sum_iC_if(i) 的值,所以考虑直接对 C 进行 C_i\gets\sum_{s\in S}C_{2i+s}+C_{2i+s+1} 的变换。
现在已经不需要 f 了,直接维护 C 即可,每次 M'=\max_{C_i>0}i 都至少会减半,所以复杂度我不会证是对的。
[代码乱写的](https://atcoder.jp/contests/abc436/submissions/71747121),摘一部分。
:::info[Code]
```cpp
#endif // ATCODER_CONVOLUTION_HPP
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define int ll
const int N = 1e6 + 10;
const int INF = 1e18;
const int MOD = 998244353;
using it2 = array<int, 2>;
template <typename T>
inline T &M(T &x) { return x %= MOD; }
int n, m;
signed main()
{
cin.tie(0)->sync_with_stdio(false), cout.setf(ios::fixed), cout.precision(10);
cin >> n >> m;
vector<int> s{1LL};
for (int i = 0, a; i < n; ++i)
{
cin >> a;
vector<int> tmp(a + 1, 0);
tmp.front() = 1, tmp.back() = 1;
s = atcoder::convolution(s, tmp);
vector<int>().swap(tmp);
}
vector<int> c{1LL}; // 逆序存储方便实现 c_i=c_{i+s}
while (m)
{
vector<int> nxt(m / 2 - (m - min<int>(m, s.size() + c.size() - 2)) / 2 + 1, 0), res = atcoder::convolution(s, c);
// cerr << nxt.size() << '\n';
for (unsigned i = 0; i < res.size() && i <= m; ++i)
{
M(nxt[m / 2 - (m - i) / 2] += res[i]);
}
c.swap(nxt);
vector<int>().swap(nxt);
vector<int>().swap(res);
m /= 2;
}
cout << c.front() << '\n';
return 0;
}
```
:::