题解 P7483 【50年后的我们】

· · 题解

给个不一样的做法。

考虑到“被过的题”的限制是覆盖它的区间中至少有一个被选,这个很难受。自然的想法是转换成算“没被过的题”,限制是覆盖它的区间中一个都不能选,这个就舒服多了。

但是求这个有啥用呢?

考虑拆一下式子,设所有题的价值和为 S

\begin{aligned}&\mathbb E(w({\rm covered})^k)\\=\ &\mathbb E((S-w({\rm uncovered}))^k)\\=\ &\mathbb E\left(\sum_{i=0}^k(-1)^{k-i}S^iw({\rm uncovered})^{k-i}\dbinom{k}{i}\right)\\=\ &\sum_{i=0}^k\dbinom{k}{i}(-1)^{k-i}S^i\mathbb E\left(w({\rm uncovered})^{k-i}\right)\end{aligned}

所以只要能对于所有 1\leq k\leq K,求出没人过的题的价值和的 k 次方的期望,就可以在线性时间内还原有人过的题的价值和的 K 次方的期望。

下面考虑前缀 DP,设 f_{i,j} 为考虑到前 i 个题,第 i 个题没人过(强制结尾),没人过的题的价值和的 j 次方的期望。

转移直接枚举上一个没人过的题 k<i,处理 j 次幂还可以考虑二项式定理展开:

\begin{aligned}&\mathbb E((a+b)^k)\\=\ &\mathbb E\left(\sum_{i=0}^ka^ib^{k-i}\dbinom{k}{i}\right)\\=\ &\sum_{i=0}^k\dbinom{k}{i}\mathbb E(a^ib^{k-i})\end{aligned}

在转移中,b 是常量 c_i。所以转移式如下:

f_{i,j}=\sum_{k<i}\sum_{x\leq j}f_{k,x}\dbinom{j}{x}c_i^{j-x}P(k+1,i)

其中 P(l,r) 的含义是 [l,r-1] 中所有元素被包含于 [l,r-1] 中的区间覆盖,且 r 没有被任何左端点在 [l,r] 之间,右端点在 [r,n] 之间的区间覆盖的概率。

首先考察如何快速计算 P(l,r)

容易观察到 P(l,r) 的定义割裂为两部分,结果是两部分的结果相乘,因此分别考虑两部分。前一部分 P_1(l,r) 可以容斥:如果没有被完全覆盖,枚举第一个没被覆盖的位置 k。那么前面被完全覆盖的概率是 P_1(l,k-1)k 没有被覆盖的概率是所有跨过它的区间的不出现的概率乘起来(记这个值是 v_k)。后面长啥样都没有关系。

所以可以如下递推:

P_1(l,r)=1-\sum_{k=l}^{r} P_1(l,k-1)v_k

边界是 P_1(l,l-1)=1

