题解:P7361 「JZOI-1」拜神

· · 题解

可能更好的阅读体验

题目很好,考察了 SA 中的并查集思想以及启发式合并,而且实现也不难,值得一做。

形式化题面如下:

给定一个长为 n 的字符串,询问次数为 q,多次询问区间 [l,r] 内最长重复子串的长度。

没有形式化题面感觉都想不出来怎么做 www。

肯定没有那么菜啦,首先考虑二分长度,问题转化为区间内是否存在一个长为 mid 的最长重复子串。

接下来我们考虑这个最长重复子串怎么求,一个比较明显的想法就是后缀数组的 LCP 功能,原命题询问的实质就是是否存在 i,j \in [l,r-mid+1],\operatorname{LCP}(i,j)\ge mid。看到后面这个式子,回忆起品酒大会的思路:从大到小将 Height 数组插入,若仅考虑 \ge L 的 Height,将 sa_{i-1},sa_{i} 之间连边,那么若 p,q 在同一联通块里,表明 \operatorname{LCP}(p,q)\ge L。我们通过并查集和启发式合并就可以做到 O(\log n) 的优秀复杂度啦。

但是有点问题啊,如果我们直接这么做我们并没有考虑区间位置,也就是说在两个联通块启发式合并的时候我们必须要记录区间的位置。我们不妨考虑对于联通块内每一个位置,我们维护它在当前联通块内上一个元素的位置,记作 pre_{i},那么区间限制转化为 \max\limits_{i\in set(L),i\in [l,r-L+1]} pre_{i}\ge l。我们可以通过对每一个联通块开主席树来辅助查询,这样就能够做到优秀的 O(q \log^2 n) 的查询啦,其中两个 \log 由二分和主席树查询贡献。

问题转化为如何维护 pre 的合并。首先,唯一确定一个联通块的信息就是所对应的 LCP 长度 L(具体见上面品酒大会思路),根据品酒大会启发式合并的思路,一次启发式 pre 的变化最多只有 O(\log n) 个,考虑用 set 把联通块内的元素存下来,启发式合并的时候暴力单点修改 pre,这样处理的复杂度是 O(n \log^2 n) 的,可以过。故总时间复杂度为 O(q\log^2 n + n \log^2 n)

请注意二分的实现:

#include<bits/stdc++.h>
#define pir pair<int,int>
using namespace std;
constexpr int MN=5e4+15;
int n,q,pre[MN];
vector<int> vht[MN];
set<int> st[MN];
string s;

struct Segment{
#define ls t[p].lson
#define rs t[p].rson

    struct Node{
        int lson,rson,val;
    }t[MN<<9];
    int tot,rt[MN];

    void pushup(int p){
        t[p].val=max(t[ls].val,t[rs].val);
    }

    void modfiy(int &p,int lst,int l,int r,int pos,int v){
        p=++tot;
        t[p]=t[lst];
        if(l==r){
            t[p].val=max(t[p].val,v);
            return;
        }
        int mid=(l+r)>>1;
        if(mid>=pos) modfiy(ls,t[lst].lson,l,mid,pos,v);
        else modfiy(rs,t[lst].rson,mid+1,r,pos,v);
        pushup(p);
    }

    int query(int p,int l,int r,int fl,int fr){
        if(l>=fl&&r<=fr){
            return t[p].val;
        }
        int mid=(l+r)>>1,ret=0;
        if(mid>=fl) ret=max(ret,query(ls,l,mid,fl,fr));
        if(mid<fr) ret=max(ret,query(rs,mid+1,r,fl,fr));
        return ret;
    }

#undef ls
#undef rs
}sg;

namespace SA{
    int len,sa[MN],x[MN],y[MN],rk[MN],c[MN],ht[MN],ST[30][MN];

    // 接受 string 和 vector_int 输入,其他输入不保证正确性
    // ST表需要手动初始化调用initst函数
    template<typename vct>
    void getsa(vct &s){
        int m=400000;
        len=s.size();
        s.insert(s.begin(),' ');
        for(int i=1;i<=len;i++){
            x[i]=s[i];
            ++c[x[i]];
        }
        for(int i=2;i<=m;i++) c[i]+=c[i-1];
        for(int i=len;i>=1;i--) sa[c[x[i]]--]=i;
        for(int k=1;k<=len;k<<=1){
            int num=0;
            for(int i=len-k+1;i<=len;i++) y[++num]=i;
            for(int i=1;i<=len;i++){
                if(sa[i]>k) y[++num]=sa[i]-k;
            }
            for(int i=1;i<=m;i++) c[i]=0;
            for(int i=1;i<=len;i++) c[x[i]]++;
            for(int i=2;i<=m;i++) c[i]+=c[i-1];
            for(int i=len;i>=1;i--) sa[c[x[y[i]]]--]=y[i],y[i]=0;
            swap(x,y);
            num=1,x[sa[1]]=1;
            for(int i=2;i<=len;i++){
                if(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k]) x[sa[i]]=num;
                else x[sa[i]]=++num;
            }
            if(num==len) break;
            m=num;
        }
        for(int i=1;i<=len;i++) rk[sa[i]]=i;
        for(int i=1,k=0;i<=len;i++){
            if(rk[i]==1) continue;
            if(k) k--;
            int j=sa[rk[i]-1];
            while(i+k<=len&&j+k<=len&&s[i+k]==s[j+k]) k++;
            ht[rk[i]]=ST[0][rk[i]]=k;
        }
    }
}using namespace SA;

int root(int x){
    if(pre[x]==x) return pre[x];
    else return pre[x]=root(pre[x]);
  // 这里用这种合并方式而不是按秩合并
  // 是因为并查集维护的是联通块所属的集合,不用考虑形态变化。
}

void merge(int x,int y,int L){
    int rx=root(x),ry=root(y);
    if(rx==ry) return;
    if(st[rx].size()<st[ry].size()) swap(rx,ry);
    pre[ry]=rx;
    for(auto p:st[ry]){
        auto it=st[rx].lower_bound(p);
        if(it!=st[rx].end()){
            sg.modfiy(sg.rt[L],sg.rt[L],1,n,*it,p);
        }
        if(it!=st[rx].begin()){
            it--;
            sg.modfiy(sg.rt[L],sg.rt[L],1,n,p,*it);
        }
    }
    for(auto p:st[ry]) st[rx].insert(p);
}

int main(){
    cin>>n>>q>>s;
    getsa(s);
    for(int i=2;i<=n;i++){
        vht[ht[i]].push_back(i);
    }
    for(int i=1;i<=n;i++){
        pre[i]=i;
        st[i].insert(i);
    }
    for(int i=n;i>=1;i--){
        sg.rt[i]=sg.rt[i+1];
        for(auto p:vht[i]){
            merge(sa[p],sa[p-1],i);
        }
    }
    while(q--){
        int L,R;
        cin>>L>>R;
        int l=0,r=R-L+1;
        while(l+1<r){
            int mid=(l+r)>>1;
            if(sg.query(sg.rt[mid],1,n,L,R-mid+1)>=L){
                l=mid;
            }else r=mid;
        }
        cout<<l<<'\n';
    }
    return 0;
}