CF821E题解

· · 题解

[]()

你在 (0,0)

(x,y) 时,每次移动可以到达 (x+1,y+1),(x+1,y),(x+1,y-1)

平面上有 n 条线段,平行于 x 轴,参数为a_ib_ic_i,表示在 (a_i,c_i)(b_i,c_i) 的一条线段,保证 b_i=a_{i+1}

要求你一直在线段的下方且在 x 轴上方,即 a_i \leq x \leq b_i 时, 0 \leq y \leq c_i

问:到达 (k,0) 的方案数,方案数对 10^9+7 取模。

我们设 dp_{i,j} 表示走到 (i,j) 的方案数,k_i 表示 x 坐标为 i 时行动的上界。

按照题意,递推式:

dp_{i,j}=\begin{cases}dp_{i-1,j-1}+dp_{i-1,j}&(j=k_i)\\dp_{i-1,j-1}+dp_{i-1,j}+dp_{i-1,j+1}&(0<j<k_i)\\dp_{i-1,j}+dp_{i-1,j+1}&(j=0)\end{cases}

那么,我们可以构造一个形如 \begin{pmatrix}1&1&0&\cdots&0&0&0\\1&1&1&\cdots&0&0&0\\\vdots&&&\ddots&&&\vdots\\0&0&0&\cdots&1&1&1\\0&0&0&\cdots&0&1&1\end{pmatrix} 的矩阵进行转移。用快速幂优化即可。

对于矩阵快速幂,这是一种优化递推的算法,可以用 O(\lg n) 的时间复杂度进行 n 次递推。

首先,定义矩阵乘法:

对于一个 n\times k 的矩阵 A 和一个 k\times m 的矩阵 B 还有一个矩阵 C=A\times B。那么 C 是一个 n\times m 的矩阵,且 C_{ij}=\sum\limits_{p=1}^kA_{i,p}\times B_{p,j}(i=\overline{1,2,\cdots,n},j=\overline{1,2,\cdots,m})

然后,因为矩阵乘法满足结合律,所以我们有了矩阵快速幂。矩阵快速幂的话就是把快速幂的板子套在矩阵上而已。

代码:

  #include<bits/stdc++.h>
  #define int long long 
  using namespace std;
  const int mod=1e9+7;
  int n,k,a,b,c;
  struct M{
      int a[20][20];
      M operator*(M t){
          M res;
          for(int i=0;i<=15;i++)for(int j=0;j<=15;j++)res.a[i][j]=0;
          for(int i=0;i<=15;i++){
              for(int j=0;j<=15;j++){
                  for(int k=0;k<=15;k++){
                      res.a[i][j]=(res.a[i][j]+a[i][k]*t.a[k][j]%mod)%mod;
                  }
              }
          }
          return res;
      }
      M operator^(int k){
          M res,_=*this;
          for(int i=0;i<=15;i++)for(int j=0;j<=15;j++)res.a[i][j]=(i==j);
          while(k){
              if(k&1)res=res*_;
              _=_*_;
              k>>=1;
          }
          return res;
      }
  }ans,base;
  signed main()
  {
      scanf("%lld%lld",&n,&k);
      for(int j=0;j<=15;j++)for(int k=0;k<=15;k++)ans.a[j][k]=0;
      ans.a[0][0]=1;
      for(int i=1;i<=n;i++){
          scanf("%lld%lld%lld",&a,&b,&c);
          if(b>=k)b=k;
          for(int j=0;j<=15;j++)for(int l=0;l<=15;l++)base.a[j][l]=0;
          for(int j=0;j<=c;j++){
              base.a[j][j]=1;
              if(j<c)base.a[j][j+1]=1;
              if(j)base.a[j][j-1]=1;
          }
          ans=(base^(b-a))*ans;
          if(b>=k)break;
      }
      printf("%lld",ans.a[0][0]);
      return 0;
  }