题解:CF2146F Bubble Sort

· · 题解

$\rm Observation2$:往前缀 $p[1,i)$ 中加入 $p_i$,则 $p_i$ 的代价是它的当前 $\rm rnk$。如果我们建立排列与当前 $\rm rnk$ 序列(满足 $r_i<i$ 的任意序列)建立一一映射(例如 $[1,3,4,2]\rightarrow[0,0,0,2]$),则显然 $\displaystyle b_i=\max_{1\leq j\leq i}r_j$。 $\rm Observation3$:现在问题已经和排列没关系了。直接转为对合法 $r$ 序列计数。一个限制 $(l,r,k)$ 相当于 $r$ 序列的前缀最大值 $\leq k$ 的个数在 $[l,r]$ 内,由于前缀最大值序列单增,那么一个必须要考虑到的核心转化是,**这个限制等价于 $l$ 位置的前缀最大值 $\leq k$ 且 $r+1$ 位置的前缀最大值 $>k$**。 此时显然有了一个 $\rm DP$:$f_{i,j}$ 表示,考虑前 $i$ 个位置,前缀最大值为 $j$ 的方案数。考虑 $i+1$ 这个位置的所有限制一定可以写成前缀最大值在一个区间 $[l,r]$ 内,那么转移方程就是 $\displaystyle f_{i+1,j}=\sum_{k<j}f_{i,k}+j\times f_{i,j}$,特别地,要求 $j\leq i,j\in[l,r]$,可以用前缀和轻松解决。 考虑优化复杂度。显然,需要 $m\leq 1000$ 的条件。序列有 $O(m)$ 个关键点,在这些关键点上有限制;并且在值域上也恰好有 $O(m)$ 个要求,考虑把值域也分段 $[0,v_1],(v_1,v_2],\cdots$。记第 $i$ 个关键点在 $p_i$,值域上的第 $j$ 个关键点为 $v_j$。这么分段之后,对于任意一个序列段 $i$(记为 $[l,r]$),只需要满足两个条件:$r_i<i$;$r_i$ 位于值域段 $[L,R]$ 内。其中 $[L,R]$ 指的是第 $L$ 到第 $R$ 个值域段的并。 那么就好做了。$f_{i,j}$ 表示考虑前 $i$ 段,目前填的最大值域段是 $j$ 的方案数。那么转移的时候枚举 $j\in[L,R]$,$\displaystyle f_{i,j}\leftarrow(c(l_i,r_i,v_j)-c(l_i,r_i,v_{j-1}))\sum_{k=1}^{j-1}f_{i-1,k}$;$\displaystyle f_{i,j}\leftarrow f_{i-1,j}\times c(l_i,r_i,v_j)$。其中 $c(l,r,v)$ 表示在 $[l,r]$ 内填 $\leq v$ 的数并且要求 $r_i<i$ 的方案数。显然,可以把区间分成受 $r_i<i$ 限制的前缀和不受限制的后缀。对于前缀来说一定是一个区间乘积的形式,可以预处理阶乘解决;后者一定是可以用快速幂计算的指数形式。 复杂度 $O(m^2\log n)$。 ::::success[Code] ```cpp #include<bits/stdc++.h> using namespace std; #define int long long #define mod 998244353 #define MAXN 1000005 int n,m,K[MAXN],L[MAXN],R[MAXN],P[MAXN],V[MAXN]; //LimL 表示第 i 段的数至少应该是多少 LimR 表示至多是多少 int LimL[MAXN],LimR[MAXN],fac[MAXN],inv[MAXN],ifac[MAXN]; int f[2005][1005],s[2005][1005]; inline void chkadd( int &x , int k ){ x += k; if( x >= mod ) x -= mod; } inline void chkequ( int &x , int k ){ x = k; if( x >= mod ) x -= mod; } inline int fp( int x , int p ){ int res = 1; while( p ){ if( p & 1 ) res = res * x % mod; x = x * x % mod; p >>= 1; } return res; } inline int calc( int l , int r , int v ){ if( v < l ) return fp( v + 1 , r - l + 1 ); v = min( v , r ); return fac[v] * ifac[l - 1] % mod * fp( v + 1 , r - v ) % mod; } inline void solve(){ scanf("%lld%lld",&n,&m); int pcnt = 0,vcnt = 0; for( int i = 1 ; i <= m ; i ++ ){ scanf("%lld%lld%lld",&K[i],&L[i],&R[i]); R[i] ++; P[++pcnt] = L[i],V[++vcnt] = K[i]; if( R[i] <= n ) P[++pcnt] = R[i]; } P[++pcnt] = n,V[++vcnt] = n - 1; sort( P + 1 , P + pcnt + 1 ),pcnt = unique( P + 1 , P + pcnt + 1 ) - ( P + 1 ); for( int i = 1 ; i <= m ; i ++ ) L[i] = lower_bound( P + 1 , P + pcnt + 1 , L[i] ) - P, R[i] = lower_bound( P + 1 , P + pcnt + 1 , R[i] ) - P; sort( V + 1 , V + vcnt + 1 ),vcnt = unique( V + 1 , V + vcnt + 1 ) - ( V + 1 ); // for( int i = 1 ; i <= pcnt ; i ++ ) cerr << V[i] << "\n"; for( int i = 1 ; i <= m ; i ++ ) K[i] = lower_bound( V + 1 , V + vcnt + 1 , K[i] ) - V; for( int i = 1 ; i <= pcnt ; i ++ ) LimL[i] = 1,LimR[i] = vcnt; for( int i = 1 ; i <= m ; i ++ ){ LimR[L[i]] = min( LimR[L[i]] , K[i] ); if( R[i] <= pcnt ) LimL[R[i]] = max( LimL[R[i]] , K[i] + 1 ); } for( int i = 2 ; i <= pcnt ; i ++ ) LimL[i] = max( LimL[i] , LimL[i - 1] ); for( int i = pcnt - 1 ; i >= 1 ; i -- ) LimR[i] = min( LimR[i] , LimR[i + 1] ); // for( int i = 1 ; i <= pcnt ; i ++ ){ // cerr << i << " " << LimL[i] << " " << LimR[i] << "\n"; // } f[0][0] = 1; V[0] = -1; for( int i = 1 ; i <= pcnt ; i ++ ){ s[i - 1][0] = f[i - 1][0]; for( int j = 1 ; j <= vcnt ; j ++ ) chkequ( s[i - 1][j] , s[i - 1][j - 1] + f[i - 1][j] ); for( int j = LimL[i] ; j <= LimR[i] ; j ++ ){ int coef = ( calc( P[i - 1] + 1 , P[i] , V[j] ) - calc( P[i - 1] + 1 , P[i] , V[j - 1] ) + mod ) % mod; chkadd( f[i][j] , s[i - 1][j - 1] * coef % mod ); coef = calc( P[i - 1] + 1 , P[i] , V[j] ); chkadd( f[i][j] , f[i - 1][j] * coef % mod ); } // for( int j = 0 ; j <= vcnt ; j ++ ){ // cerr << "f[" << i << "][" << j << "]=" << f[i][j] << "\n"; // } } int Ans = 0; for( int i = 1 ; i <= vcnt ; i ++ ) chkadd( Ans , f[pcnt][i] ); printf("%lld\n",Ans); for( int i = 0 ; i <= pcnt ; i ++ ){ LimL[i] = LimR[i] = 0; for( int j = 0 ; j <= vcnt ; j ++ ) f[i][j] = s[i][j] = 0; } } signed main(){ fac[0] = inv[1] = ifac[0] = 1; for( int i = 1 ; i < MAXN ; i ++ ) fac[i] = fac[i - 1] * i % mod; for( int i = 2 ; i < MAXN ; i ++ ) inv[i] = ( mod - mod / i ) * inv[mod % i] % mod; for( int i = 1 ; i < MAXN ; i ++ ) ifac[i] = ifac[i - 1] * inv[i] % mod; int testcase; scanf("%lld",&testcase); while( testcase -- ) solve(); return 0; } ``` ::::