题解:P13871 [蓝桥杯 2024 省 Java/Python A] 吊坠

· · 题解

题意

给定 n 个串 s_{1\sim n},有一个完全无向图,点 i 到点 j 的距离为 \operatorname{lcp}(s_i,s_j),其中 s_is_j 均可以循环移动若干次。询问这个无向图的最大生成树。

思路

先考虑用破环成链处理循环移动的情况,然后暴力建图并直接跑最大生成树。我们发现,暴力计算一条边权的时间复杂度是 O(m^3),所以朴素实现的复杂度是 O(n^2m^3)。考虑用字符串哈希优化这个东西。枚举两个串循环移动的步数 i,j,并单次 O(\log m) 询问 lcp,这样可以做到 O(n^2m^2\log m)。不过还能优化,考虑到边权一定小于等于 m,所以可以直接把这两个串复制一遍拼到各自串的末尾,然后求这两个串小于等于 m 的最长公共子串的长度,这是等价于上面算的那个东西的。这样计算边权的时间复杂度被优化为 O(m\log m),总时间复杂度 O(n^2m\log m),可以通过。

Code

注意 Py 代码一定要用 PyPy 提交,不然会 T 飞!

比 cpp 跑得慢了两百倍

N = 210
M = 60
mod1 = 1145141
mod2 = 10**9 + 7
B = 131
s = [['' for _ in range(2 * M)] for _ in range(N)]
n = m = 0
hs1 = [[0 for _ in range(2 * M)] for _ in range(N)]
hs2 = [[0 for _ in range(2 * M)] for _ in range(N)]
pw1 = [0 for _ in range(2 * M + 1)]
pw2 = [0 for _ in range(2 * M + 1)]

pw1[0] = pw2[0] = 1

for i in range(1, 2 * M + 1):
    pw1[i] = pw1[i - 1] * B % mod1
    pw2[i] = pw2[i - 1] * B % mod2

def f1(x: int, l: int, r: int) -> int:
    return (hs1[x][r] - hs1[x][l - 1] * pw1[r - l + 1] % mod1 + mod1) % mod1

def f2(x: int, l: int, r: int) -> int:
    return (hs2[x][r] - hs2[x][l - 1] * pw2[r - l + 1] % mod2 + mod2) % mod2

def check(x: int, y: int, mid: int) -> bool:
    if mid == 0:
        return True
    ss = set()
    for i in range(2 * m - mid):
        ss.add((f1(x, i, i + mid - 1), f2(x, i, i + mid - 1)))
    for i in range(2 * m - mid):
        if (f1(y, i, i + mid - 1), f2(y, i, i + mid - 1)) in ss:
            return True
    return False
def lcp(x: int, y: int) -> int:
    l = 0
    r = m
    while l < r:
        mid = (l + r + 1) // 2
        if check(x, y, mid):
            l = mid
        else:
            r = mid - 1
    return l

f = [i for i in range(N)]

def find(x: int) -> int:
    if x != f[x]:
        fx = find(f[x])
        f[x] = fx
        return fx
    return x

if __name__ == '__main__':
    n, m = map(int, input().split())
    for i in range(n):
        ss = input()
        for j in range(m):
            s[i][j] = s[i][j + m] = ss[j]
    for i in range(n):
        for j in range(2 * m):
            hs1[i][j] = (hs1[i][j - 1] * B % mod1 + ord(s[i][j])) % mod1
            hs2[i][j] = (hs2[i][j - 1] * B % mod2 + ord(s[i][j])) % mod2
    e = []
    for i in range(n):
        for j in range(i + 1, n):
            e.append([i, j, lcp(i, j)])
    e.sort(key = lambda x : x[2], reverse = True)
    ans = 0
    for i in e:
        if find(i[0]) == find(i[1]):
            continue
        ans += i[2]
        f[find(i[0])] = find(i[1])
    print(ans)

cpp 选手写 py 真快炸了

管理大大求过!qwq