题解 P3990 【[SHOI2013]超级跳马】

· · 题解

矩阵优化DP

矩阵乘法几乎不会,我瞎扯证明了1个小时才差不多搞明白怎么弄。

  1. DP

    \large dp_{i,j}表示跳到第i行第j列的方案数。

    于是有了一个初步的状态转移方程(假设i,j都不会越界):

    \large dp_{i,j}=dp_{i-1,j-1}+dp_{i,j-1}+dp_{i+1,j-1}+dp_{i-1,j-3}+dp_{i-1,j-5}+...\ ...

    很明显,它复杂度高,不方便求。

    然后手玩一下你就会发现后面那一大串(dp_{i-1,j-3}+...\ ... 以及往后的)跟dp_{i,j-2}相等。

    证明:

    首先,奇数+偶数(2)=奇数,小学知识

    因为马每次只能横向跳奇数格,所以能一次性达到(i\ ,\ j-2)的点同时也能一次性达到(i\ ,\ j),且没有其他点(出去第j-2列上的三个点)能到达(i\ ,\ j)

    然后状态转移方程就变成了

    \large dp_{i,j}=dp_{i-1,j-1}+dp_{i,j-1}+dp_{i+1,j-1}+dp_{i,j-2}

    复杂度是 O(nm)的,T飞。

  2. 矩阵优化:

    转移矩阵的大小与n有关。

    如果n=3,转移矩阵是这样哒:

    1 & 1 & 0 & 1 & 0 & 0\\ 1 & 1 & 1 & 0 & 1 & 0\\ 0 & 1 & 1 & 0 & 0 & 1\\ 1 & 0 & 0 & 0 & 0 & 0\\ 0 & 1 & 0 & 0 & 0 & 0\\ 0 & 0 & 1 & 0 & 0 & 0\\ \end{Bmatrix} \times \begin{Bmatrix} dp_{i,1} \\ dp_{i,2} \\ dp_{i,3} \\ dp_{i-1,1} \\ dp_{i-1,2} \\ dp_{i-1,3} \\ \end{Bmatrix} \ = \ \begin{Bmatrix} dp_{i+1,1} \\ dp_{i+1,2} \\ dp_{i+1,3} \\ dp_{i,1} \\ dp_{i,2} \\ dp_{i,3} \\ \end{Bmatrix}

    然后矩阵快速幂求一下就好了。

  3. 计算答案

    你以为怎么着就结束了吗?

    其实,还有一个细节问题没有注意到:

    i=1,j=3时,回归转移方程式:

    \large dp_{1,3}=dp_{1,2}+dp_{1,2}+dp_{1,1}

    然而这时候dp_{1,1}却并不是由前面的点转移过来的,而是特殊初值!!!

    也就是说对于dp_{i,j},答案都恰好会多算dp_{i,j-2}(这里如果不明白可以自己手玩几组小数据)。

    综上所述,记录答案时,应该求dp_{n,m}-dp_{n,m-2}

    当然也可以求dp_{n,m-1}+dp_{n-1,m-1}

多说一句:为何看着各位大佬的矩阵初值都是对角线为1啊,光把dp_{1,1}初始化为1,计算答案时只用第一行不就行吗?虽然没多大区别

对了,答案别忘了%30011。

code(没写注释,个人认为上面写的比较详细了):


#include<bits/stdc++.h> 
#define p 30011
using namespace std;
int read()
{
    int xsef = 0,yagx = 1;char cejt = getchar();
    while(cejt < '0'||cejt > '9'){if(cejt == '-')yagx = -1;cejt = getchar();}
    while(cejt >= '0'&&cejt <= '9'){xsef = (xsef << 1) + (xsef << 3) + cejt - '0';cejt = getchar();}
    return xsef * yagx;
}
int n,m;
struct node{
    int a[110][110];
    node(){
        memset(a, 0, sizeof(a));
    }
}b,c,d;
node operator * (node x, node y){
    node z;
    for(int i = 1;i <= n * 2;i++){
        for(int j = 1;j <= n * 2;j++){
            for(int k = 1;k <= n * 2;k++){
                z.a[i][j] = (z.a[i][j] + x.a[i][k] * y.a[k][j]) % p;
            }
        }
    }
    return z;
}
void build(){
    for(int i=1;i<=n;i++){
            if(i!=1)
                b.a[i][i-1]=1;
            b.a[i][i]=1;
            if(i!=n)
                b.a[i][i+1]=1;
        }
    for(int i=1;i<=n;i++){
        b.a[i][i+n]=1;
    }
    for(int i=1;i<=n;i++){
        b.a[i+n][i]=1;
    }
}
void ksm(){
    while(m){
        if(m&1)
            c=c*b;
        b=b*b;
        m>>=1;
    }
}
signed main(){
    n=read(),m=read();
    build();
    d = b;
    m-=2;
    for(int i=1;i<=n*2;i++)
        c.a[i][i]=1;
    ksm();
    int ans = -c.a[1][n+n];
    c=c*d;
    ans += c.a[1][n];
    printf("%d",ans + p % p);
    return 0;
}