[KOI 2021 Round 2] 最长公共括号子串 TJ

· · 题解

题面

给定两个括号字符串 AB,求满足同时是两个字符串的子串的合法括号串的最长长度,有多组数据,且 T 能到 5 \times 10^5

思路

pre_i = \left\{ \begin{aligned} pre_{i-1}+1 \quad A_i = \texttt{"("}\\ pre_{i-1}-1 \quad A_i = \texttt{")"} \end{aligned} \right.

那么如果 A 中子串 [l,r] 满足 \min\limits^{r}_{i=l} pre_i \ge pre_{l-1}pre_r = pre_{l-1} 那么则为合法括号串。

用 ST 表维护 pre 即可做到 \mathcal{O}(1) 判断合法。

于是我们可以枚举子串在 A 中的左端点 l,然后二分求出在 [l,l+blcp_l-1] 范围内最后一次 \min\limits^{r}_{i=l} pre_i \ge pre_{l-1} 的地方(blcp_i 为满足同时是两个字符串的子串且在 A 中左端点在 i 的字符串的最大长度),然后又二分求出最后一处满足 pre_r = pre_{l-1} 的地方并记录,然后就做完了。

目前问题在于如何求出 blcp,考虑后缀数组。

非常套路地将 AB 连在一起中间随便插个除括号外的字符然后跑后缀数组,然后分别从左往右从右往左扫一遍 sablcp_x 即为 \min\limits^{x}_{i=a+1} height_iasa 中离 x 最近的在 B 中的后缀),代码大概长这样(n 为连起来后的长度,mA 的长度):

for (int i = 1,j = 0,minn = 0x3f3f3f3f;i <= n;i++){
    if (sa[i] > m+1) minn = 0x3f3f3f3f,j = 1;
    else{
        minn = min(minn,height[i]);
        if (j && sa[i] <= m) blcp[sa[i]] = minn;
    }
}
for (int i = n,j = 0,minn = 0x3f3f3f3f;i;i--){
    if (sa[i] > m+1) minn = 0x3f3f3f3f,j = 1;
    else{
        minn = min(minn,height[i+1]);
        if (j && sa[i] <= m) blcp[sa[i]] = max(blcp[sa[i]],minn);
    }
}

于是这道题就做完了

注意事项

都是多测而且测试用例还这么多惹出来的

Code

#include <bits/stdc++.h>
#define maxn 2000005 
using namespace std;
string str,a,b;
int sa[maxn],rnk[maxn],lg[maxn],st[maxn][25],lastSa[maxn],pre[maxn],lastRnk[maxn],height[maxn],blcp[maxn],n,m,cn,cnt[maxn],T,id[256];
int l,r,mid,ans,maxx;
void clear(){
    for (int i = 1;i <= n;i++) cnt[i] = 0;//不要使用 memset 
}
void initSABao(){
    for (int i = 1;i <= n;i++) sa[i] = i;
    sort(sa+1,sa+n+1,[&](int a,int b){
        for (int i = a,j = b;i <= n && j <= n;i++,j++){
            if (str[i] > str[j]) return 0;
            else if (str[i] < str[j]) return 1;
        }
        return (int)(n-a+1 < n-b+1);
    });
    for (int i = 1;i <= n;i++) rnk[sa[i]] = i;
}
void initSA(){
    if (n <= 800){//小于 800 就暴力 
        initSABao();
        return;
    }
    for (int i = 1;i <= n;i++) rnk[i] = id[str[i]],cn = max(cn,id[str[i]]);
    for (int i = 1;i <= n;i++) ++cnt[rnk[i]];
    for (int i = 1;i <= cn;i++) cnt[i] += cnt[i-1];
    for (int i = n;i;i--) sa[cnt[rnk[i]]--] = i;
    for (int w = 1;w < n;w <<= 1){
        int cur = 0;
        for (int i = n-w+1;i <= n;i++) lastSa[++cur] = i;
        for (int i = 1;i <= n;i++){
            if (sa[i] > w) lastSa[++cur] = sa[i]-w;
        }
        clear();
        for (int i = 1;i <= n;i++) cnt[rnk[lastSa[i]]]++;
        for (int i = 1;i <= cn;i++) cnt[i] += cnt[i-1];
        for (int i = n;i;i--) sa[cnt[rnk[lastSa[i]]]--] = lastSa[i];
        swap(lastRnk,rnk),cn = 0;
        for (int i = 1;i <= n;i++){
            if (lastRnk[sa[i]] == lastRnk[sa[i-1]] && lastRnk[sa[i]+w] == lastRnk[sa[i-1]+w]) rnk[sa[i]] = cn;
            else rnk[sa[i]] = ++cn;
        }
        if (cn == n) break; 
    }
}
void initHeight(){
    for (int i = 1,j = 0;i <= n;i++){
        if (j) j--;
        while (str[i+j] == str[sa[rnk[i]-1]+j]) ++j;
        height[rnk[i]] = j;
    }
    for (int i = 1,j = 0,minn = 0x3f3f3f3f;i <= n;i++){//扫两遍求blcp 
        if (sa[i] > m+1) minn = 0x3f3f3f3f,j = 1;
        else{
            minn = min(minn,height[i]);
            if (j && sa[i] <= m) blcp[sa[i]] = minn;
        }
    }
    for (int i = n,j = 0,minn = 0x3f3f3f3f;i;i--){
        if (sa[i] > m+1) minn = 0x3f3f3f3f,j = 1;
        else{
            minn = min(minn,height[i+1]);
            if (j && sa[i] <= m) blcp[sa[i]] = max(blcp[sa[i]],minn);
        }
    }
}
void initST(){ 
    for (int i = 1;i <= m;i++) pre[i] = st[i][0] = st[i-1][0]+(str[i] == '(' ? 1 : -1);
    for (int k = 1;k <= lg[m];k++){
        for (int i = 1;i+(1<<k-1) <= n;i++) st[i][k] = min(st[i][k-1],st[i+(1<<k-1)][k-1]);
    }
}
int find(int l,int r){
    int k = lg[r-l+1];
    return min(st[l][k],st[r-(1<<k)+1][k]);
}
void solve(){
    cin >> a >> b,str = ' ' + a + '#' + b,n = str.size()-1,m = a.size(),cn = 0;
    initSA(),initHeight(),initST();
    maxx = 0;
    for (int i = 1;i <= m;i++){
        l = i,r = blcp[i]+l-1,ans = 0;
        while (l <= r){//第一次二分求大于等于的范围 
            mid = l + (r - l >> 1);
            if (find(i,mid) >= pre[i-1]) ans = mid,l = mid + 1;
            else r = mid - 1;
        }
        l = i,r = ans;
        while (l <= r){//第二次二分求最长合法的子串
            mid = l + (r - l >> 1);
            if (find(mid,ans) == pre[i-1]) maxx = max(maxx,mid-i+1),l = mid + 1;
            else r = mid - 1;
        }
    }
    cout << maxx << "\n";
    for (int i = 1;i <= m;i++) blcp[i] = 0;
    clear();
}
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    for (int i = 2,n = maxn-5;i <= n;i++) lg[i] = lg[i/2]+1;
    id['('] = 1,id[')'] = 2,id['#'] = 3;
    cin >> T;
    while (T--) solve();
    return 0;
}