题解 P4472 【[BJWC2018]八维】

· · 题解

题目大意

给一个由 n\times m 的字符矩阵无限复制得到的矩阵。等概率随机选择一个起点,并等概率随机选择一个方向(上、下、左、右、左上、左下、右上、右下),将这个方向上 k 个字符组成一个字符串。询问这样选出来的两个字符串相同的概率。

## 题解 我们设 $S_{i,j}$ 表示第 $i+1$ 行第 $j+1$ 列的字符。由无限复制的定义可以得到, $S_{i,j}=S_{i\bmod n,j\bmod m}$ 。 考虑计算出每个起点、每个方向上组成的字符串的哈希值。这里我们采用如下哈希方式: $$H(X)=\sum_{i=0}^{len} X_iP^{len-i-1} \pmod{2^{64}}$$ 其中 $X$ 是一个长度为 $len$ ,下标范围是 $[0,len)$ 的字符串。 如果直接暴力计算所有的字符串,这样的时间复杂度是 $\mathcal O(nmk)$ ,显然是不行的。但事实上,我们可以找到 $X$ 的循环节。可以证明, $X$ 的循环节长度不会超过 $nm$ (考虑 $X$ 的下标的模意义)。 假定有一个循环节是 $\verb!"aab"!$ 的字符串,它的长度为 $11$ 。 $$X=\verb![(aab)(aab)(aab)aa]!$$ 我们可以计算出它循环节的哈希值(设为 $H_0$ ,它的长度为 $L_0$ ),那么完整的循环部分的哈希值应该为: $$\begin{aligned}H(\verb!aabaabaab!)&=H_0\times \left(P^3\right)^2+H_0\times \left(P^3\right)^1+H_0\times \left(P^3\right)^0\cr &=H_0\left(\left(P^3\right)^2+\left(P^3\right)^1+\left(P^3\right)^0\right) \end{aligned}$$ 显然,括号里的是一个首项为 $1$ ,公比为 $P^3$ 的等比数列。对于一个等比数列,我们有如下公式: $$q^0+q^1+\cdots +q^{n-1}=\frac{q^n-1}{q-1}$$ 但是由于模数是 $2^{64}$ ,而 $q-1=P^{L_0}-1$ 是不存在乘法逆元的。考虑使用分治法。 假设我们要计算 $F(q,n)=q^0+q^1+\cdots +q^{n-1}$ ,那么有: $$F(q,n)=\begin{cases} 0 & n=0 \cr 1 & n=1 \cr (1+q)\times F(q^2,\lfloor n\div2\rfloor) & n\equiv 0\pmod 2 \cr (1+q)\times F(q^2,\lfloor n\div2\rfloor)\times q+1 & n\equiv 1\pmod 2 \cr \end{cases}$$ 这么做的复杂度是 $\mathcal O(\log k)$ 。 对于 $H(X)$ 的后半部分,我们可以直接暴力。这样的复杂度不超过 $\mathcal O(L_0)$ 。但由于 $L_0$ 可能达到 $n\times m$ ,于是总复杂度为 $\mathcal O(n^2m^2)$ 。显然,我们还要进一步优化。 不妨以样例二为例: $$\begin{bmatrix} \tt\color{grey}\textcolor{black}{b\ a\ n}\ b\ a\ n\ b\ a\ n\ b\ a\ n \cr \tt\color{grey}\textcolor{black}{a}\ \underlinesegment{\textcolor{black}{n\ a}\ a\ n\ a\ a\ n\ a\ a\ n}\ a \cr \tt\color{grey}n\ a\ b\ n\ a\ b\ n\ a\ b\ b\ a\ n \cr \tt\color{grey}b\ a\ n\ b\ a\ n\ b\ a\ n\ a\ n\ a \cr \tt\color{grey}a\ n\ a\ a\ n\ a\ a\ n\ a\ b\ a\ n \cr \tt\color{grey}n\ a\ b\ n\ a\ b\ n\ a\ b\ a\ n\ a \cr \end{bmatrix}$$ 假如我们已经计算出了 $X_0=\verb![naanaanaan]!$ 这一段的哈希值,现在要计算它右边的一个字符串 $X_1=\verb![aanaanaana]!$ 。我们能够发现,它们共用一个循环节。它的哈希值就是把 $H(X_0)$ 乘上 $P$ 、减去最左侧的字母 $\verb!n!$ 的贡献、再加上最右侧 $\verb!a!$ 的贡献。 也就是说,只要我们计算出任意一个长度为 $k$ 的字符串的哈希值,就能 $\mathcal O(\log k)$ 转移到将它向右移动一位的字符串的哈希值。但事实上,我们可以预处理 $P^k$ ,那么单次转移的复杂度就降为了 $\mathcal O(1)$ 。 计算一个长度为 $k$ 的字符串的哈希值的复杂度是 $\mathcal O(L_0\log k)$ ,我们再花费 $\mathcal O(L_0)$ 的复杂度将共用这个循环节的其他的长度为 $k$ 的字符串的哈希值求出来。于是平均每个字符串的复杂度是 $\mathcal O(\log k)$ 。又因为一共只有 $n\times m\times 8$ 个需要计算的字符串,于是总复杂度为 $\mathcal O(nm\log k)$ ,可以通过本题。 另外,我们需要开一个哈希表将哈希值相同的字符串插入进去,然后统计相同哈希值的字符串有多少个,计算出共有多少种选法使得两个字符串哈希值相同。除以总方案数 $n\times m\times 8$ 就行了。 实测跑的飞起( ## 参考代码 ```cpp #include<bits/stdc++.h> #define up(l,r,i) for(int i=l,END##i=r;i<=END##i;++i) #define dn(r,l,i) for(int i=r,END##i=l;i>=END##i;--i) using namespace std; typedef long long i64; typedef unsigned int u32; typedef unsigned long long u64; const int INF =2147483647; int qread(){ int w=1,c,ret; while((c=getchar())> '9'||c< '0') w=(c=='-'?-1:1); ret=c-'0'; while((c=getchar())>='0'&&c<='9') ret=ret*10+c-'0'; return ret*w; } const int MAXN=500+3,P=13331; char S[MAXN][MAXN],T[MAXN*MAXN]; u64 h; bool vis[MAXN][MAXN]; int n,m,p,s,t; i64 ans1,ans2; const int dir[8][2]={{1,0},{0,1},{-1,0},{0,-1},{1,1},{1,-1},{-1,1},{-1,-1}}; u64 pwr(u64 a,u64 b){ u64 r=1; while(b){if(b&1) r=r*a; a=a*a,b>>=1;} return r; } const int SIZ =999997,MAXM=MAXN*MAXN*8; int head[SIZ],val[MAXM],nxt[MAXM],tot; u64 ver[MAXM]; void add(int u,u64 v){ ver[++tot]=v,nxt[tot]=head[u],val[tot]=1,head[u]=tot; } void inc(u64 h){ for(int p=head[h%SIZ];p;p=nxt[p]) if(ver[p]==h){ans1+=2ll*val[p]+1,++val[p];return;} add(h%SIZ,h),++ans1; } u64 calc(u64 a,u64 b){ //calc a^0+a^1+a^2+...+a^(b-1) if(b==0) return 0; if(b==1) return 1; if(b&1) return (1ull+a)*calc(a*a,b>>1)*a+1; else return (1ull+a)*calc(a*a,b>>1); } int main(){ n=qread(),m=qread(),p=qread(); u64 q=pwr(P,p); up(0,n-1,i) scanf("%s",S[i]); up(0,7,d){ up(0,n-1,i) up(0,m-1,j) if(!vis[i][j]){ int a=i,b=j; h=0; while(!vis[a][b]){ h=h*P+S[a][b],T[t++]=S[a][b],vis[a][b]=true; a=(n+a+dir[d][0])%n,b=(m+b+dir[d][1])%m; } int x=p%t; h*=calc(pwr(P,t),p/t); up(0,x-1,k) h=h*P+T[k]; up(0,t-1,k) inc(h),h=h*P+T[(x+k)%t]-T[k]*q; t=0; } up(0,n-1,i) up(0,m-1,j) vis[i][j]=0; } ans2=1ll*n*m*8*(1ll*n*m*8); i64 d=__gcd(ans1,ans2); printf("%lld/%lld\n",ans1/d,ans2/d); return 0; } ```