题解 P4389 【付公主的背包】

Nemlit

2020-03-03 16:13:55

Solution

~~貌似前年初中教练把这题丢到基础背包练习题当作业来着~~ 首先我们对于每一个物品写出他的生成函数$S(x)$,对于一个体积为$V_i$的物品,我们有$S(x) = \sum_{j\%V_i=0}x^j$ $$S(x)=\sum_{j=0}^{\infty}x^{V_i\times j}, x^{V_i}\times S(x)=\sum_{j=1}^{\infty}x^{V_i\times j}$$ $S(x)-x^{V_i}\times S(x)=x^0=1$,所以我们有:$S(x)=\dfrac{1}{1-x^{V_i}}$ 于是我们可以快乐的把所有生成函数求逆之后卷在一起,复杂度为$O(NMlogN)$ 考虑优化,我们注意到:$e^{ln(S(x))}=S(x)$,这个看起来很无聊的柿子确实解题的关键,因为他把乘号变成了加号,多项式加法我们有很优秀的$O(N)$复杂度,并且看起来很好进一步优化 于是我们可以吧所有生成函数取$ln$,然后全部加起来,$exp$回去即可。 现在来考虑怎么求$ln(\dfrac{1}{1-x^{V_i}})$。根据一般套路,我们先求导: 令$G(x)=ln(\dfrac{1}{1-x^{V_i}})$,则$G'(x)=ln'(\dfrac{1}{1-x^{V_i}})$ 根据对数函数的性质,我们有$ln(x^c)=c\times ln(x)$,带入得:$G'(x)=-ln'(1-x^{V_i})=(\dfrac{V_i\times x^{V_i-1}}{1-x^{V_i}})=S(x)\times V_i\times x^{V_i-1}=V_i\times \sum_{j=0}^{\infty}x^{V_i\times j+V_i-1}$ 再来把这个柿子积分回去,我们有:$G(x)=\int G'(x)=V_i\times \sum_{j=0}^{\infty}\dfrac{x^{V_i\times j+V_i}}{V_i\times j+V_i}=\sum_{j=0}^{\infty}\dfrac{x^{V_i\times j+V_i}}{j+1}=\sum_{j=1}^{\infty}\dfrac{x^{V_i\times j}}{j}$ 所以我们现在的任务变成了,求出所有$\sum_{j=1}^{\infty}\dfrac{x^{V_i\times j}}{j}$的和,再进行$exp$ 求和用调和级数可以保证复杂度$\rm O(NlogN)$,所以总复杂度是$\rm O(NlogN)$ ### $\rm Code:$ ```cpp /* This code is written by Nemlit */ #include<bits/stdc++.h> using namespace std; #define il inline #define re register #define rep(i, a, b) for(re int i = (a); i <= (b); ++ i) #define mod 998244353 il int read() { re int x = 0, f = 1; re char c = getchar(); while(c < '0' || c > '9') { if(c == '-') f = -1; c = getchar();} while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar(); return x * f; } il int Mul(int a, int b) { return 1ll * a * b % mod; } il int Inc(int a, int b) { return (a += b) >= mod ? a - mod : a; } il int Dec(int a, int b) { return (a -= b) < 0 ? a + mod : a; } il int qpow(int a, int b) { int r = 1; while(b) { if(b & 1) r = Mul(r, a); a = Mul(a, a), b >>= 1; } return r; } #define maxn 800005 int n, m, a[maxn], r[maxn], h[maxn], f[maxn], g[maxn], p[maxn], w[maxn], lim, l, invG, G, cnt[maxn], pax, inv[maxn]; il void NTT(int *a, int opt, int lim) { rep(i, 0, lim - 1) if(i > r[i]) swap(a[i], a[r[i]]); w[0] = 1; for(re int mid = 1; mid < lim; mid <<= 1) { int base = qpow(opt == 1 ? G : invG, (mod - 1) / (mid << 1)); rep(j, 1, mid - 1) w[j] = Mul(w[j - 1], base); for(re int j = 0; j < lim; j += mid << 1) { rep(k, 0, mid - 1) { int x = a[j + k], y = Mul(w[k], a[j + k + mid]); a[j + k] = Inc(x, y), a[j + k + mid] = Dec(x, y); } } } if(opt == -1) rep(i, 0, lim - 1) a[i] = Mul(a[i], inv[lim]); } il void Inv(int *a, int *b, int n) { rep(i, 0, n * 4) b[i] = h[i] = 0; b[0] = qpow(a[0], mod - 2); int k = 1, lim = 2, l = 1; while(k <= n) { k <<= 1, lim <<= 1, ++ l; rep(i, 0, k - 1) h[i] = a[i]; rep(i, 0, lim - 1) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1)); NTT(h, 1, lim), NTT(b, 1, lim); rep(i, 0, lim - 1) b[i] = Dec(Mul(2, b[i]), Mul(Mul(b[i], b[i]), h[i])); NTT(b, -1, lim); rep(i, k, lim) b[i] = 0; } } il void ln(int *a, int *b, int n) { rep(i, 0, n * 2) b[i] = f[i] = 0; rep(i, 1, n - 1) b[i - 1] = Mul(a[i], i); Inv(a, f, n); int lim = 1, l = 0; while(lim <= n * 2) lim <<= 1, ++ l; rep(i, 0, lim - 1) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1)); NTT(f, 1, lim), NTT(b, 1, lim); rep(i, 0, lim - 1) f[i] = Mul(b[i], f[i]); NTT(f, -1, lim), b[0] = 0; rep(i, 0, n - 2) b[i + 1] = Mul(f[i], inv[i + 1]); rep(i, n, lim - 1) b[i] = 0; } il void Exp(int *a, int *b, int n) { rep(i, 0, n * 4) b[i] = g[i] = p[i] = 0; int k = 1, lim = 2, l = 1; b[0] = 1; while(k <= n) { k <<= 1, lim <<= 1, ++ l; ln(b, g, k); rep(i, 0, k - 1) g[i] = Dec(a[i], g[i]); g[0] = Inc(g[0], 1); rep(i, 0, lim - 1) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1)); NTT(b, 1, lim), NTT(g, 1, lim); rep(i, 0, lim - 1) b[i] = Mul(b[i], g[i]); NTT(b, -1, lim); rep(i, k, lim - 1) b[i] = 0; } } signed main() { n = read(), m = read(), G = 3, invG = qpow(G, mod - 2), inv[1] = 1; rep(i, 2, 800000) inv[i] = Mul(inv[mod % i], (mod - mod / i)); rep(i, 0, n - 1) pax = read(), ++ cnt[pax]; rep(i, 1, m) { if(!cnt[i]) continue; for(re int j = i; j <= m; j += i) a[j] = Inc(a[j], Mul(cnt[i], inv[j / i])); } Exp(a, p, m); rep(i, 1, m) printf("%d\n", p[i]); return 0; } ```