题解 P5772 【[JSOI2016]位运算】

· · 题解

强烈推荐,戳此↓,看看我的博客

更好的阅读体验

题目大意

题目链接

给定两个整数n, k和一个01串S。我们设R是一个二进制数,它的二进制表示,就是S重复k次。请你选出n个不同的、小于R的非负整数(也就是值在[0,R-1]之间),使得它们的异或和为0

数据范围:3\leq n\leq 7,1\leq k\leq 10^5,1\leq |S|\leq 50

本题题解

假设我们已经知道了R(也就是把S重复k次大力展开),该怎么做?设选出的这n个数为x_1,x_2,\dots,x_n。为了保证这n个数互不相同且都小于R,不妨设R>x_1>x_2>\dots >x_n。不妨设x_0=R

因为n很小,考虑状压DP。设dp[i][\text{mask}]表示考虑了R的前i位(从高到低);\text{mask}是一个长度为n的二进制数,对于第j个位置 (1\leq j\leq n),如果当前(只考虑数的前i位)x_j=x_{j-1},则\text{mask}j位为1,否则\text{mask}j位为0。转移时,枚举x_1\dots x_n的第i位分别填什么,这样可以计算出新的\text{mask}'。令dp[i][\text{mask}']\texttt{+=}dp[i-1][\text{mask}]即可。

这个DP的时间复杂度是O(|R|\cdot (2^n)^2\cdot n)=O(k\cdot |S|\cdot 2^{2n}\cdot n)的,无法通过本题。

发现上面的这个DP,没有用到“R是由一个非常短的串S,重复k次得来的”,这一特殊条件。我们发现,在Sk次重复中,每一次的转移其实都是一样的。那么就容易想到做矩阵快速幂

那么关键就是要求出转移矩阵:\text{trans}[\text{mask}_1][\text{mask}_2]表示从dp[x\cdot|S|][\text{mask}_1]转移到dp[(x+1)\cdot|S|][\text{mask}_2]的系数 (0\leq x<k)。可以枚举\text{mask}_1,令dp[0][\text{mask}_1]=1,然后对S做一遍上面的那个DP,就可以对所有\text{mask}_2求出\text{trans}[\text{mask}_1][\text{mask}_2]了。这部分时间复杂度O(|S|\cdot 2^{3n}\cdot n)。可以承受。

然后就对这个大小为2^n\times 2^n的矩阵做矩阵快速幂即可。这部分时间复杂度O(2^{3n}\log k)

总时间复杂度O(|S|\cdot 2^{3n}\cdot n+2^{3n}\log k)

参考代码:

//problem:LOJ2075
#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fi first
#define se second
#define SZ(x) ((int)(x).size())

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;

const int MOD=1e9+7;
inline int mod1(int x){return x<MOD?x:x-MOD;}
inline int mod2(int x){return x<0?x+MOD:x;}
inline void add(int& x,int y){x=mod1(x+y);}
inline void sub(int& x,int y){x=mod2(x-y);}
inline int pow_mod(int x,int i){int y=1;while(i){if(i&1)y=(ll)y*x%MOD;x=(ll)x*x%MOD;i>>=1;}return y;}

const int MAXN=7,MAXK=1e5,MAXS=50;
const int SIZE2=1<<MAXN;
int n,K,len,bitcnt[SIZE2],dp[MAXS+5][SIZE2];
char s[MAXS+5];
struct Matrix{
    int a[SIZE2][SIZE2],sz;
    Matrix(){
        memset(a,0,sizeof(a));
        sz=0;
    }
};
Matrix operator*(const Matrix& X,const Matrix& Y){
    Matrix Z;
    assert(X.sz==Y.sz);
    Z.sz=X.sz;
    for(int i=0;i<=X.sz;++i){
        for(int j=0;j<=X.sz;++j){
            for(int k=0;k<=X.sz;++k){
                add(Z.a[i][j],(ll)X.a[i][k]*Y.a[k][j]%MOD);
            }
        }
    }
    return Z;
}
Matrix mat_pow(Matrix X,int i){
    Matrix Y;
    Y.sz=X.sz;
    for(int j=0;j<=X.sz;++j)Y.a[j][j]=1;
    while(i){
        if(i&1)Y=Y*X;
        X=X*X;
        i>>=1;
    }
    return Y;
}

int main() {
    cin>>n>>K;
    cin>>(s+1);len=strlen(s+1);
    int sz=(1<<n)-1;
    for(int i=1;i<=sz;++i)bitcnt[i]=bitcnt[i>>1]+(i&1);
    Matrix trans;
    trans.sz=sz;
    for(int st=0;st<=sz;++st){
        memset(dp,0,sizeof(dp));
        dp[0][st]=1;
        for(int i=1;i<=len;++i){
            for(int j=0;j<=sz;++j)if(dp[i-1][j]){
                for(int k=0;k<=sz;++k)if(bitcnt[k]%2==0){
                    static int curs[MAXN+5];
                    curs[0]=s[i]-'0';
                    for(int l=1;l<=n;++l)curs[l]=((k>>(l-1))&1);
                    bool fail=false;
                    int newj=0;
                    for(int l=1;l<=n;++l){
                        if((j>>(l-1))&1){
                            //之前是等于的
                            if(curs[l]>curs[l-1]){fail=1;break;}
                            if(curs[l]==curs[l-1])newj|=(1<<(l-1));
                        }
                    }
                    if(fail)continue;
                    add(dp[i][newj],dp[i-1][j]);
                }
            }
        }
        for(int ed=0;ed<=sz;++ed){
            trans.a[st][ed]=dp[len][ed];
        }
    }
    trans=mat_pow(trans,K);
    Matrix res;
    res.a[0][sz]=1;
    res.sz=sz;
    res=res*trans;
    cout<<res.a[0][0]<<endl;
    return 0;
}