P5469 机器人 题解

· · 题解

这题耗费我一个下午加半个晚上加小半个早上,作此题解以记之。

如果有什么写错的地方还请神犇们不吝赐教,拜谢!

拉格朗日插值优化 DP

朴素 DP + 前缀和优化

本身 DP 就比较巧妙,枚举最大值的位置进行区间 DP,利用了最大值左边的区间和最大值右边的区间互不影响的性质(如果有多个最大值就选取最右边的,这样就保证了其左右两段一定互不干涉)。加上前缀和优化可以达到 O(n^3w)。状态表示:f[l][r][k] 表示在区间 [l,r] 中,最大值与下限的差不超过 k 的方案数。转移就是枚举最大值的位置。转移方程:f[l][r][k]= \sum \limits_{l \le i \le r} {f[l][i-1][k] \times f[i+1][r][k-1]}。解释一下为什么是 f[i+1][r][k-1],因为 DP 中的“最大值”就是最靠右的最大值,所以整区间最大值右边的区间的最大值一定比整区间最大值 k 小(而左区间的最大值可以为 k)。

利用有效段较少进行优化

原本 O(n^3w) 中有 O(n^2) 是枚举区间,但其实合法的区间数很少。可以预处理合法区间数进行 DP,记合法区间数为 m,打表发现 m 其实最多也就 2500 左右。状态表示:f[i][k] 表示第 i 个合法段,最大值和下限的差值不超过 k 的方案数。对于所有区间按照长度排序,使长度单调不减,保证所有区间的子区间都在其之前被求出结果。转移方程变为:f[id[l][r]][k]= \sum \limits_{l \le i \le r且左右区间均合法} {f[id[[l][i-1]][k] \times f[id[i+1][r]][k-1]}。时间复杂度 O(nmw)。但是这个 DP 的复杂度仍然不可接受,原因在于需要在值域范围枚举,而本题值域 w 很大。

拉格朗日插值优化

仍然是状态表示:f[i][k] 表示第 i 个合法段,最大值和下限的差值不超过 k 的方案数。我们发现函数 f[i] 为关于 kn 次多项式(长度为 1 的区间的 f[i][k] 是关于 k 的一次式,再根据转移方程,f[i][k] 这个关于 k 的函数的次数就是区间 i 的长度,最长的区间长度也就是 n 了),可以先求出 n+1 项的值,再用拉格朗日插值求出其它位置的值。由于下标是连续的,拉格朗日插值可以做到 O(n)(简单的说:原本的 x_i 是连续的(就是 [1,n+1]),带入 x=len 后仍然是连续的数。公式 \sum\limits_{1 \le i \le n}{y_i \times \prod\limits_{1 \le j \le n}^{j \neq i}{\frac{x-x_j}{x_i-x_j}}} 的右半部分(即 \prod\limits_{1 \le j \le n}^{j \neq i}{\frac{x-x_j}{x_i-x_j}})的分子分母其实都是一段连乘积挖去一个数,也可以理解成两端连乘积相乘,那么就可以通过预处理前后缀积来实现 O(1) 求乘积,O(n)n 个乘积式的和,具体实现见代码)。

如果区间长度超过 n+1,就先将前 n+1 项 DP 求出,后面部分插值求,否则直接 DP 求。

(注:这里提到的“段”指的是将所有上下界排序之后,相邻两个值之间的段。且上下界在读入的时候已经转化成左闭右开的形式,保证了每个连续段 DP 的正确性。记区间长度 r-llen。 ) 第 i 段的答案理论上应该是 f[i][len],当长度超过 n+1 的时候就是这个函数在 x=len 的时候的值。 每次把答案存在 f[i][0] 里,因为这一段所有合法的情况不会影响下一段的求解,根据乘法原理,答案就是每段(“段”是指值域段,不是区间)结果相乘,这里直接用前一段结果作为后一段的起点,实际上就是乘法原理。(也就是说,在第 i 段 DP 结束之后,f[i][0] 存的是前 i 段答案之积)。 最终答案就是 f[id[1][n]][0] 啦。时间复杂度 O(n^2m)

代码实现(一些注释)

#include<bits/stdc++.h>
using namespace std;
inline int read()
{
    int res=0; bool f=0;
    char ch=getchar();
    while(!isdigit(ch)) f|=(ch=='-'),ch=getchar();
    while(isdigit(ch)) res=res*10+(ch^'0'),ch=getchar();
    return f?-res:res;
}
const int N=605,mod=1000000007;
struct node{
    int l,r;
    bool operator<(const node &t)const{return r-l<t.r-t.l;}
}p[3030];
int n,tot,id[N][N],a[N],b[N];
int fac[N],inv[N];
inline int qmi(int x,int y)
{
    int res=1;
    while(y)
    {
        if(y&1) res=1ll*res*x%mod;
        x=1ll*x*x%mod;
        y>>=1;
    }
    return res;
}
void init(int n)//预处理阶乘和逆元
{
    fac[0]=inv[0]=1;
    for(int i=1;i<=n;i++) fac[i]=1ll*fac[i-1]*i%mod;
    inv[n]=qmi(fac[n],mod-2);
    for(int i=n-1;i>=1;i--) inv[i]=1ll*inv[i+1]*(i+1)%mod;
}
inline void add(int &x,int y)
{
    x+=y;
    if(x>=mod) x-=mod;
}
void dfs(int l,int r)//预处理合法区间
{
    if(id[l][r]||l>r) return;
    id[l][r]=++tot;
    p[tot]={l,r};
    for(int i=l;i<=r;i++)//枚举分界线   
        if(abs((i-l)-(r-i))<=2)
            dfs(l,i-1),dfs(i+1,r);
}
int num[N],cnt;
void read_and_change()
{
    for(int i=1;i<=n;i++)
    {
        num[++cnt]=a[i]=read();
        num[++cnt]=b[i]=read()+1;//左闭右开 后面方便
    }
    sort(num+1,num+cnt+1);
    cnt=unique(num+1,num+cnt+1)-num-1;
    for(int i=1;i<=n;i++)
    {
        a[i]=lower_bound(num+1,num+cnt+1,a[i])-num;
        b[i]=lower_bound(num+1,num+cnt+1,b[i])-num;
    }
}
int f[3030][N],pre[N],suf[N];
inline void lag(int l,int r,int len)
{
    if(len<=n+1) 
    {
        for(int i=1;i<=tot;i++) f[i][0]=f[i][len];
        return;
    }
    for(int i=1;i<=tot;i++) f[i][0]=0;
    pre[0]=suf[n+2]=1;//pre、suf就是把带入len之后的式子的分子前后缀前
    for(int i=1;i<=n+1;i++) pre[i]=1ll*pre[i-1]*((len-i+mod)%mod)%mod;
    for(int i=n+1;i>=1;i--) suf[i]=1ll*suf[i+1]*((len-i+mod)%mod)%mod;
    for(int i=1;i<=n+1;i++)
    {
        int val=1ll*pre[i-1]*suf[i+1]%mod/*分子*/*inv[n+1-i]%mod*inv[i-1]%mod/*分母*/*(((n+1-i)&1)?mod-1:1)%mod/*分母的符号*/;
        for(int j=1;j<=tot;j++)
            add(f[j][0],1ll*val*f[j][i]%mod);
    }
}
int main()
{
    n=read();
    init(n+5);
    dfs(1,n); sort(p+1,p+tot+1);
    read_and_change();//读入并离散化
    for(int i=0;i<=n+1;i++) f[0][i]=1;
    for(int t=1;t<cnt;t++)
    {
        int len=min(n+1,num[t+1]-num[t]);
        for(int i=1;i<=tot;i++)
        {
            int l=p[i].l,r=p[i].r;
            for(int j=l;j<=r;j++)
                if(abs((j-l)-(r-j))<=2&&t>=a[j]&&t<b[j])
                    for(int k=1;k<=len;k++)
                        add(f[id[l][r]][k],1ll*f[id[l][j-1]][k]*f[id[j+1][r]][k-1]%mod);
                        //最初的DP中的“最大值”就是最靠右的最大值,所以右区间的最大值一定比k小(j位置是k)
            for(int k=1;k<=len;k++) 
                add(f[id[l][r]][k],f[id[l][r]][k-1]);//前缀和优化
        }
        lag(num[t],num[t+1],num[t+1]-num[t]);
        for(int i=1;i<=tot;i++)
            for(int j=1;j<=len;j++)
                f[i][j]=0;
    }
    printf("%d\n",f[id[1][n]][0]);
    return 0;
}