题解:P5469 [NOI2019] 机器人

· · 题解

[NOI2019] 机器人

观察到任意方案中,所有位置能走到的区间都是包含或不相交的区间,而反过来看,每个区间中最大值中最靠右者,可以走到区间中所有的位置,考虑区间 DP。设 dp_{l, r, h} 表示在 [l, r] 中,最高的位置的高度不超过 h 的方案数,则有转移 dp_{l, r, h}=dp_{l, r, h-1}+\sum_{i\in[l, r], |(i-l)-(r-i)|\leq2, a_i\leq h\leq b_i}dp_{l, i-1, h}\times dp_{i+1, r, h},表示的是枚举出发点 i,其能走到的位置左右两边距离不超过 2,且这个位置可以被设为 h,令这个位置上的高度为 h,则方案数为左边最高高度不超过 h,右边小于 h 的方案数相乘。而最高为 h 的加上最高不超过 h-1 的,即不超过 h 的方案数。

这样子是 O(n^2V) 的。由于最后只需要知道 dp_{1, n, V} 的值,则会对最终区间 [1, n] 产生贡献的区间 [l, r] 的个数会比 O(n^2) 小很多。则可以先预处理出这些区间。打表得知,实际有用的区间个数不到 3000,设个数为 w,则只需要处理这 w 个区间即可。时间复杂度 O(wV)

既然 V 这么大,就可以去考虑离散化它。离散化后的 V 被分成了 O(n) 个区间。此时,把高度在一个区间内的情况一起考虑,会发现转移中 a_i\leq h\leq b_i 的条件对于每一个位置是固定的,不随在此区间内的 h 变化而变化。纵观每一个位置 i,其每一个 h 的转移都是一样的,这时就可以大胆地猜测,一个长度为 n 的区间的答案是一个与 h 有关的 n 次多项式。这个可以使用数学归纳法证明。首先长度为 1 的区间,其方案数是随 h 的增加线性递增的,然后是长度为 k 的区间,设转移时左右区间的长度分别为 x, y(x+y=k-1),则左边是一个 x 次多项式,右边是一个 y 次多项式,转移时是将其相乘,此时为 x+y=k-1 次,求和操作算作一次,为 k 次,则证明了长度为 k 的区间是与 h 有关的 k 次多项式。

那么对于任意的 h,只需要求出这个 n 次多项式,就可以直接带入求值。使用拉格朗日插值法,先求出高度为 [1, n+1] 的答案,最后代入 h 即可,这一部分的复杂度是 O(n) 的。

则最终的做法即在高度区间长度不超过 n+1 时暴力计算,超过时先计算出 [1, n+1] 时的值,再代入整个区间的长度,求出答案,做下一个高度区间时将前一个区间的答案视作 0(因为比所有下一个区间的高度都低),一层层做下去即可。时间复杂度 O(n^2w)

#include <iostream>
#include <algorithm>
#include <string.h>
#include <iomanip>
#include <bitset>
#include <math.h>
#include <string>
#include <vector>
#include <queue>
#include <set>
#include <map>
#define fst first
#define scd second
#define db double
#define ll long long
#define mp make_pair
#define pb push_back
#define eb emplace_back
#define vi vector <int>
#define pii pair <int, int>
#define sz(x) ((int)x.size())
#define ms(f, x) memset(f, x, sizeof(f))
#define L(i, j, k) for (int i=(j); i<=(k); ++i)
#define R(i, j, k) for (int i=(j); i>=(k); --i)
#define ACN(i, H_u) for (int i=H_u; i; i=E[i].nxt)
using namespace std;
template <typename INT> void rd(INT &res) {
    res=0; bool f=false; char ch=getchar();
    while (ch<'0'||ch>'9') f|=ch=='-', ch=getchar();
    while (ch>='0'&&ch<='9') res=(res<<1)+(res<<3)+(ch^48), ch=getchar();
    res=(f?-res:res);
}
template <typename INT, typename...Args>
void rd(INT &x, Args &...y) { rd(x), rd(y...); }
//dfs
const int mod=1e9+7; 
const int maxn=300, maxm=3000;
const int N=maxn+10, M=maxm+10;
int fac[N], inv[N], d[N<<1], a[N], b[N], f[M][N], id[N][N], idx, n, dcnt, pre[N], suf[N];
//wmr
int mods(int x) { return x<0?x+mod:x; }
int moda(int x) { return x>=mod?x-mod:x; }
struct node {
    int l, r;
    node(int _l, int _r) { l=_l, r=_r; }
    node() {}
    bool operator < (const node &k) const { return r-l<k.r-k.l; }
} p[M];
int quick_power(int x, int y) {
    int res=1;
    while (y) {
        if (y&1) res=(ll)res*x%mod;
        x=(ll)x*x%mod, y>>=1;
    }
    return res; 
}
//incra
void dfs(int l, int r) {
    if (l>r||id[l][r]) return;
    p[id[l][r]=++idx]=node(l, r);
    if (l==r) return;
    L(i, l, r) if (abs(i-l-r+i)<=2) dfs(l, i-1), dfs(i+1, r);
}
void lag(int rh) {
    pre[0]=1; L(i, 1, n+1) pre[i]=(ll)pre[i-1]*(rh-i)%mod;
    suf[n+2]=1; R(i, n+1, 1) suf[i]=(ll)suf[i+1]*(rh-i)%mod;
    L(i, 1, idx) f[i][0]=0;
    L(i, 1, n+1) {
        int v=(ll)pre[i-1]*suf[i+1]%mod*inv[i-1]%mod*inv[n+1-i]%mod*(((n+1-i)&1)?mod-1:1)%mod;
        L(j, 1, idx) f[j][0]=((ll)v*f[j][i]+f[j][0])%mod;
    }
}
//lottle
signed main() {
//  ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
//  freopen(".in", "r", stdin);
//  freopen(".out", "w", stdout);
    fac[0]=1; L(i, 1, maxn) fac[i]=(ll)fac[i-1]*i%mod;
    inv[maxn]=quick_power(fac[maxn], mod-2); R(i, maxn-1, 0) inv[i]=(ll)inv[i+1]*(i+1)%mod;
    rd(n);
    L(i, 1, n) rd(a[i], b[i]), d[++dcnt]=a[i], d[++dcnt]=++b[i]; //左闭右开
    sort(d+1, d+dcnt+1); dcnt=unique(d+1, d+dcnt+1)-d-1;
    L(i, 1, n) a[i]=lower_bound(d+1, d+dcnt+1, a[i])-d, b[i]=lower_bound(d+1, d+dcnt+1, b[i])-d;
    dfs(1, n); sort(p+1, p+idx+1);
    L(i, 0, n+1) f[0][i]=1;
    L(i, 1, dcnt-1) {
        int hn=min(n+1, d[i+1]-d[i]);
        L(j, 1, idx) {
            int l=p[j].l, r=p[j].r;
            L(k, l, r) if (abs(k-l-r+k)<=2&&a[k]<=i&&i<b[k])
            L(h, 1, hn) f[id[l][r]][h]=((ll)f[id[l][k-1]][h]*f[id[k+1][r]][h-1]+f[id[l][r]][h])%mod;
            L(h, 1, hn) f[id[l][r]][h]=moda(f[id[l][r]][h]+f[id[l][r]][h-1]);
        }
        int rh=d[i+1]-d[i];
        if (rh<=n+1) L(j, 1, idx) f[j][0]=f[j][hn];
        else lag(rh);
        L(j, 1, idx) fill(f[j]+1, f[j]+hn+1, 0);
    }
    printf("%d\n", f[id[1][n]][0]);
    return 0;
}