这样可以在 $l$ 固定时 $O(n^2+m)$ 计算所有 $r$ 的 $P_1(l,r)$ 值。 因此计算 $P_1(l,r)$ 可以做到 $O(n^3+nm)$。 另一部分是个显然的二维后缀积,也可以快速计算。 因此现在可以快速计算 $P(l,r)$。 然后是 DP 的转移,做一个换序求和: $$f_{i,j}=\sum_{x\leq j}\dbinom{j}{x}c_i^{j-x}\sum_{k<i}f_{k,x}P(k+1,i)$$ 当 $i$ 固定的时候,可以预先计算 $$s_x=\sum_{k<i}f_{k,x}P(k+1,i)$$ 这一部分的复杂度是 $O(nk)$。 然后代回到上面的式子里面就是 $$f_{i,j}=\sum_{x\leq j}\dbinom{j}{x}c_i^{j-x}s_x$$ 这样转移的复杂度就下降到了 $O(nk+k^2)$。 总复杂度 $O(n^3+n^2k+nk^2)$,足以通过。 ```cpp #include <bits/stdc++.h> using namespace std; #define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++) char buf[1 << 21], *p1 = buf, *p2 = buf; inline int qread() { register char c = getchar(); register int x = 0, f = 1; while (c < '0' || c > '9') { if (c == '-') f = -1; c = getchar(); } while (c >= '0' && c <= '9') { x = (x << 3) + (x << 1) + c - 48; c = getchar(); } return x * f; } const long long mod = 998244353; pair <int, int> q[405]; int n, m, K, l[100005], r[100005], c0[405]; long long p[100005], s[405][405], c[405][405], sum, f[405][405], cov[405][405], tmp[405], tmpd[405], sf[405], pwr[405][405], inv[100005]; vector <int> adj[405]; inline long long Power(long long x, long long y) { long long ans = 1; while (y) { if (y & 1) ans = ans * x % mod; x = x * x % mod; y >>= 1; } return ans; } inline void Read() { n = qread(); m = qread(); K = qread(); for (int i = 1;i <= n;i++) { q[i].first = qread(); q[i].second = qread(); sum = (sum + q[i].second) % mod; } q[n + 1] = make_pair(1000000001, 0); sort(q + 1, q + n + 1); for (int i = 1;i <= m;i++) { l[i] = qread(); r[i] = qread(); p[i] = qread(); l[i] = lower_bound(q + 1, q + n + 1, make_pair(l[i], 0)) - q; r[i] = lower_bound(q + 1, q + n + 1, make_pair(r[i] + 1, 0)) - q - 1; inv[i] = Power((mod + 1 - p[i]) % mod, mod - 2); } } inline void Prefix() { for (int l = 1;l <= n + 2;l++) { for (int r = 1;r <= n + 2;r++) s[l][r] = 1; } for (int i = 1;i <= m;i++) { if (l[i] <= r[i]) s[l[i]][r[i]] = s[l[i]][r[i]] * (mod + 1 - p[i]) % mod; } for (int r = n;r >= 1;r--) { for (int l = 1;l <= r;l++) s[l][r] = s[l][r] * s[l][r + 1] % mod; } for (int r = n;r >= 1;r--) { for (int l = r;l >= 1;l--) s[l][r] = s[l][r] * s[l + 1][r] % mod; } c[0][0] = 1; for (int i = 1;i <= K;i++) { c[i][0] = 1; for (int j = 1;j <= i;j++) c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % mod; } for (int i = 1;i <= n;i++) { for (int j = 1;j <= n;j++) adj[j].clear(); for (int j = 1;j <= m;j++) { if (l[j] >= i) adj[r[j]].push_back(j); } for (int k = 0;k <= n;k++) tmp[k] = 1; cov[i][i - 1] = 1; for (int j = i;j <= n;j++) { int siz = adj[j].size(); for (int k = 0;k <= n;k++) { c0[k] = 0; tmpd[k] = 1; } for (int k = 0;k < siz;k++) { if (p[adj[j][k]] == 1) { c0[l[adj[j][k]]]++; c0[r[adj[j][k]] + 1]--; } else { tmpd[l[adj[j][k]]] = tmpd[l[adj[j][k]]] * (mod + 1 - p[adj[j][k]]) % mod; tmpd[r[adj[j][k]] + 1] = tmpd[r[adj[j][k]] + 1] * inv[adj[j][k]] % mod; } } for (int k = 1;k <= n;k++) { c0[k] += c0[k - 1]; tmpd[k] = tmpd[k] * tmpd[k - 1] % mod; } for (int k = 1;k <= n;k++) { if (c0[k]) tmp[k] = 0; else tmp[k] = tmp[k] * tmpd[k] % mod; } cov[i][j] = 1; for (int k = i - 1;k < j;k++) cov[i][j] = (cov[i][j] - cov[i][k] * tmp[k + 1] % mod + mod) % mod; } } cov[n + 1][n] = 1; for (int i = 1;i <= n + 1;i++) { pwr[i][0] = 1; for (int j = 1;j <= K;j++) pwr[i][j] = pwr[i][j - 1] * q[i].second % mod; } } inline void Solve() { f[0][0] = 1; for (int i = 1;i <= n + 1;i++) { for (int k = 0;k <= K;k++) { sf[k] = 0; for (int ii = 0;ii < i;ii++) sf[k] = (sf[k] + f[ii][k] * s[ii + 1][i] % mod * cov[ii + 1][i - 1]) % mod; } for (int j = 0;j <= K;j++) { for (int k = 0;k <= j;k++) { f[i][j] = (f[i][j] + pwr[i][j - k] % mod * c[j][k] % mod * sf[k]) % mod; } } } long long ans = 0; for (int j = 0;j <= K;j++) { ans = (ans + Power(sum, K - j) * c[K][j] % mod * Power(mod - 1, j) % mod * f[n + 1][j]) % mod; } cout << ans << endl; } int main() { Read(); Prefix(); Solve(); return 0; } ```