题解:CF1906J Count BFS Graph

· · 题解

Solution

bfs 按层进行,每个点能 push 进去一段点。设 f_{i,j} 表示当前正考虑点 P_i ,队列里已经 push 到了位置 j 的方案数,即,P_i 带来的所有更新都是 push 在 j 后面。这个状态有意义应满足 j \ge i

那么 f_{i,j} 可以更新 f_{i+1,k},满足 j \le k \le R_{j+1}。其中 R_j 表示从 j 开始最长能延续多长的一段递增。原因就是一个点 push 进去的一定是一段编号递增的点。现在已经 push 到 j 了,从 j+1 开始要编号递增。

前面这 j 个点已经被标记,无论 P_i 和它们有没有边都不会被 push,所以这些边连不连均可,后面的被 push 进去的点与 P_i 必须有边,没被 push 的必须没有边。但是如果对于每个点我们都考虑它与其他点所有的连边方案的话会算重,比如 (1,2) 会被 1 考虑一次 2 考虑一次,所以强制只考虑位置大于自己的点即可,这样的点共有 j-i 个。我们更新是向后更新,所以考虑编号大于自己的会比较好做。

整理一下就是 f_{i+1,k}\larr 2^{j-i}f_{i,j},其中 j \le k \le R_{j+1}。初始时 f_{1,1}=1,就是说 1 被 push 了但还没考虑它的连边。答案是 f_{n+1,n},因为要考虑完这 n 个点。

这样直接做是三次方的,然后你注意到一个转移是区间加一个固定的数,所以前缀和 + 差分即可,时间复杂度 \mathcal{O}(n^2)

Code

#include<algorithm>
#include<iostream>
#include<cstdio>
#define file(x) freopen(x".in","r",stdin);freopen(x".out","w",stdout);
using namespace std;
typedef long long ll;
namespace FastIO{
    template<typename T=int> T read(){
        T x=0;int f=1;char c=getchar();
        while(!isdigit(c)){if(c=='-') f=~f+1;c=getchar();}
        while(isdigit(c)){x=(x<<3)+(x<<1)+(c^48);c=getchar();}
        return x*f;
    }
    template<typename T> void write(T x){
        if(x<0){putchar('-');x=-x;}
        if(x>9) write(x/10);
        putchar(x%10+'0');
    }
    template<typename T> void Write(T x,char c='\n'){write(x);putchar(c);}
}
using namespace FastIO;
const int MOD=998244353;
const int maxn=5005;
int f[maxn][maxn],pw[maxn],r[maxn],a[maxn],c[maxn];
int main(){
    int n=read();
    f[1][1]=1;pw[0]=1;
    for(int i=1;i<=n;i++) pw[i]=(pw[i-1]*2)%MOD;
    for(int i=1;i<=n;i++) a[r[i]=i]=read();
    for(int i=1;i<=n;i++) while(a[r[i]]<a[r[i]+1]) r[i]++;
    r[n+1]=n;
    for(int i=1;i<=n;i++){
        for(int j=1;j<=n;j++) c[j]=0;
        for(int j=i;j<=n;j++){
        //  for(int k=j;k<=r[j+1];k++){
        //      (f[i+1][k]+=(1ll*pw[j-i]*f[i][j])%MOD)%=MOD; 
        //      printf("%d %d %d\n",i,j,f[i][j]);
        //  }
            int x=(1ll*pw[j-i]*f[i][j])%MOD;
            c[j]=(c[j]+x)%MOD,c[r[j+1]+1]=(c[r[j+1]+1]-x+MOD)%MOD;
        }
        for(int j=1;j<=n;j++) c[j]=(c[j]+c[j-1])%MOD;
        for(int j=i;j<=n;j++) f[i+1][j]=c[j];
    }
    write(f[n+1][n]);
    return 0;
}