题解 P4472 【[BJWC2018]八维】
囧仙
·
·
题解
题目大意
给一个由 n\times m 的字符矩阵无限复制得到的矩阵。等概率随机选择一个起点,并等概率随机选择一个方向(上、下、左、右、左上、左下、右上、右下),将这个方向上 k 个字符组成一个字符串。询问这样选出来的两个字符串相同的概率。
## 题解
我们设 $S_{i,j}$ 表示第 $i+1$ 行第 $j+1$ 列的字符。由无限复制的定义可以得到, $S_{i,j}=S_{i\bmod n,j\bmod m}$ 。
考虑计算出每个起点、每个方向上组成的字符串的哈希值。这里我们采用如下哈希方式:
$$H(X)=\sum_{i=0}^{len} X_iP^{len-i-1} \pmod{2^{64}}$$
其中 $X$ 是一个长度为 $len$ ,下标范围是 $[0,len)$ 的字符串。
如果直接暴力计算所有的字符串,这样的时间复杂度是 $\mathcal O(nmk)$ ,显然是不行的。但事实上,我们可以找到 $X$ 的循环节。可以证明, $X$ 的循环节长度不会超过 $nm$ (考虑 $X$ 的下标的模意义)。
假定有一个循环节是 $\verb!"aab"!$ 的字符串,它的长度为 $11$ 。
$$X=\verb![(aab)(aab)(aab)aa]!$$
我们可以计算出它循环节的哈希值(设为 $H_0$ ,它的长度为 $L_0$ ),那么完整的循环部分的哈希值应该为:
$$\begin{aligned}H(\verb!aabaabaab!)&=H_0\times \left(P^3\right)^2+H_0\times \left(P^3\right)^1+H_0\times \left(P^3\right)^0\cr
&=H_0\left(\left(P^3\right)^2+\left(P^3\right)^1+\left(P^3\right)^0\right)
\end{aligned}$$
显然,括号里的是一个首项为 $1$ ,公比为 $P^3$ 的等比数列。对于一个等比数列,我们有如下公式:
$$q^0+q^1+\cdots +q^{n-1}=\frac{q^n-1}{q-1}$$
但是由于模数是 $2^{64}$ ,而 $q-1=P^{L_0}-1$ 是不存在乘法逆元的。考虑使用分治法。
假设我们要计算 $F(q,n)=q^0+q^1+\cdots +q^{n-1}$ ,那么有:
$$F(q,n)=\begin{cases}
0 & n=0 \cr
1 & n=1 \cr
(1+q)\times F(q^2,\lfloor n\div2\rfloor) & n\equiv 0\pmod 2 \cr
(1+q)\times F(q^2,\lfloor n\div2\rfloor)\times q+1 & n\equiv 1\pmod 2 \cr
\end{cases}$$
这么做的复杂度是 $\mathcal O(\log k)$ 。
对于 $H(X)$ 的后半部分,我们可以直接暴力。这样的复杂度不超过 $\mathcal O(L_0)$ 。但由于 $L_0$ 可能达到 $n\times m$ ,于是总复杂度为 $\mathcal O(n^2m^2)$ 。显然,我们还要进一步优化。
不妨以样例二为例:
$$\begin{bmatrix}
\tt\color{grey}\textcolor{black}{b\ a\ n}\ b\ a\ n\ b\ a\ n\ b\ a\ n \cr
\tt\color{grey}\textcolor{black}{a}\ \underlinesegment{\textcolor{black}{n\ a}\ a\ n\ a\ a\ n\ a\ a\ n}\ a \cr
\tt\color{grey}n\ a\ b\ n\ a\ b\ n\ a\ b\ b\ a\ n \cr
\tt\color{grey}b\ a\ n\ b\ a\ n\ b\ a\ n\ a\ n\ a \cr
\tt\color{grey}a\ n\ a\ a\ n\ a\ a\ n\ a\ b\ a\ n \cr
\tt\color{grey}n\ a\ b\ n\ a\ b\ n\ a\ b\ a\ n\ a \cr
\end{bmatrix}$$
假如我们已经计算出了 $X_0=\verb![naanaanaan]!$ 这一段的哈希值,现在要计算它右边的一个字符串 $X_1=\verb![aanaanaana]!$ 。我们能够发现,它们共用一个循环节。它的哈希值就是把 $H(X_0)$ 乘上 $P$ 、减去最左侧的字母 $\verb!n!$ 的贡献、再加上最右侧 $\verb!a!$ 的贡献。
也就是说,只要我们计算出任意一个长度为 $k$ 的字符串的哈希值,就能 $\mathcal O(\log k)$ 转移到将它向右移动一位的字符串的哈希值。但事实上,我们可以预处理 $P^k$ ,那么单次转移的复杂度就降为了 $\mathcal O(1)$ 。
计算一个长度为 $k$ 的字符串的哈希值的复杂度是 $\mathcal O(L_0\log k)$ ,我们再花费 $\mathcal O(L_0)$ 的复杂度将共用这个循环节的其他的长度为 $k$ 的字符串的哈希值求出来。于是平均每个字符串的复杂度是 $\mathcal O(\log k)$ 。又因为一共只有 $n\times m\times 8$ 个需要计算的字符串,于是总复杂度为 $\mathcal O(nm\log k)$ ,可以通过本题。
另外,我们需要开一个哈希表将哈希值相同的字符串插入进去,然后统计相同哈希值的字符串有多少个,计算出共有多少种选法使得两个字符串哈希值相同。除以总方案数 $n\times m\times 8$ 就行了。
实测跑的飞起(
## 参考代码
```cpp
#include<bits/stdc++.h>
#define up(l,r,i) for(int i=l,END##i=r;i<=END##i;++i)
#define dn(r,l,i) for(int i=r,END##i=l;i>=END##i;--i)
using namespace std;
typedef long long i64;
typedef unsigned int u32;
typedef unsigned long long u64;
const int INF =2147483647;
int qread(){
int w=1,c,ret;
while((c=getchar())> '9'||c< '0') w=(c=='-'?-1:1); ret=c-'0';
while((c=getchar())>='0'&&c<='9') ret=ret*10+c-'0';
return ret*w;
}
const int MAXN=500+3,P=13331;
char S[MAXN][MAXN],T[MAXN*MAXN]; u64 h;
bool vis[MAXN][MAXN]; int n,m,p,s,t; i64 ans1,ans2;
const int dir[8][2]={{1,0},{0,1},{-1,0},{0,-1},{1,1},{1,-1},{-1,1},{-1,-1}};
u64 pwr(u64 a,u64 b){
u64 r=1; while(b){if(b&1) r=r*a; a=a*a,b>>=1;} return r;
}
const int SIZ =999997,MAXM=MAXN*MAXN*8;
int head[SIZ],val[MAXM],nxt[MAXM],tot; u64 ver[MAXM];
void add(int u,u64 v){
ver[++tot]=v,nxt[tot]=head[u],val[tot]=1,head[u]=tot;
}
void inc(u64 h){
for(int p=head[h%SIZ];p;p=nxt[p])
if(ver[p]==h){ans1+=2ll*val[p]+1,++val[p];return;}
add(h%SIZ,h),++ans1;
}
u64 calc(u64 a,u64 b){ //calc a^0+a^1+a^2+...+a^(b-1)
if(b==0) return 0; if(b==1) return 1;
if(b&1) return (1ull+a)*calc(a*a,b>>1)*a+1;
else return (1ull+a)*calc(a*a,b>>1);
}
int main(){
n=qread(),m=qread(),p=qread(); u64 q=pwr(P,p);
up(0,n-1,i) scanf("%s",S[i]);
up(0,7,d){
up(0,n-1,i) up(0,m-1,j) if(!vis[i][j]){
int a=i,b=j; h=0; while(!vis[a][b]){
h=h*P+S[a][b],T[t++]=S[a][b],vis[a][b]=true;
a=(n+a+dir[d][0])%n,b=(m+b+dir[d][1])%m;
}
int x=p%t; h*=calc(pwr(P,t),p/t);
up(0,x-1,k) h=h*P+T[k];
up(0,t-1,k) inc(h),h=h*P+T[(x+k)%t]-T[k]*q;
t=0;
}
up(0,n-1,i) up(0,m-1,j) vis[i][j]=0;
}
ans2=1ll*n*m*8*(1ll*n*m*8); i64 d=__gcd(ans1,ans2);
printf("%lld/%lld\n",ans1/d,ans2/d);
return 0;
}
```