题解 P7154 【[USACO20DEC] Sleeping Cows P】

· · 题解

动态规划优化

题目链接:P7154 [USACO20DEC] Sleeping Cows P

本题解同步发布于 My Blog

题意:

现给定长度均为 n 的数组 s_it_is_i 能与 t_j 匹配当且仅当 s_i\le t_j。一组匹配是极大的,当且仅当对于任意一个匹配对 (s_i,t_j) 均满足 s_i\le t_j ,且任意未被匹配的 st 中无法再组成匹配对。

求所有极大匹配的方案数。

考虑将 st 都从小到大排序,那么对于一个 t_i,能与之匹配的 s_j 应该是 s 的一段前缀。

那么考虑从小到大讨论每一个 t,每一个 t 有两种选择:选择一个合法的 s 与之匹配,或放弃将这个 t 加入匹配。

不难发现,当我们将 t 升序排序后,对于一个 s_j,若其可与 t_i 匹配,那么一定可以与 t_{i+1} 匹配。此时考虑什么情况下,我们才可以选择放弃将当前的 t 加入匹配。

若当前 t_i 所能匹配的 s 的最大下标为 \text{lim}_ i,那么若要放弃 t_i,则 s_1s_{\text{lim}_ i} 均需要被匹配。因为如果放弃 t_i 且放弃 s_1s_{\text{lim}_ i} 中的任意一个,那么他们一定可以组成一个新的匹配对,这是我们不希望看到的。

也就是说,当存在一个 s_i 被放弃加入匹配时,任意满足 i\le \text{lim}_ jt_j不可以被放弃

同时,我们实际上只关注最小的,被放弃的 s_i。那么一个 \mathcal{O}(n^3) 的算法就很显然了:

考虑最小的被放弃的 s_i 是谁,那么对于任意 j < is_j 都需要加入匹配,\text{lim}_ j \ge it_j 均不能放弃。\mathcal{O}(n) 枚举 s_i 后可以通过 \mathcal{O}(n^2)\text{dp} 解决。

考虑在这个思路上进行优化。实际上,我们并不需要枚举 s_i 是谁,而是只关心 s 的一个前缀是否全部加入了匹配。也就是说,我们可以将上面枚举 s_i 的过程转变为一个 \text{bool} 变量 0/1 表示某一个前缀是否全部选入。

那么可以维护两个指针,分别指向当前还未考虑的最小的 st,每次选择较小的一边进行更新,特殊地,对于 s=t 的情形,则优先更新 s 一边。

更简单的实现方法,可以将两个数组合并为一个,进行双关键字排序,依次分类讨论更新即可。以下的 \text{dp} 均在这个新的数组上进行讨论。

考虑 \text{dp},设 \text{dp}_ {i,j,0/1} 表示考虑到第 i 个位置,目前有 j 个被选入匹配,但是还未选定匹配谁的 s1i-1 位置中的 s 是否都被选择加入的方案数,其中第三位的 1 表示前缀中的 s 全部加入匹配。

那么可以分类进行转移:

若第 i 个位置是 s 中的元素,那么:

\text{dp}_ {i,j,0}=\text{dp}_ {i-1,j-1,0}+\text{dp}_ {i-1,j,0}+\text{dp}_ {i-1,j,1} \text{dp}_ {i,j,1}=\text{dp}_ {i-1,j-1,1}

可以考虑是否将这个 s 加入匹配,若加入,则 j\leftarrow j+1 ,否则不变。需要注意的是若选择放弃当前 s 则第三位必然为 0

若第 i 个位置是 t 中的元素,那么:

\text{dp}_ {i,j,0}=\text{dp}_ {i-1,j+1,0}\times (j+1) \text{dp}_ {i,j,1}=\text{dp}_ {i-1,j,1}+\text{dp}_ {i-1,j+1,1}\times (j+1)

若当前第三位为 0,那么不能放弃当前 tj\leftarrow j-1,同时从当前所有还未确定的 s 中选择一个与当前 t 匹配;若第三位为 1,则可以选择放弃当前位,同样也可以选择进行一次匹配。

最后的答案就是:

\text{dp}_ {n,0,0}+\text{dp}_ {n,0,1}

也就是最后没有在匹配中,却未选择与谁匹配的 s

时间复杂度 \mathcal{O}(n^2)

\texttt{Code:}
//Code By CXY07
#include<bits/stdc++.h>
using namespace std;

//#define FILE
#define int long long
#define debug(x) cout << #x << " = " << x << endl
#define file(FILENAME) freopen(FILENAME".in", "r", stdin), freopen(FILENAME".out", "w", stdout)
#define LINE() cout << "LINE = " << __LINE__ << endl
#define LL long long
#define ull unsigned long long
#define pii pair<int,int>
#define mp make_pair
#define pb push_back
#define fst first
#define scd second
#define inv(x) qpow((x),mod - 2)
#define lowbit(x) ((x) & (-(x)))
#define abs(x) ((x) < 0 ? (-(x)) : (x))
#define randint(l, r) (rand() % ((r) - (l) + 1) + (l))
#define vec vector

const int MAXN = 6010;
const int INF = 2e9;
const double PI = acos(-1);
const double eps = 1e-6;
const int mod = 1e9 + 7;
//const int mod = 998244353;
//const int G = 3;
//const int base = 131;

int n, m;
int dp[2][MAXN][2]; //dp[i][j][0/1] 还有j头牛需要加入,是否填满
int now, pre = 1;//滚动数组
pii s[MAXN];

template<typename T> inline bool read(T &a) {
    a = 0; char c = getchar(); int f = 1;
    while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9') {a = a * 10 + (c ^ 48); c = getchar();}
    a *= f;
    return 1;
}

template<typename A, typename ...B>
inline bool read(A &x, B &...y) {return read(x) && read(y...);}

signed main () {
#ifdef FILE
    freopen(".in","r",stdin);
    freopen(".out","w",stdout);
#endif
    read(n);
    for(int i = 1; i <= n; ++i) read(s[i].fst), s[i].scd = 0;
    for(int i = 1; i <= n; ++i) read(s[i + n].fst), s[i + n].scd = 1;
    sort(s + 1, s + n * 2 + 1);
    dp[now][0][1] = 1;
    for(int i = 1; i <= n * 2; ++i) {
        swap(now, pre);//滚动数组
        memset(dp[now], 0, sizeof dp[now]);
        if(!s[i].scd) {//若当前是 s 中元素
            for(int j = 0; j <= n; ++j) {
                if(j) (dp[now][j][0] += dp[pre][j - 1][0]) %= mod;
                if(j) (dp[now][j][1] += dp[pre][j - 1][1]) %= mod;
                (dp[now][j][0] += dp[pre][j][0] + dp[pre][j][1]) %= mod;
            }
        } else {//若当前是 t 中元素
            for(int j = 0; j <= n; ++j) {
                (dp[now][j][1] += dp[pre][j][1]) %= mod;
                (dp[now][j][1] += dp[pre][j + 1][1] * (j + 1) % mod) %= mod;
                (dp[now][j][0] += dp[pre][j + 1][0] * (j + 1) % mod) %= mod;
            }
        }
    }
    printf("%lld\n", ((dp[now][0][0] + dp[now][0][1]) % mod + mod) % mod);
    return 0;
}