P11346 [KTSC 2023 R2] 会议室 2
EuphoricStar
·
·
题解
给定 n 条线段 [l_i, r_i],有一个 n 个点的无向图,若 [l_i, r_i] 和 [l_j, r_j] 交集不为空那么图上 i, j 之间有边。
你要进行 n - 1 次操作,每次操作删除一条线段。求在最小化每次操作后图的连通块数量之和的前提下,有多少种删的方案。对 10^9 + 7 取模。
----
首先考虑怎么删能使连通块数量和最小。显然是按连通块大小从小到大删,并且删的过程要保证连通块不分裂。
所以我们统计每个连通块内部删的方案,最后答案乘上 $\prod\limits_{i = 1}^n c_i!$,其中 $c_i$ 为大小为 $i$ 的连通块数量。
考虑倒过来加线段。那么限制为,设某个时刻已加入的线段并集为 $[L, R]$,那么新加入的线段 $[l, r]$ 要满足 $[L, R]$ 和 $[l, r]$ 交集不为空。
若加入某条线段后,线段并集变大,那么我们称它是一条**关键线段**。
设所有关键线段为 $[l_1, r_1], [l_2, r_2], \ldots, [l_k, r_k]$,线段并集的变化依次为 $[L_1, R_1], [L_2, R_2], \ldots, [L_k, R_k]$。
那么一条非关键线段 $[l, r]$,必须在 $[l_t, r_t]$ 之后加入,$t$ 是最小的正整数满足 $[l, r] \subseteq [L_t, R_t]$。
考虑确定了关键线段后怎么算方案数。
考虑构造一棵树。根结点为 $[l_1, r_1]$,所有关键线段 $[l_1, r_1], [l_2, r_2], \ldots, [l_k, r_k]$ 连成一条链。对每条非关键线段,将它挂到 $[l_t, r_t]$ 下面。那么方案数等价于这棵树的拓扑序数量,即 $\dfrac{n!}{\prod\limits_{i = 1}^k (n - g_{L_i, R_i})!}$,其中 $g_{i, j}$ 表示有多少线段 $[l, r] \subseteq [i, j]$,可以二维前缀和预处理。
那么考虑 DP。设 $f_{L, R}$ 表示当前线段并集为 $[L, R]$,对于所有选关键线段的方案,$\dfrac{1}{\prod (n - g_{l, r})!}$ 之和。
转移首先令 $f_{L, R} \gets f_{L, R} \times \dfrac{1}{(n - g_{L, R})!}$。然后枚举下一条关键线段 $[l, r]$(需要满足 $[l, r] \not \subseteq [L, R]$ 且 $[l, r]$ 和 $[L, R]$ 交集不为空),有 $f_{\min(L, l), \max(R, r)} \gets f_{L, R}$。容易二维前缀和优化。
时间复杂度 $O(n^2)$。
```cpp
// Problem: P11346 [KTSC 2023 R2] 会议室 2
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P11346
// Memory Limit: 1000 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 4040;
const int mod = 1000000007;
inline void fix(int &x) {
x += ((x >> 31) & mod);
}
int n, m, c[maxn], lsh[maxn], tot;
ll fac[maxn], inv[maxn];
struct node {
int l, r;
node(int a = 0, int b = 0) : l(a), r(b) {}
} a[maxn], b[maxn];
int f[maxn][maxn], f1[maxn][maxn], f2[maxn][maxn], f3[maxn][maxn], g[maxn][maxn];
bool vis[maxn][maxn];
vector<int> vl[maxn], vr[maxn];
inline int calc() {
tot = 0;
for (int i = 1; i <= m; ++i) {
lsh[++tot] = b[i].l;
lsh[++tot] = b[i].r;
}
sort(lsh + 1, lsh + tot + 1);
tot = unique(lsh + 1, lsh + tot + 1) - lsh - 1;
for (int i = 1; i <= m; ++i) {
b[i].l = lower_bound(lsh + 1, lsh + tot + 1, b[i].l) - lsh;
b[i].r = lower_bound(lsh + 1, lsh + tot + 1, b[i].r) - lsh;
}
for (int i = 0; i <= tot + 1; ++i) {
vector<int>().swap(vl[i]);
vector<int>().swap(vr[i]);
for (int j = 0; j <= tot + 1; ++j) {
f[i][j] = g[i][j] = f1[i][j] = f2[i][j] = f3[i][j] = 0;
vis[i][j] = 0;
}
}
for (int i = 1; i <= m; ++i) {
f[b[i].l][b[i].r] = g[b[i].l][b[i].r] = 1;
vis[b[i].l][b[i].r] = 1;
vl[b[i].l].pb(b[i].r);
vr[b[i].r].pb(b[i].l);
}
for (int i = tot; i; --i) {
for (int j = i; j <= tot; ++j) {
g[i][j] += g[i + 1][j] + g[i][j - 1] - g[i + 1][j - 1];
}
}
for (int i = tot; i; --i) {
for (int j = i; j <= tot; ++j) {
for (int k : vl[i]) {
if (k < j) {
fix(f[i][j] += f1[i + 1][j] - mod);
fix(f[i][j] -= f1[k + 1][j]);
}
}
for (int k : vr[j]) {
if (k > i) {
fix(f[i][j] += f2[i][j - 1] - mod);
fix(f[i][j] -= f2[i][k - 1]);
}
}
if (vis[i][j]) {
fix(f[i][j] += f3[i + 1][j] - mod);
fix(f[i][j] += f3[i][j - 1] - mod);
fix(f[i][j] -= f3[i + 1][j - 1]);
}
f[i][j] = 1LL * f[i][j] * inv[m - g[i][j]] % mod;
fix(f1[i][j] = f1[i + 1][j] + f[i][j] - mod);
fix(f2[i][j] = f2[i][j - 1] + f[i][j] - mod);
fix(f3[i][j] = f3[i + 1][j] + f3[i][j - 1] - mod);
fix(f3[i][j] -= f3[i + 1][j - 1]);
fix(f3[i][j] += f[i][j] - mod);
}
}
return f[1][tot] * fac[m - 1] % mod;
}
int count_removals(vector<int> _l, vector<int> _r) {
n = (int)_l.size();
for (int i = 1; i <= n; ++i) {
a[i].l = _l[i - 1];
a[i].r = _r[i - 1];
}
sort(a + 1, a + n + 1, [&](const node &a, const node &b) {
return a.l < b.l;
});
fac[0] = 1;
for (int i = 1; i <= n; ++i) {
fac[i] = fac[i - 1] * i % mod;
}
inv[0] = inv[1] = 1;
for (int i = 2; i <= n; ++i) {
inv[i] = (mod - mod / i) * inv[mod % i] % mod;
}
int r = 0;
ll ans = 1;
for (int i = 1; i <= n; ++i) {
if (a[i].l > r) {
if (m) {
ans = ans * calc() % mod;
++c[m];
m = 0;
}
r = a[i].r;
b[++m] = a[i];
} else {
r = max(r, a[i].r);
b[++m] = a[i];
}
}
if (m) {
ans = ans * calc() % mod;
++c[m];
}
for (int i = 1; i <= n; ++i) {
ans = ans * fac[c[i]] % mod;
}
return ans;
}
// int main() {
// int n;
// scanf("%d", &n);
// vector<int> a(n), b(n);
// for (int i = 0; i < n; ++i) {
// scanf("%d%d", &a[i], &b[i]);
// }
// printf("%d\n", count_removals(a, b));
// return 0;
// }
```