题解:P7538 [COCI2016-2017#4] Osmosmjerka

· · 题解

P7538 [COCI2016-2017#4] Osmosmjerka

R188752538 记录详情

Main Idea

给定一个字符矩阵,我们将这个矩阵无限复制,形成一个无限大的矩阵。然后,从矩阵中的任意位置出发,可以沿着八个方向(上下左右以及四个斜对角方向)选取连续的 K 个字符,形成一个字符串。题目要求计算,随机选择两个位置和两个方向,得到两个相同的字符串的概率。

Solution

枚举起点 ij 和八个方向 d,于是考虑计算哈希值,由于 2 \leq K \leq 10^9,数据太大了,于是可以考虑倍增计算哈希值。

我们发现对于每个点,只有八个方向,最终能得到 8 \times n \times m 个字符串,不是很多,那我们可以考虑求出这些字符串的哈希值,相同的哈希值代表选到相同字符串的一种可能,直接统计即可。

但是,模数取 998244353 时被卡掉一个点。

你可以选择:

  1. 使用双哈希。
  2. 使用一些不那么有名的模数。
  3. 自然溢出

优化

当你做完后,开心地提交后,发现错了,为什么呢?由于空间复杂度是 O(8 \times n ^ 2 \times \log_{2}{K}) 内存超限了,怎么办呢?

肯定是用类似滚动数组的方式来优化空间复杂度,由于八个方向之间的计算互不影响,所以我们分开算八次,优化复杂度为 O(n ^ 2 \times \log_{2}{K}) 刚好可以过,具体实现见代码。

当你改完后,开心地提交后,发现还是错了,为什么呢?

注意,如果哈希的时候取模,是很慢。可以用自然溢出或不要再同一行代码取太多了模。

Code

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define INF 0x7ffffff

inline int read()
{
    int x = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9')
    {
        if(c == '-')
        {
            f = -1;
        }
        c = getchar();
    }
    while(c >= '0' && c <= '9')
    {
        x = (x << 1) + (x << 3) + c - '0';
        c = getchar();
    }
    return x * f;
}

const ll mod = 998244853;

char c[505][505];
int st[505][505][50];
ll poww[100];
ll pow2[100];
ll ans[20000010], top;

int dx[10] = {0, 0, 1, -1, 1, 1, -1, -1};
int dy[10] = {1, -1, 0, 0, 1, -1, 1, -1};

ll gcd(ll a, ll b)
{
    return !b ? a : gcd(b, a % b);
}

int main()
{
    poww[0] = 133;
    for(int i = 1; i <= 35; i++)
    {
        poww[i] = (poww[i - 1] * poww[i - 1]) % mod;
    }
    pow2[0] = 1;
    for(int i = 1; i <= 35; i++)
    {
        pow2[i] = pow2[i - 1] * 2;
    }
    int n = read(), m = read(), kkk = read();
    for(int i = 0; i < n; i++)
        scanf("%s", c[i]);
    for(int i = 0; i < n; i++)
    {
        for(int j = 0; j < m; j++)
        {
            st[i][j][0] = c[i][j] - 'a' + 1;
        }
    }
    for(int k = 0; k < 8; k++)//分开算八次
    {
        for(int l = 1; l <= 30; l++)
        {
            for(int i = 0; i < n; i++)
            {
                for(int j = 0; j < m; j++)
                {
                    int x = (i + dx[k] * (pow2[l - 1])) % n;
                    int y = (j + dy[k] * (pow2[l - 1])) % m;
                    x = (x + n) % n;
                    y = (y + m) % m;
                    st[i][j][l] = (st[i][j][l - 1] * poww[l - 1] + st[x][y][l - 1]) % mod;
                }
            }
        }
        for(int i = 0; i < n; i++)
        {
            for(int j = 0; j < m; j++)
            {
                ll Hash = 0;
                int x = i, y = j;
                for(int l = 30; l >= 0; l--)
                {
                    if(kkk & (1 << l))
                    {
                        Hash = Hash * poww[l] + st[x][y][l];
                        x = (x + pow2[l] * dx[k]) % n;
                        y = (y + pow2[l] * dy[k]) % m;
                        x = (x + n) % n;
                        y = (y + m) % m;
                    }
                }
                ans[++top] = Hash;
            }
        }
    }
    sort(ans + 1, ans + top + 1);
    ll cnt = 0;
    ll ans1 = 0;
    ll ans2 = top * top;
    for(int i = 1; i <= top; i++)
    {
        if(i > 1 && ans[i] == ans[i - 1])
        {
            cnt++;
        }
        else
        {
            ans1 += 1LL * cnt * cnt;
            cnt = 1;
        }
    }
    ans1 += cnt * cnt;
    printf("%lld/%lld", ans1 / gcd(ans1, ans2), ans2 / gcd(ans1, ans2));
    return 0;
}