题解 P7483 【50年后的我们】
Sol1
·
·
题解
给个不一样的做法。
考虑到“被过的题”的限制是覆盖它的区间中至少有一个被选,这个很难受。自然的想法是转换成算“没被过的题”,限制是覆盖它的区间中一个都不能选,这个就舒服多了。
但是求这个有啥用呢?
考虑拆一下式子,设所有题的价值和为 S:
\begin{aligned}&\mathbb E(w({\rm covered})^k)\\=\ &\mathbb E((S-w({\rm uncovered}))^k)\\=\ &\mathbb E\left(\sum_{i=0}^k(-1)^{k-i}S^iw({\rm uncovered})^{k-i}\dbinom{k}{i}\right)\\=\ &\sum_{i=0}^k\dbinom{k}{i}(-1)^{k-i}S^i\mathbb E\left(w({\rm uncovered})^{k-i}\right)\end{aligned}
所以只要能对于所有 1\leq k\leq K,求出没人过的题的价值和的 k 次方的期望,就可以在线性时间内还原有人过的题的价值和的 K 次方的期望。
下面考虑前缀 DP,设 f_{i,j} 为考虑到前 i 个题,第 i 个题没人过(强制结尾),没人过的题的价值和的 j 次方的期望。
转移直接枚举上一个没人过的题 k<i,处理 j 次幂还可以考虑二项式定理展开:
\begin{aligned}&\mathbb E((a+b)^k)\\=\ &\mathbb E\left(\sum_{i=0}^ka^ib^{k-i}\dbinom{k}{i}\right)\\=\ &\sum_{i=0}^k\dbinom{k}{i}\mathbb E(a^ib^{k-i})\end{aligned}
在转移中,b 是常量 c_i。所以转移式如下:
f_{i,j}=\sum_{k<i}\sum_{x\leq j}f_{k,x}\dbinom{j}{x}c_i^{j-x}P(k+1,i)
其中 P(l,r) 的含义是 [l,r-1] 中所有元素被包含于 [l,r-1] 中的区间覆盖,且 r 没有被任何左端点在 [l,r] 之间,右端点在 [r,n] 之间的区间覆盖的概率。
首先考察如何快速计算 P(l,r)。
容易观察到 P(l,r) 的定义割裂为两部分,结果是两部分的结果相乘,因此分别考虑两部分。前一部分 P_1(l,r) 可以容斥:如果没有被完全覆盖,枚举第一个没被覆盖的位置 k。那么前面被完全覆盖的概率是 P_1(l,k-1)。k 没有被覆盖的概率是所有跨过它的区间的不出现的概率乘起来(记这个值是 v_k)。后面长啥样都没有关系。
所以可以如下递推:
P_1(l,r)=1-\sum_{k=l}^{r} P_1(l,k-1)v_k
边界是 P_1(l,l-1)=1。
这样可以在 $l$ 固定时 $O(n^2+m)$ 计算所有 $r$ 的 $P_1(l,r)$ 值。
因此计算 $P_1(l,r)$ 可以做到 $O(n^3+nm)$。
另一部分是个显然的二维后缀积,也可以快速计算。
因此现在可以快速计算 $P(l,r)$。
然后是 DP 的转移,做一个换序求和:
$$f_{i,j}=\sum_{x\leq j}\dbinom{j}{x}c_i^{j-x}\sum_{k<i}f_{k,x}P(k+1,i)$$
当 $i$ 固定的时候,可以预先计算
$$s_x=\sum_{k<i}f_{k,x}P(k+1,i)$$
这一部分的复杂度是 $O(nk)$。
然后代回到上面的式子里面就是
$$f_{i,j}=\sum_{x\leq j}\dbinom{j}{x}c_i^{j-x}s_x$$
这样转移的复杂度就下降到了 $O(nk+k^2)$。
总复杂度 $O(n^3+n^2k+nk^2)$,足以通过。
```cpp
#include <bits/stdc++.h>
using namespace std;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;
inline int qread() {
register char c = getchar();
register int x = 0, f = 1;
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = (x << 3) + (x << 1) + c - 48;
c = getchar();
}
return x * f;
}
const long long mod = 998244353;
pair <int, int> q[405];
int n, m, K, l[100005], r[100005], c0[405];
long long p[100005], s[405][405], c[405][405], sum, f[405][405], cov[405][405], tmp[405], tmpd[405], sf[405], pwr[405][405], inv[100005];
vector <int> adj[405];
inline long long Power(long long x, long long y) {
long long ans = 1;
while (y) {
if (y & 1) ans = ans * x % mod;
x = x * x % mod;
y >>= 1;
}
return ans;
}
inline void Read() {
n = qread(); m = qread(); K = qread();
for (int i = 1;i <= n;i++) {
q[i].first = qread(); q[i].second = qread();
sum = (sum + q[i].second) % mod;
}
q[n + 1] = make_pair(1000000001, 0);
sort(q + 1, q + n + 1);
for (int i = 1;i <= m;i++) {
l[i] = qread();
r[i] = qread();
p[i] = qread();
l[i] = lower_bound(q + 1, q + n + 1, make_pair(l[i], 0)) - q;
r[i] = lower_bound(q + 1, q + n + 1, make_pair(r[i] + 1, 0)) - q - 1;
inv[i] = Power((mod + 1 - p[i]) % mod, mod - 2);
}
}
inline void Prefix() {
for (int l = 1;l <= n + 2;l++) {
for (int r = 1;r <= n + 2;r++) s[l][r] = 1;
}
for (int i = 1;i <= m;i++) {
if (l[i] <= r[i]) s[l[i]][r[i]] = s[l[i]][r[i]] * (mod + 1 - p[i]) % mod;
}
for (int r = n;r >= 1;r--) {
for (int l = 1;l <= r;l++) s[l][r] = s[l][r] * s[l][r + 1] % mod;
}
for (int r = n;r >= 1;r--) {
for (int l = r;l >= 1;l--) s[l][r] = s[l][r] * s[l + 1][r] % mod;
}
c[0][0] = 1;
for (int i = 1;i <= K;i++) {
c[i][0] = 1;
for (int j = 1;j <= i;j++) c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % mod;
}
for (int i = 1;i <= n;i++) {
for (int j = 1;j <= n;j++) adj[j].clear();
for (int j = 1;j <= m;j++) {
if (l[j] >= i) adj[r[j]].push_back(j);
}
for (int k = 0;k <= n;k++) tmp[k] = 1;
cov[i][i - 1] = 1;
for (int j = i;j <= n;j++) {
int siz = adj[j].size();
for (int k = 0;k <= n;k++) {
c0[k] = 0;
tmpd[k] = 1;
}
for (int k = 0;k < siz;k++) {
if (p[adj[j][k]] == 1) {
c0[l[adj[j][k]]]++;
c0[r[adj[j][k]] + 1]--;
} else {
tmpd[l[adj[j][k]]] = tmpd[l[adj[j][k]]] * (mod + 1 - p[adj[j][k]]) % mod;
tmpd[r[adj[j][k]] + 1] = tmpd[r[adj[j][k]] + 1] * inv[adj[j][k]] % mod;
}
}
for (int k = 1;k <= n;k++) {
c0[k] += c0[k - 1];
tmpd[k] = tmpd[k] * tmpd[k - 1] % mod;
}
for (int k = 1;k <= n;k++) {
if (c0[k]) tmp[k] = 0;
else tmp[k] = tmp[k] * tmpd[k] % mod;
}
cov[i][j] = 1;
for (int k = i - 1;k < j;k++) cov[i][j] = (cov[i][j] - cov[i][k] * tmp[k + 1] % mod + mod) % mod;
}
}
cov[n + 1][n] = 1;
for (int i = 1;i <= n + 1;i++) {
pwr[i][0] = 1;
for (int j = 1;j <= K;j++) pwr[i][j] = pwr[i][j - 1] * q[i].second % mod;
}
}
inline void Solve() {
f[0][0] = 1;
for (int i = 1;i <= n + 1;i++) {
for (int k = 0;k <= K;k++) {
sf[k] = 0;
for (int ii = 0;ii < i;ii++) sf[k] = (sf[k] + f[ii][k] * s[ii + 1][i] % mod * cov[ii + 1][i - 1]) % mod;
}
for (int j = 0;j <= K;j++) {
for (int k = 0;k <= j;k++) {
f[i][j] = (f[i][j] + pwr[i][j - k] % mod * c[j][k] % mod * sf[k]) % mod;
}
}
}
long long ans = 0;
for (int j = 0;j <= K;j++) {
ans = (ans + Power(sum, K - j) * c[K][j] % mod * Power(mod - 1, j) % mod * f[n + 1][j]) % mod;
}
cout << ans << endl;
}
int main() {
Read();
Prefix();
Solve();
return 0;
}
```