题解:CF2146F Bubble Sort
MaxBlazeResFire
·
·
题解
$\rm Observation2$:往前缀 $p[1,i)$ 中加入 $p_i$,则 $p_i$ 的代价是它的当前 $\rm rnk$。如果我们建立排列与当前 $\rm rnk$ 序列(满足 $r_i<i$ 的任意序列)建立一一映射(例如 $[1,3,4,2]\rightarrow[0,0,0,2]$),则显然 $\displaystyle b_i=\max_{1\leq j\leq i}r_j$。
$\rm Observation3$:现在问题已经和排列没关系了。直接转为对合法 $r$ 序列计数。一个限制 $(l,r,k)$ 相当于 $r$ 序列的前缀最大值 $\leq k$ 的个数在 $[l,r]$ 内,由于前缀最大值序列单增,那么一个必须要考虑到的核心转化是,**这个限制等价于 $l$ 位置的前缀最大值 $\leq k$ 且 $r+1$ 位置的前缀最大值 $>k$**。
此时显然有了一个 $\rm DP$:$f_{i,j}$ 表示,考虑前 $i$ 个位置,前缀最大值为 $j$ 的方案数。考虑 $i+1$ 这个位置的所有限制一定可以写成前缀最大值在一个区间 $[l,r]$ 内,那么转移方程就是 $\displaystyle f_{i+1,j}=\sum_{k<j}f_{i,k}+j\times f_{i,j}$,特别地,要求 $j\leq i,j\in[l,r]$,可以用前缀和轻松解决。
考虑优化复杂度。显然,需要 $m\leq 1000$ 的条件。序列有 $O(m)$ 个关键点,在这些关键点上有限制;并且在值域上也恰好有 $O(m)$ 个要求,考虑把值域也分段 $[0,v_1],(v_1,v_2],\cdots$。记第 $i$ 个关键点在 $p_i$,值域上的第 $j$ 个关键点为 $v_j$。这么分段之后,对于任意一个序列段 $i$(记为 $[l,r]$),只需要满足两个条件:$r_i<i$;$r_i$ 位于值域段 $[L,R]$ 内。其中 $[L,R]$ 指的是第 $L$ 到第 $R$ 个值域段的并。
那么就好做了。$f_{i,j}$ 表示考虑前 $i$ 段,目前填的最大值域段是 $j$ 的方案数。那么转移的时候枚举 $j\in[L,R]$,$\displaystyle f_{i,j}\leftarrow(c(l_i,r_i,v_j)-c(l_i,r_i,v_{j-1}))\sum_{k=1}^{j-1}f_{i-1,k}$;$\displaystyle f_{i,j}\leftarrow f_{i-1,j}\times c(l_i,r_i,v_j)$。其中 $c(l,r,v)$ 表示在 $[l,r]$ 内填 $\leq v$ 的数并且要求 $r_i<i$ 的方案数。显然,可以把区间分成受 $r_i<i$ 限制的前缀和不受限制的后缀。对于前缀来说一定是一个区间乘积的形式,可以预处理阶乘解决;后者一定是可以用快速幂计算的指数形式。
复杂度 $O(m^2\log n)$。
::::success[Code]
```cpp
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define mod 998244353
#define MAXN 1000005
int n,m,K[MAXN],L[MAXN],R[MAXN],P[MAXN],V[MAXN];
//LimL 表示第 i 段的数至少应该是多少 LimR 表示至多是多少
int LimL[MAXN],LimR[MAXN],fac[MAXN],inv[MAXN],ifac[MAXN];
int f[2005][1005],s[2005][1005];
inline void chkadd( int &x , int k ){ x += k; if( x >= mod ) x -= mod; }
inline void chkequ( int &x , int k ){ x = k; if( x >= mod ) x -= mod; }
inline int fp( int x , int p ){
int res = 1;
while( p ){
if( p & 1 ) res = res * x % mod;
x = x * x % mod;
p >>= 1;
}
return res;
}
inline int calc( int l , int r , int v ){
if( v < l ) return fp( v + 1 , r - l + 1 );
v = min( v , r );
return fac[v] * ifac[l - 1] % mod * fp( v + 1 , r - v ) % mod;
}
inline void solve(){
scanf("%lld%lld",&n,&m);
int pcnt = 0,vcnt = 0;
for( int i = 1 ; i <= m ; i ++ ){
scanf("%lld%lld%lld",&K[i],&L[i],&R[i]); R[i] ++;
P[++pcnt] = L[i],V[++vcnt] = K[i];
if( R[i] <= n ) P[++pcnt] = R[i];
}
P[++pcnt] = n,V[++vcnt] = n - 1;
sort( P + 1 , P + pcnt + 1 ),pcnt = unique( P + 1 , P + pcnt + 1 ) - ( P + 1 );
for( int i = 1 ; i <= m ; i ++ )
L[i] = lower_bound( P + 1 , P + pcnt + 1 , L[i] ) - P,
R[i] = lower_bound( P + 1 , P + pcnt + 1 , R[i] ) - P;
sort( V + 1 , V + vcnt + 1 ),vcnt = unique( V + 1 , V + vcnt + 1 ) - ( V + 1 );
// for( int i = 1 ; i <= pcnt ; i ++ ) cerr << V[i] << "\n";
for( int i = 1 ; i <= m ; i ++ ) K[i] = lower_bound( V + 1 , V + vcnt + 1 , K[i] ) - V;
for( int i = 1 ; i <= pcnt ; i ++ ) LimL[i] = 1,LimR[i] = vcnt;
for( int i = 1 ; i <= m ; i ++ ){
LimR[L[i]] = min( LimR[L[i]] , K[i] );
if( R[i] <= pcnt ) LimL[R[i]] = max( LimL[R[i]] , K[i] + 1 );
}
for( int i = 2 ; i <= pcnt ; i ++ ) LimL[i] = max( LimL[i] , LimL[i - 1] );
for( int i = pcnt - 1 ; i >= 1 ; i -- ) LimR[i] = min( LimR[i] , LimR[i + 1] );
// for( int i = 1 ; i <= pcnt ; i ++ ){
// cerr << i << " " << LimL[i] << " " << LimR[i] << "\n";
// }
f[0][0] = 1; V[0] = -1;
for( int i = 1 ; i <= pcnt ; i ++ ){
s[i - 1][0] = f[i - 1][0];
for( int j = 1 ; j <= vcnt ; j ++ ) chkequ( s[i - 1][j] , s[i - 1][j - 1] + f[i - 1][j] );
for( int j = LimL[i] ; j <= LimR[i] ; j ++ ){
int coef = ( calc( P[i - 1] + 1 , P[i] , V[j] ) - calc( P[i - 1] + 1 , P[i] , V[j - 1] ) + mod ) % mod;
chkadd( f[i][j] , s[i - 1][j - 1] * coef % mod );
coef = calc( P[i - 1] + 1 , P[i] , V[j] );
chkadd( f[i][j] , f[i - 1][j] * coef % mod );
}
// for( int j = 0 ; j <= vcnt ; j ++ ){
// cerr << "f[" << i << "][" << j << "]=" << f[i][j] << "\n";
// }
}
int Ans = 0;
for( int i = 1 ; i <= vcnt ; i ++ ) chkadd( Ans , f[pcnt][i] );
printf("%lld\n",Ans);
for( int i = 0 ; i <= pcnt ; i ++ ){
LimL[i] = LimR[i] = 0;
for( int j = 0 ; j <= vcnt ; j ++ ) f[i][j] = s[i][j] = 0;
}
}
signed main(){
fac[0] = inv[1] = ifac[0] = 1;
for( int i = 1 ; i < MAXN ; i ++ ) fac[i] = fac[i - 1] * i % mod;
for( int i = 2 ; i < MAXN ; i ++ ) inv[i] = ( mod - mod / i ) * inv[mod % i] % mod;
for( int i = 1 ; i < MAXN ; i ++ ) ifac[i] = ifac[i - 1] * inv[i] % mod;
int testcase; scanf("%lld",&testcase);
while( testcase -- ) solve();
return 0;
}
```
::::