题解 P3990 【[SHOI2013]超级跳马】
矩阵优化DP
矩阵乘法几乎不会,我瞎扯证明了1个小时才差不多搞明白怎么弄。
-
DP 设
\large dp_{i,j} 表示跳到第i 行第j 列的方案数。于是有了一个初步的状态转移方程(假设
i,j 都不会越界):\large dp_{i,j}=dp_{i-1,j-1}+dp_{i,j-1}+dp_{i+1,j-1}+dp_{i-1,j-3}+dp_{i-1,j-5}+...\ ... 很明显,它复杂度高,不方便求。
然后
手玩一下你就会发现后面那一大串(dp_{i-1,j-3}+...\ ... 以及往后的)跟dp_{i,j-2} 相等。证明:
首先,奇数+偶数(2)=奇数,
小学知识因为马每次只能横向跳奇数格,所以能一次性达到
(i\ ,\ j-2) 的点同时也能一次性达到(i\ ,\ j) ,且没有其他点(出去第j-2 列上的三个点)能到达(i\ ,\ j) 。然后状态转移方程就变成了
\large dp_{i,j}=dp_{i-1,j-1}+dp_{i,j-1}+dp_{i+1,j-1}+dp_{i,j-2} 复杂度是
O(nm) 的,T飞。 -
矩阵优化:
转移矩阵的大小与
n 有关。如果
n=3 ,转移矩阵是这样哒:1 & 1 & 0 & 1 & 0 & 0\\ 1 & 1 & 1 & 0 & 1 & 0\\ 0 & 1 & 1 & 0 & 0 & 1\\ 1 & 0 & 0 & 0 & 0 & 0\\ 0 & 1 & 0 & 0 & 0 & 0\\ 0 & 0 & 1 & 0 & 0 & 0\\ \end{Bmatrix} \times \begin{Bmatrix} dp_{i,1} \\ dp_{i,2} \\ dp_{i,3} \\ dp_{i-1,1} \\ dp_{i-1,2} \\ dp_{i-1,3} \\ \end{Bmatrix} \ = \ \begin{Bmatrix} dp_{i+1,1} \\ dp_{i+1,2} \\ dp_{i+1,3} \\ dp_{i,1} \\ dp_{i,2} \\ dp_{i,3} \\ \end{Bmatrix} 然后矩阵快速幂求一下就好了。
-
计算答案
你以为怎么着就结束了吗?
其实,还有一个细节问题没有注意到:
当
i=1,j=3 时,回归转移方程式:\large dp_{1,3}=dp_{1,2}+dp_{1,2}+dp_{1,1} 然而这时候
dp_{1,1} 却并不是由前面的点转移过来的,而是特殊初值!!!也就是说对于
dp_{i,j} ,答案都恰好会多算dp_{i,j-2} (这里如果不明白可以自己手玩几组小数据)。综上所述,记录答案时,应该求
dp_{n,m}-dp_{n,m-2} 。当然也可以求
dp_{n,m-1}+dp_{n-1,m-1} 。
多说一句:为何看着各位大佬的矩阵初值都是对角线为1啊,光把虽然没多大区别
对了,答案别忘了%30011。
code(没写注释,个人认为上面写的比较详细了):
#include<bits/stdc++.h>
#define p 30011
using namespace std;
int read()
{
int xsef = 0,yagx = 1;char cejt = getchar();
while(cejt < '0'||cejt > '9'){if(cejt == '-')yagx = -1;cejt = getchar();}
while(cejt >= '0'&&cejt <= '9'){xsef = (xsef << 1) + (xsef << 3) + cejt - '0';cejt = getchar();}
return xsef * yagx;
}
int n,m;
struct node{
int a[110][110];
node(){
memset(a, 0, sizeof(a));
}
}b,c,d;
node operator * (node x, node y){
node z;
for(int i = 1;i <= n * 2;i++){
for(int j = 1;j <= n * 2;j++){
for(int k = 1;k <= n * 2;k++){
z.a[i][j] = (z.a[i][j] + x.a[i][k] * y.a[k][j]) % p;
}
}
}
return z;
}
void build(){
for(int i=1;i<=n;i++){
if(i!=1)
b.a[i][i-1]=1;
b.a[i][i]=1;
if(i!=n)
b.a[i][i+1]=1;
}
for(int i=1;i<=n;i++){
b.a[i][i+n]=1;
}
for(int i=1;i<=n;i++){
b.a[i+n][i]=1;
}
}
void ksm(){
while(m){
if(m&1)
c=c*b;
b=b*b;
m>>=1;
}
}
signed main(){
n=read(),m=read();
build();
d = b;
m-=2;
for(int i=1;i<=n*2;i++)
c.a[i][i]=1;
ksm();
int ans = -c.a[1][n+n];
c=c*d;
ans += c.a[1][n];
printf("%d",ans + p % p);
return 0;
}