题解 【AT3623 [ARC084D] XorShift】
command_block
2021-05-09 12:51:30
**题意** :有 $n$ 个数 $A_{i...n}$ 写在黑板上(以二进制形式给出),现在有两种可以执行无限次的操作:
当 $x$ 在黑板上时把 $2x$ 也写在黑板上。
当 $x$ 和 $y$ 都在黑板上时,把 $x{\ \rm xor\ }y$ 写在黑板上。(选择的 $x,y$ 可以相同)
求最终有多少个 $≤M$ 的数能被写在黑板上。答案对 $998244353$ 取模。
$n≤6$, $M,A_i≤2^{4000}$ ,时限$\texttt{2s}$。
------------
将给出的数看做 $F_2$ 下的多项式。
两个操作分别对应 “乘 $x$” 和 “两个多项式相加/减” 。
利用这两个操作,可以构造出多项式 $F,G$ 的 $\gcd$。
具体地,求 $\gcd(F,G)$ 时,不妨设 ${\rm deg}(F)\geq {\rm deg}(G)$ ,记 $P=G*x^{{\rm deg}(F)-{\rm deg}(G)}$。
则有 $\gcd(F,G)=\gcd(F-P,G)$。
显然,这样的转化每次都会使得 $F$ 的次数减少,故(当 $F,G$ 其中一者为 $0$ 时)能求得答案。
记 $D$ 为 $A_{1...n}$ 的 $\gcd$ 。不难发现,题目中的操作只能构造出所有 $D$ 的倍数。
- 于是问题转化为 :
有多少个多项式 $P$ 满足下列条件 :
- $P$ 是 $D$ 的倍数。
- $P$ 转化为二进制数后 $\leq M$。
根据数位 $\rm DP$ 的套路,可以将第二个条件转化成若干如下不相交的约束 :
- 钦定前面的若干高位,其余低位随意。
记 $d={\rm deg}(D),m={\rm deg}(M)$ 。若 $d>m$ 则答案为 $1$ (仅有 $0$ 符合条件)。
不难证明,若我们钦定了 $P$ 的高位,只剩下 $d-1$ 个低位没有钦定,那么有唯一一种方法填写低位使得其是 $D$ 的倍数。
- 具体方法 :记被钦定的部分为 $P_0$ ,在低位填上 $P_0\bmod D$ 。
于是,若询问中钦定了前 $k$ 个高位,则 :
- 若 $m-k\geq d-1$ ,方案数为 $2^{m-k-d+1}$。
- 若 $m-k<d-1$ ,这些询问的前 $m-d+1$ 位都是相同的(与 $M$ 相同),只需处理一次。(补全后判定是否 $\leq M$)
利用 `bitset` ,$F_2$ 下的多项式 $\rm gcd$ ,多项式取模的复杂度均可以做到 $O({\rm deg}^2/\omega)$。
于是总复杂度为 $O(nm^2/\omega)$。
```cpp
#include<algorithm>
#include<cstring>
#include<cstdio>
#include<bitset>
#define MaxN 4050
using namespace std;
const int mod=998244353;
struct Poly{bitset<MaxN> f;int m;};
Poly gcd(Poly A,Poly B)
{
while(A.f.any()&&B.f.any()){
if (A.m>B.m){A.f^=B.f<<(A.m-B.m);while(A.m&&!A.f[A.m-1])A.m--;}
else {B.f^=A.f<<(B.m-A.m);while(B.m&&!B.f[B.m-1])B.m--;}
}return A.f.any() ? A : B;
}
Poly fmod(Poly A,Poly B)
{
while(A.m>=B.m){
A.f^=B.f<<(A.m-B.m);
while(A.m&&!A.f[A.m-1])A.m--;
}return A;
}
char s[MaxN],m[MaxN];
Poly d,sav;
int n,pw2[MaxN];
int main()
{
scanf("%d%s",&n,m);
for (int i=1;i<=n;i++){
scanf("%s",s);int len=strlen(s);
sav.f.reset();sav.m=len;
for (int j=0;j<len;j++)
sav.f[len-j-1]=(s[j]=='1');
d=gcd(d,sav);
}
int len=strlen(m),ans=0;
if (len<d.m){puts("1");return 0;}
pw2[0]=1;for (int i=1;i<=len;i++)pw2[i]=(pw2[i-1]<<1)%mod;
for (int k=0;len-k-d.m>=0;k++)
if (m[k]=='1')
ans=(ans+pw2[len-k-d.m])%mod;
sav.f.reset();sav.m=len;
for (int j=d.m-1;j<len;j++)
sav.f[j]=(m[len-j-1]=='1');
sav=fmod(sav,d);
bool fl=1;
for (int i=d.m-2;i>=0;i--)
if (sav.f[i]>(m[len-i-1]=='1')){fl=0;break;}
else if (sav.f[i]<(m[len-i-1]=='1')){fl=1;break;}
ans+=fl;
printf("%d",(ans+mod)%mod);
return 0;
}
```