CF1671F Permutation Counting 题解

· · 题解

\max(k,x)=N

我想来讲一下正经一点的,不搜索不打表的 \mathcal{O}(N^7) dp 的做法。

考虑将排列划分为极小的若干个区间 [l_i,r_i] 使得 p_{l_i}\sim p_{r_i} 恰好为数字 l_i\sim r_i 的排列(本原连续段)。这样做的好处是,两段之间绝对不可能存在逆序和下降,且这种划分方式一定是唯一的。

下证:

  1. 非严格递增的 [l_i,r_i] 至多有 11 个。

原因:每个非严格递增的区间必存在下降。

  1. 非严格递增的 [l_i,r_i] 区间长度总和不超过 22

原因:假定 [l_i,r_i] 长度为 len\geq 2,对于所有 i\in[l_i,r_i-1][l_i,i][i+1,r_i] 这两个区间之间必然存在逆序对,否则 \max([l_i,i])<\min([i+1,r_i]),这个区间就不是本原连续段了。

因此长度为 len 的区间至少存在 len-1 个逆序对,也即区间长度总和不超过 2\times 11=22

由此可以发现 p_i=i 对绝大多数 i 成立,那么我们只需要考虑那些非递增的 [l_i,r_i],最后插板把 n-c(c\leq22) 个数插进去就可以。

考虑先对每段 dp。

容易想到设 dp 状态 f_{i,x,y,j} 表示长度为 i 的排列,逆序数为 x,下降数为 y,最后一个数是 j 的方案数。转移即枚举 p_i=k

但这样是错的,因为没有保证其是本原连续段。

问题出在哪里?

举个例子,i=4 时,若枚举 p_i=p_4=4,那么其实当前状态下 p_1\sim p_3 可以被划分出去。为了保证本原性,一定要在之后的某个 p_z 处,令 p_z=1/2/3,把 3 “抬升”上去才可以。

由此可以看出我们还需要记一维状态 u 表示要求之后的某一个 p_z 必须 \leq uu=0 表示无要求。

初始化 dp_{1,0,0,1,1}=1,之后每次取 p_i=i(i\geq 2) 时,将 u 更新为 \min(u,i-1),而取 p_i\leq u 时,将 u 更新为 0 即可。

转移的总复杂度为 \mathcal{O}(N^6)

然后求出 g_{i,j,x,y} 表示 j 个长度和为 i 的区间,共有 x 个逆序数和 y 个下降数的方案数,简单背包合并即可,复杂度 \mathcal{O}(N^7)

对每组询问 (n,k,m),枚举非严格递增的区间个数为 j,长度总和为 i,余下 n-i 个数插进 j+1 个空隙里,空隙可空,方案数 g_{i,j,k,m}\times \begin{pmatrix} n-i \\ j \end{pmatrix},组合数下指标较小,用定义求,然后累和即可。最后一部分复杂度 \mathcal{O}(TN^3)

code

#include <bits/stdc++.h>
const int mod = 998244353;
inline int mul(int x, int y){
  return (int)(1ll * x * y % (1ll * mod));
}
inline int add(int x, int y){
  return x + y >= mod ? x + y - mod : x + y;
}
inline int minus(int x, int y){
  return x < y ? x - y + mod : x - y;
}
inline int Qpow(int x, int y){
  int r = 1;
  while(y){
    if(y & 1) r = mul(r, x);
    x = mul(x, x);
    y >>= 1;
  }
  return r;
}
int f[13][12][12], g[23][12][12][12], ifac[12];
inline int C(int x, int y){
  int r = 1;
  for(int i = x; i >= x - y + 1; --i) r = mul(r, i);
  for(int i = 1; i <= y; ++i) r = mul(r, ifac[i]);
  return r;
}
void solve(){
  int n, k, m; int ans = 0;
  scanf("%d%d%d", &n, &k, &m);
  for(int i = 1; i <= 22 && i <= n; ++i)
    for(int j = 1; j <= 11; ++j)
       ans = add(ans, mul(g[i][j][k][m], C(n - i + j, j)));
  printf("%d\n", ans);
  return ;
}
int dp[13][12][12][13][13];
int main(){
  int T = 1;
  scanf("%d", &T);
  ifac[0] = 1;
  for(int i = 1; i <= 11; ++i) ifac[i] = Qpow(i, mod - 2);
  dp[1][0][0][1][1] = 1;
  for(int i = 2; i <= 12; ++i)
    for(int x = 0; x <= std::min(11, (i - 2) * (i - 1) / 2); ++x)
      for(int y = 0; y <= i - 2; ++y)
        for(int j = 1; j <= i - 1; ++j)
          for(int u = 0; u <= i - 1; ++u){
            if(!dp[i - 1][x][y][j][u]) continue;
            for(int k = 1; k <= i; ++k){
              int rj, nx, ny, nu;
              if(j >= k) rj = j + 1;
              else rj = j;
              if(rj > k) ny = y + 1;
              else ny = y;
              nx = x + i - k;
              if(nx > 11 || ny > 11) continue;
              if(k <= u) nu = 0;
              else nu = u;
              if(k == i && u == 0) nu = k - 1;
              dp[i][nx][ny][k][nu] = add(dp[i][nx][ny][k][nu], dp[i - 1][x][y][j][u]);
            }
          }
  for(int i = 1; i <= 12; ++i)
    for(int x = 0; x <= 11; ++x)
      for(int y = 0; y <= 11; ++y)
        for(int j = 1; j <= i; ++j)
          f[i][x][y] = add(f[i][x][y], dp[i][x][y][j][0]);
  g[0][0][0][0] = 1;
  for(int j = 1; j <= 11; ++j)
    for(int i = 1; i <= 22; ++i)
      for(int x = 0; x <= 11; ++x)
        for(int y = 0; y <= 11; ++y)
          for(int oi = 2; oi <= i && oi <= 12; ++oi)
            for(int ox = 0; ox <= x; ++ox)
              for(int oy = 0; oy <= y; ++oy)
                g[i][j][x][y] = add(g[i][j][x][y], mul(g[i - oi][j - 1][x - ox][y - oy], f[oi][ox][oy]));
  while(T--) solve();
  return 0;
}