【题解】 [COCI2020] Semafor

· · 题解

tag:矩阵乘法

首先,每次操作都是可以转化成一次异或卷积。因为每种状态都可以看作是一个二进制数,两个二进制数异或后就是另外一种状态。而这两个状态合起来的贡献就是相乘的数值。

这样 K 次操作可以进行矩阵快速幂来优化。那么现在的问题是每 K 次都要保证所有的状态是合法的。如果做 \frac n K 次,每次清空不合法的状态的话复杂度还是炸裂。所以我们需要考虑优化一下。

我们考虑把 K 次转移后的结果都压成一个矩阵,这个矩阵的含义就是从 i 转移到到 j 的合法方案数,其中 ij 都是数字,而不是刚才的二进制状态。这样进行 \lfloor \frac n K \rfloor\times K 次转移就可以用这个矩阵来进行快速幂,而剩下 n\mod K 的部分就可以转化回二进制状态来进行快速幂。

我们来捋一下维护的东西和过程:

  1. 先预处理 K 次异或卷积,运用快速幂,得到 f_{i} 为在进行 K 次操作后从 0 到另一个二进制状态为 i 的的方案数。
  2. 再把结果压成矩阵,设 g_{i,j} 表示 K 次操作后从数字 i 转移到数字 j 的方案数,g_{i,j}=f_{mp_i\oplus mp_j}mp_i 表示数字 i 的二进制状态表示。然后和一个只有 g_{x,x}=1 其余为 0 的矩阵乘 \frac n K 次,同样用快速幂。
  3. 把得到的矩阵转回二进制状态表示,再进行 n \bmod K 次异或卷积。

这样,这道题就做完了,答案就是输出每个数字的二进制表示下的异或卷积之后的结果。

代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
ll n,K;int m,X;int sz,num,mx;
int mp[]={3,1,6,21,9,28,18,5,30,29};//每个数的二进制状态表示
ll b[102][102],g[102][102],tt[102][102];
struct martix
{
    ll p[1025];int siz;
    martix (int xx){siz=xx;for(int i=0;i<xx;i++) p[i]=0;}
};
const int mod=1e9+7;
void mul(martix &a,martix &f,martix &t)
{
    for(int i=0;i<mx;i++) 
    {
        if(!a.p[i]) continue;
        for(int j=0;j<mx;j++) 
        {
            if(!f.p[j]) continue;
            t.p[i^j]=(t.p[i^j]+a.p[i]*f.p[j]%mod)%mod;
        }
    }
    for(int i=0;i<mx;i++)
        a.p[i]=t.p[i],t.p[i]=0;
}
void ksm1(martix &a,martix &f,ll K)//异或卷积的快速幂
{
    martix t(a.siz);memset(t.p,0,sizeof(t.p));
    while(K)
    {
        if(K&1) mul(f,a,t);
        mul(a,a,t);K>>=1;
    }
}
void ksm2(ll K)//矩阵乘法的快速幂
{
    while(K)
    {
        if(K&1) 
        {
            for(int k=0;k<num;k++)
            {
                for(int i=0;i<num;i++)
                {
                    if(!g[i][k]) continue;
                    for(int j=0;j<num;j++)
                    {
                        if(!b[k][j]) continue;
                        tt[i][j]=(tt[i][j]+g[i][k]*b[k][j]%mod)%mod;
                    }

                }
            }
            for(int i=0;i<num;i++) for(int j=0;j<num;j++) g[i][j]=tt[i][j],tt[i][j]=0;
        }
        for(int k=0;k<num;k++)
        {
            for(int i=0;i<num;i++)
            {
                if(!b[i][k]) continue;
                for(int j=0;j<num;j++)
                {
                    if(!b[k][j]) continue;
                    tt[i][j]=(tt[i][j]+b[i][k]*b[k][j]%mod)%mod;
                }
            }
        }
        for(int i=0;i<num;i++) for(int j=0;j<num;j++) b[i][j]=tt[i][j],tt[i][j]=0;
        K>>=1;
    }
}
int find(int x){return (m==1)?mp[x]:((mp[x/10]<<5)|mp[x%10]);}
int main()
{
//  freopen("p.in","r",stdin);
    scanf("%d%lld%lld%d",&m,&n,&K,&X);
    if(m==1) sz=5,num=10,mx=32;else sz=10,num=100,mx=1024;
    martix a(mx),t(mx);
    memset(a.p,0,sizeof(a.p));memset(t.p,0,sizeof(t.p));
    for(int i=0;i<sz;i++) a.p[1<<i]=1; 
    martix f(mx);memset(f.p,0,sizeof(f.p));f.p[0]=1;ksm1(a,f,K);
    for(int i=0;i<num;i++)
    {
        for(int j=0;j<num;j++)
            b[i][j]=f.p[find(i)^find(j)];//把结果压成矩阵
    }
    g[X][X]=1; ksm2(n/K);
    martix ans(mx);memset(ans.p,0,sizeof(ans.p));
    for(int i=0;i<num;i++) ans.p[find(i)]=g[X][i];//把矩阵的结果转化回二进制状态
    memset(a.p,0,sizeof(a.p));memset(f.p,0,sizeof(f.p));
    for(int i=0;i<sz;i++) a.p[1<<i]=1; 
    f.p[0]=1;ksm1(a,f,n%K);
    mul(ans,f,t);
    for(int i=0;i<num;i++)
    {
        printf("%lld\n",ans.p[find(i)]);
    }
    return (0-0);
}