[HEOI2012]赵州桥——矩阵乘法+循环节

· · 题解

矩阵乘法

一题两做,出题人是个智障

首先,推式子:

先考虑涂颜色方案数,忽略掉(2*i)^m 这个东西。

f[i] 表示前 i 行都填好颜色的方案数。转移的话,考虑当前行的第一个格子。

1、如果第一个格子颜色和上一行第二个相同,那么这一行第二个格子有 m-1 种涂法。

2、如果第一个格子颜色和上一行第二个不同,那么这一行第二个格子有 m-2 种涂法,这一行第一个格子也有 m-2 种涂法。

所以 f[1]=m*(m-1),f[i]=f[i-1]*(m-1)+f[i-1]*(m-2)^2

f[i]=f[i-1]*(m^2-3m+3)

然后化简一下就是

f[i]=(m^2-3m+3)^{i-1}*m*(m-1)

然后把题目要求的 (2*i)^m乘上。此时答案就是

2^m*m*(m-1)*\sum_{i=1}^n i^m*k^{i-1}

这里 k=(m^2-3m+3)

第二步m\le200 的做法

发现:i^m=[(i-1)+1]^m

然后二项式定理,把式子拆开就可以了。

矩阵需要维护 i^0,i^1...i^m 以及前缀和 sum,矩阵大小为 m+2

复杂度 O(m^3logn)

然后矩阵乘法的时候,如果 a[i][k] 没有值,就直接跳过就行了。否则会被卡常。

这样就有 80 的好成绩了。

代码( 80 分):

#include<bits/stdc++.h>
#define N 205
using namespace std;

int n,m,p,k;
int c[N][N];
struct nd{int a[N][N];}b,t,now;

inline int POW(int a,int b,int ans=1){
    for(;b;b>>=1,a=1ll*a*a%p)
        if(b&1) ans=1ll*ans*a%p;
    return ans;
}
void init()
{
    k=(m*m-3*m+3)%p;
    for(int i=0;i<=m;i++)
        for(int j=0;j<=i;j++)
            c[i][j]=j ? (c[i-1][j-1]+c[i-1][j])%p : 1;
    for(int i=0;i<=m;i++)
        b.a[i][0]=1;
    for(int i=0;i<=m;i++)
        for(int j=0;j<=i;j++)
            now.a[i][j]=1ll*k*c[i][j]%p;
    now.a[m+1][m+1]=now.a[m+1][m]=1;
}
inline nd mul(nd a,nd b){
    memset(t.a,0,sizeof t.a);
    for(int i=0;i<=m+1;i++)
        for(int k=0;k<=m+1;k++) if(a.a[i][k])
            for(int j=0;j<=m+1;j++) if(b.a[k][j])
                t.a[i][j]=(1ll*a.a[i][k]*b.a[k][j]+t.a[i][j])%p;
    return t;   
}
inline nd POW(nd a,int b){
    nd ans=a;b--;
    for(;b;b>>=1,a=mul(a,a))
        if(b&1) ans=mul(ans,a);
    return ans;
}
int work()
{
    b=mul(POW(now,n-1),b);
    return b.a[m+1][0]+b.a[m][0];
}
signed main(){
    cin>>n>>m>>p;init();
    int ans=work();
    cout<<1ll*ans*POW(2,m)%p*m*(m-1)%p;
}

第三步p\le3000 的做法

发现:i^m=(i % p)^m

所以这个式子是有循环节的,循环节长度就是 p ,然后对于每个循环节之间,就是等比数列。先暴力算出来首项,然后求出来公比,剩下的就是等比数列求和了。

因为 p 不是质数,可能没有逆元,所以要用别的方法求和,类似快速幂的倍增思想。

代码(AC):

#pragma GCC optimize(3)
#pragma GCC optimize("Ofast")
#include<bits/stdc++.h>
#define N 205
#define int long long
using namespace std;

int n,m,p,k;
int c[N][N];
struct nd{int a[N][N];}b,t,now;

inline int POW(int a,int b,int ans=1){
    for(;b;b>>=1,a=1ll*a*a%p)
        if(b&1) ans=1ll*ans*a%p;
    return ans;
}
void init()
{
    for(int i=0;i<=m;i++)
        for(int j=0;j<=i;j++)
            c[i][j]=j ? (c[i-1][j-1]+c[i-1][j])%p : 1,now.a[i][j]=1ll*k*c[i][j]%p;
    for(int i=0;i<=m;i++) b.a[i][0]=1;
    now.a[m+1][m+1]=now.a[m+1][m]=1;
}
inline nd mul(nd a,nd b){
    memset(t.a,0,sizeof t.a);
    for(int i=0;i<=m+1;i++)
        for(int k=0;k<=m+1;k++) if(a.a[i][k])
            for(int j=0;j<=m+1;j++) if(b.a[k][j])
                t.a[i][j]=(1ll*a.a[i][k]*b.a[k][j]+t.a[i][j])%p;
    return t;   
}
nd POW(nd a,int b){
    nd ans=a;b--;
    for(;b;b>>=1,a=mul(a,a))
        if(b&1) ans=mul(ans,a);
    return ans;
}
int work(){
    init(); b=mul(POW(now,n-1),b);
    return b.a[m+1][0]+b.a[m][0];
}
int cal(int b,int n){
    if(n==1) return b;
    if(n&1) return 1ll*(cal(b,n-1)+1)*b%p;
    return 1ll*cal(b,n>>1)*(1+POW(b,n>>1))%p;
}
int work2()
{
    int now=0;
    for(int i=1;i<=p;i++)
        now=(now+1ll*POW(i,m)%p*POW(k,i-1))%p;
    now=1ll*now*(cal(POW(k,p),n/p-1)+1)%p;
    for(int i=n-n%p+1;i<=n;i++)
        now=(now+1ll*POW(i,m)%p*POW(k,i-1))%p;
    return now;
}
signed main(){
    cin>>n>>m>>p;k=(m*m-3*m+3)%p;
    int ans=m<=200 ? work() : work2();
    cout<<1ll*ans*POW(2,m)%p*m*(m-1)%p;
}