P4600题解

· · 题解

前置知识

失配树和最近公共祖先。

思路与做法

将原问题转换为在失配树上寻找最近公共祖先,主要是因为以下两个原因:

理解了上面的,就比较简单了,给所有编码的每个前缀构建字典树,然后跑一遍构建失配树,最后直接输出根到该最近公共祖先的值(注意是原图上的值),样例如下图:

代码

由于遭受了 MLE 的重创,迫不得已将部分 vector 换成了数组,可能使可读性下降。vector 邻接表与链式前向星有内存性能上的差异,因为 vector 扩充时是默认多申请 2 倍空间,所以像这道变态的题目可能会卡内存只能用链式前向星的写法写。

#include<bits/stdc++.h>
#define maxn 1000006
#define mod 1000000007
using namespace std;
int ch[maxn][26],idx=0,fail[maxn];
long long va[maxn];
int Id[maxn*2],len=0,Idx=0,Len[maxn];
// 注意输入的字符串总长度可能超过 tot!
struct Tree{int ne,to;}tr[maxn*2]; // 失配树。
// 原先 Id[i][j] 表示第 i 个表示的第 j 个地区对应的字典树的值,由于 vector 炸了,就用 Id[i]+Len[i] 代替了。
int k=0,head[maxn];
void add(int u,int v){k++,tr[k].to=v,tr[k].ne=head[u],head[u]=k;}
void ins(int id,string s){
    int u=0;
    for(int i=0;i<s.length();i++){
        if(!ch[u][s[i]-'a']) ch[u][s[i]-'a']=++idx;
        va[ch[u][s[i]-'a']]=(va[u]*26+(s[i]-'a'))%mod, // 计算根到该节点的值。
        u=ch[u][s[i]-'a'],Id[++Idx]=u; // 存储每个地区对应字典树上的编号。
    }
}
void bul(){ // 建立失配数组和失配树。
    queue<int> q;
    for(int i=0;i<26;i++) if(ch[0][i])
        q.push(ch[0][i]),add(0,ch[0][i]),add(ch[0][i],0);
    while(!q.empty()){
        int u=q.front(); q.pop();
        for(int i=0;i<26;i++)
            if(ch[u][i])
                fail[ch[u][i]]=ch[fail[u]][i],q.push(ch[u][i]),
                add(ch[u][i],fail[ch[u][i]]),
                add(fail[ch[u][i]],ch[u][i]);
            else ch[u][i]=ch[fail[u]][i];
    }
}
int dfn[maxn],tim=0,st[20][maxn]; // dfs 序 +ST 表找最近公共祖先。
// 详见 https://www.luogu.com.cn/article/pu52m9ue。
int get(int x,int y){
    return dfn[x]<dfn[y]?x:y;
}
struct Node{int x,f;};
void dfs(int u,int fa){ // 用迭代替代了递归。
    stack<Node> stk;
    stk.push({0,0});
    while (!stk.empty()){
        int u=stk.top().x,fa=stk.top().f; stk.pop();
        if(dfn[u]==0){
            tim++,dfn[u]=tim,st[0][tim]=fa;
            for(int i=head[u];i;i=tr[i].ne)
                if(tr[i].to!=fa) stk.push({tr[i].to, u});
        }
    }
}
int lca(int u,int v){
    if(u==v) return u;
    u=dfn[u],v=dfn[v]; if(u>v) swap(u,v);
    int dis=__lg(v-u); u++;
    return get(st[dis][u],st[dis][v-(1<<dis)+1]);
}
int main(){
    int n; cin>>n; string s;
    for(int i=1;i<=n;i++) cin>>s,ins(i,s),len+=s.length(),Len[i]=len;
    // Len[i] 表示前 i 行的总地区数。
    bul(),dfs(0,0);
    for(int i=1;i<=__lg(idx+1);i++)
        for(int j=1;j+(1<<i)-1<=idx+1;j++)
            st[i][j]=get(st[i-1][j],st[i-1][j+(1<<i-1)]);
    int m,i,j,k,l,u,v; cin>>m; while(m--)
        cin>>i>>j>>k>>l,u=Id[Len[i-1]+j],v=Id[Len[k-1]+l],cout<<va[lca(u,v)]<<"\n";
    return 0;
}