回文树 / 回文自动机(PAM)入门笔记

· · 题解

题解 P5496 【模板】回文树 / 回文自动机(PAM)暨 PAM 入门教程

更新日志:

本文简要解释了如何理解 PAM 的时空复杂度,如果您不理解可以参考。

有任何问题欢迎提问。

尽管本题只需要最基本的回文自动机即可完成,但作为模板题,还是多写一点来的合适,毕竟很多同学来这是学算法的。所以本文包括:

注:为了方便区分字符和字符变量,本文中字符及字符串一律使用 \mathtt 字体。

引入

做到 CF2164H 发现需要 PAM 作为前置然后我不会,于是就来记笔记了。

关于回文的问题,我们已经有一个 Manacher 算法来处理了。Manacher 求解了以每个位置为中心的最长回文子串半径,而回文自动机则是以更大的空间开销为代价,接受一个字符串所有回文子串的自动机。

罗勇军、郭卫斌《算法竞赛》上总结说,PAM 的算法思想概括来说是“奇偶字典树+后缀链跳跃”。私以为这个概括十分精辟准确,所以接下来就分成“建 trie”和“求 fail 指针”两步来讲。

『奇偶字典树』

在正式建树之前,我们需要了解 trans 指针的定义。

PAM 使用类似于 trie 的结构来存储所有本质不同的回文串。普通的 trie 结构类似于这样:

但 PAM 中 trie 是用来存储回文串的,所以只有回文串配出现在它的结点上。怎么保证结点上一定是回文串呢?

方法很简单:trie 上的边不再表示向后添加边上的字符,改为表示向前后各添加一个该边上的字符。按照这个含义,上述 trie 各点对应的字符串就如下图所示了:

但是这样的话,似乎不能存储所有本质不同的回文串:上图中只有长度为偶数的。

所以对于长度为奇数的我们还要额外建一棵 trie。

这就是『奇偶字典树』的思想,上一张图即奇字典树,而再上一张图即偶字典树。而所谓的 trans 指针其实就是树边。也就是说 trans_{x,c}=\overline{cxc}x 是字符串,c 是字符),或者说在 x 两边各加一个 c,例如 trans_{\mathtt{ababa},\mathtt{c}}=\mathtt{cababac}。当然,在编写的过程中,字符串都会用 trie 上的点编号来表示的。

『后缀链跳跃』

联赛的字符串算法一共有两个需要背板子:KMP 和马拉车。

这两种算法在算法思路上的共性特点之一是:知道了 S[1,i-1] 的答案,求解 S_i 的答案。

PAM 也采用相同的增量构造方式。所以我们考虑:知道了以 i-1 结尾的最长回文串(设为 lst),求解以 i 结尾的最长回文串。

如果运气不错,lst 前一个字符正好和 S_i 相同,则有下图所示的情况:

此时 lst' 直接就是 lst 两边各加上一个 S_i,即 lst'=trans_{lst,S_i}

但更多的时候运气不会这么好。

我们考虑,我们刚才直接通过判断 lst 两侧字符是否相等来确定是否能拓展回文串长度的底气是,我们知道 lst 本身是回文的,所以在其两侧同时添加相同的字符当然还是回文的。

为了保持这个性质,我们需要从长到短不断枚举以 i-1 结尾的回文串,直到枚举到某一个时两边字符恰好一致,或者枚举到长为 -1 的串了。为了加快枚举的过程,我们引入了 fail 指针:fail_x 的定义就是 x 除其自身外最长的回文后缀,例如 fail_{\mathtt{ababa}}=\mathtt{aba}

于是『后缀链跳跃』的含义也很清楚了:对于 lst 两边字符不相等的情况,我们只需要不断地使 lst 变为 fail_{lst} 直至两边字符一致或者 lst 已经到奇字典树的根上了(其实此时由于我们定义奇字典树根结点 len=-1 从而有 i=i-len-1,所以必然会判断为两边相等),然后拓展到 lst',如图。

目前为止,我们求出了 lst',但这还不足以建出 PAM,因为我们需要知道新点的 fail

如果求出的 lst' 已经在树上了,则其所有的信息都是确定的,当然不需要额外求解;但如果需要新建这个点,那就需要额外求出它的 fail

而由于求的还是以 i 结尾的回文串,所以思路其实和前面是完全一致的:接着使 lst 变为 fail_{lst} 直至 lst 两边再次相等。这里特判 lst 已经在奇字典树根结点处的情况,因为没法给该点设置 fail 了:这种情况求出的 lst' 长度已经是 1 了,也就不可能有其它非空的回文后缀了,所以 fail 直接置为偶字典树根即可。

至此,我们已经讲完了 PAM 的所有内容

当然,结点上还需要维护题目要求的其他信息,例如本题需要维护后缀链长度(即跳 fail 跳到根所需次数)dep

也许你会很疑惑:其他教材都把 PAM 画出来了,你这还没开始画呢就讲完了?不是说好的多图吗?

别急,这就给你画。

算法流程图解

下面将第三组样例稍加修改,使其出现多次加入到同一点上的情况,即 \mathtt{aabaabbaab} 为例画图。其中实线箭头表示 trans,虚线箭头表示 faillst 所在的后缀链已经高亮显示。

如果你能按照顺序在不提前看的情况下找准每一个 lst 箭头的正确位置,或者干脆能重新画一遍这张图,那么恭喜你,你已经掌握了 PAM!

复杂度

时间

朴素 PAM 建树的时间复杂度是 O(n)

有些同学不理解为什么回文树不断跳 fail 指针而时间复杂度依旧是线性的。这可能是因为你想成了时间复杂度与树高相关。但实际上,跳 fail 的次数显然只与后缀链(fail 链)的长度有关,而每次加入点只会使 lst 所在的 fail 链长增加至多 1,总的跳 fail 次数当然就是 O(n) 了。

空间

朴素 PAM 的空间复杂度是 O(n|\Sigma|),其中 \Sigma 表示字符集。

根据 trie 的思想,PAM 只会为本质不同的回文串创建点,而这样的回文串只有 O(n) 个,单个点的瓶颈在 trans 上,所以总共是 O(n|\Sigma|)

模板代码

PAM 代码的要点绝对不是背下来,因为它代码十分简单,而且没有易错的编码细节。它的重点在于理解算法过程,所以在编写的过程中脑子里可以想着上面的图来强化理解。

再次重复算法要点:

只要把握住上面两点,编写 PAM 其实很简单!

struct pam{
    ll trans[30],fail,len,dep;//dep 表示后缀链长度,也就是本题的答案 
};
string s;
ll lstans,lst=1,p=2;
pam a[500010];
void ins(ll now,ll pos){
    //在 now 后试图加入 s[pos] 字符 
    ll ch=s[pos]-'a';//方便后面书写 
    while(s[pos]!=s[pos-a[now].len-1]){
        now=a[now].fail;//不断跳 fail 直至 now 两侧字符相同 
    }
    if(a[now].trans[ch]){
        //该点已经存在,所有信息都有了,直接退出 
        lstans=a[a[now].trans[ch]].dep;
        lst=a[now].trans[ch];
        return ;
    }
    //该点原来不存在,需要求出 fail,len,dep
    //求 fail 
    if(now==1){//如果当前已经是单字符了 
        a[p].fail=0;//则 fail 是空串 
    }
    else{
        //否则接着跳 fail 
        ll tmp=a[now].fail;
        while(s[pos]!=s[pos-a[tmp].len-1]){
            tmp=a[tmp].fail;
        }
        a[p].fail=a[tmp].trans[ch];
    }

    a[p].len=a[now].len+2;//len 直接由父亲的 len 加 2 
    a[p].dep=a[a[p].fail].dep+1;//dep 直接由后缀链上一个点的 dep 加 1 
    a[now].trans[ch]=(p++);//创建树上的边,这一行放到最后是为了方便前面直接用 p 
    lstans=a[a[now].trans[ch]].dep;
    lst=a[now].trans[ch];
}
int main(){
    ios::sync_with_stdio(0);
    cin>>s;
    a[0].fail=1;
    a[1].len=-1;
    for(int i=0;i<s.length();i++){
        s[i]=(s[i]+lstans-'a')%26+'a';
        ins(lst,i);
        cout<<lstans<<' ';
    }
    return 0;
}

一些例题

本来还想放一个刚结束的 CCPC【数据删除】站的题的,根据小青鱼通知 Ucup 可能会用所以暂不公开。

P3649 [APIO2014] 回文串

这个和板子差不多,仍然把考察点放在 fail 树上。

记录每个结点作为 lst 的次数,然后每个结点实际出现的次数就是 fail 树上子树和,建 PAM 时把树一起建出来然后 DFS 一下即可。

:::success[代码]

struct pam{
    ll trans[30],fail,len,cnt;
};
string s;
ll lst=1,p=2,ans;
pam a[300010];
vector<ll> f[300010];
void dfs(ll now){
    for(int i=0;i<f[now].size();i++){
        dfs(f[now][i]);
        a[now].cnt+=a[f[now][i]].cnt;
    }
    ans=max(ans,a[now].cnt*a[now].len);
}
void ins(ll now,ll pos){
    ll ch=s[pos]-'a';
    while(s[pos]!=s[pos-a[now].len-1]){
        now=a[now].fail;
    }
    if(a[now].trans[ch]){
        lst=a[now].trans[ch];
        a[lst].cnt++;
        return ;
    }
    if(now==1){
        a[p].fail=0;
    }
    else{
        ll tmp=a[now].fail;
        while(s[pos]!=s[pos-a[tmp].len-1]){
            tmp=a[tmp].fail;
        }
        a[p].fail=a[tmp].trans[ch];
    }
    f[a[p].fail].push_back(p);
    a[p].len=a[now].len+2;
    a[p].cnt++;
    a[now].trans[ch]=(p++);
    lst=a[now].trans[ch];
}
int main(){
    ios::sync_with_stdio(0);
    cin>>s;
    a[0].fail=1;
    a[1].len=-1;
    for(int i=0;i<s.length();i++){
        ins(lst,i);
    }
    dfs(0);
    cout<<ans;
    return 0;
}

:::

P4287 [SHOI2011] 双倍回文

这可以算作下一题的前置。

这里首先容易想到 PAM 中偶字典树的相关内容。但肯定不是直接扫偶字典树,而是要求存在一个长度恰好为自身一半的回文后缀。所以本题最终还是个后缀链的题。

我们在建 PAM 的过程中对每个回文串处理出其长度小于其自身长度的一半的回文后缀中最长的那个,然后看该回文后缀长度是否恰好为其自身一半即可。维护的方式是对点记录 slink 表示上述后缀,加入新点时从父亲的 slink 开始跳 fail 直至符合条件即可。

稍微有点细节,记得每个可能的不合法情况都判到。

:::success[代码]

struct pam{
    ll trans[30],fail,len,slink;
};
string s;
ll lst=1,p=2,ans;
pam a[500010];
void ins(ll now,ll pos){
    ll ch=s[pos]-'a';
    while(s[pos]!=s[pos-a[now].len-1]){
        now=a[now].fail;
    }
    if(a[now].trans[ch]){
        lst=a[now].trans[ch];
        return ;
    }
    if(now==1){
        a[p].fail=0;
    }
    else{
        ll tmp=a[now].fail;
        while(s[pos]!=s[pos-a[tmp].len-1]){
            tmp=a[tmp].fail;
        }
        a[p].fail=a[tmp].trans[ch];
    }
    a[p].len=a[now].len+2;
    if((a[a[p].fail].len<<1)<=a[p].len){
        a[p].slink=a[p].fail;
    }
    else{
        ll tmp=a[now].slink;
        while((a[tmp].len+2<<1)>a[p].len||s[pos]!=s[pos-a[tmp].len-1]){
            tmp=a[tmp].fail;
        }
        a[p].slink=a[tmp].trans[ch];
    }
    a[now].trans[ch]=(p++);
    lst=a[now].trans[ch];
}
int main(){
    ios::sync_with_stdio(0);
    cin>>s>>s;
    a[0].fail=1;
    a[1].len=-1;
    for(int i=0;i<s.length();i++){
        ins(lst,i);
        if(a[lst].len%4==0&&(a[a[lst].slink].len<<1)==a[lst].len){
            ans=max(ans,a[lst].len);
        }
    }
    cout<<ans;
    return 0;
}

:::

CF2164H PalindromePalindrome

定义一个字符串的幽默度为其中出现至少两次的最长的回文串的长度。

给定长度为 n 的字符串 sq 次询问,每次询问 s 的一段子串的幽默度。

数据范围:n,q\le5\times10^5

在阅读下文时,一旦出现不理解的地方,一定要提醒自己,这些串都是回文串,它们具有回文串的性质。

考虑回文子串两次出现位置的关系。显然可以简单地分为相交和不相交。

两次出现范围相交

设两次出现的区间分别是 [l_1,r_1][l_2,r_2],则有 l_1<l_2\le r_1<r_2,又因为 [l_1,r_1][l_2,r_2] 是回文的,所以 [l_1,r_2] 也必然是回文的。所以这里我们只需要对每个回文的 [l_1,r_2] 找到其最大的边界即可。考虑设 S[l,l'] 为以 l 开头的最长前缀,S[r',r] 为以 r 结尾的最长后缀,则只有下面三种情况是做贡献的:

  1. 回文中心在 S[l,l']S[r',r] 的回文中心(取边界)之间的所有回文子串。

第一条可以马拉车预处理,后两条可以上 PAM 找祖先,查询的时候直接上线段树即可。

两次出现范围不相交

对于这种情况,考虑以 r 结尾的所有回文后缀。令 B=\overline{AC},其中 A,C 为回文串,记 B^k=\overline{\begin{matrix}\underbrace{BBB\dots B}\\k\ \text{个}\ B\end{matrix}},则上述所有后缀均可以表示成 B^kA(k\in\mathbb{N}) 的形式。所以,对于两个后缀 S_1,S_2,若 \dfrac{|S_2|}{2}<|S_1|<|S_2|,则 S_1S_2 中必然会相交地出现至少两次,这会被第一种情况计算。所以实际有效的串的个数只有 O(n\log n) 级别。

我们在构建 PAM 时预处理出每个串的长度不超过其一半的最长回文后缀,每次跳这个后缀就可以只遍历我们需要的串了。查询还是上线段树即可。

总复杂度 O((n+q)\log^2 n)

:::success[代码 By @ljw0102]

#include<bits/stdc++.h>
#define N 500005
#define ll long long
#define fi first
#define se second
using namespace std;
template<typename T> void read(T &x){
    x=0;int f=0;char c=getchar();
    for(;c<'0'||c>'9';c=getchar()) f=(c=='-');
    for(;c>='0'&&c<='9';c=getchar()) x=x*10+c-'0';
    if(f) x=-x;
}

int n,q;
char s[N];

struct node{
    int l,r,k,op;
}b[N*20];
int cnt;

vector<int> ver[N];
int siz[N],dfn[N],cd;
int f[N][20];
void dfs(int x,int p){
    dfn[x]=++cd;siz[x]=1;
    f[x][0]=p;
    for(int j=1;j<20;j++) f[x][j]=f[f[x][j-1]][j-1];
    for(int y:ver[x]) dfs(y,x),siz[x]+=siz[y];
}
int zkw[N*4],M=524288;
void change(int x,int k){
    x+=M;
    zkw[x]=max(zkw[x],k);
    for(x>>=1;x;x>>=1) zkw[x]=max(zkw[x*2],zkw[x*2+1]);
}
int ask(int l,int r){
    l+=M,r+=M;
    if(l==r) return zkw[l];
    int ans=max(zkw[l],zkw[r]);
    while(l+1!=r){
        if((l&1)==0) ans=max(ans,zkw[l+1]);
        if(r&1) ans=max(ans,zkw[r-1]);
        l>>=1,r>>=1;
    }
    return ans;
}

struct PAM{
char s[N];
int t[N][26],fail[N],len[N],lst=1,tot=1,pos[N],ans[N];
int diff[N],slink[N];
void insert(int i){
    int p=lst,c=s[i]-'a';
    while(s[i-len[p]-1]!=s[i]) p=fail[p];
    if(t[p][c]) lst=t[p][c];
    else{
        int u=++tot,q=fail[p];
        while(s[i-len[q]-1]!=s[i]) q=fail[q];
        fail[u]=t[q][c];
        t[p][c]=u;len[u]=len[p]+2;
        diff[u]=len[u]-len[fail[u]];
        if(diff[u]==diff[fail[u]]) slink[u]=slink[fail[u]];
        else slink[u]=fail[u];
        lst=u;
    }
    pos[i]=lst;
}
int find(int l,int r){
    int p=pos[r];
    if(len[p]<=r-l+1) return len[p];
    while(len[slink[p]]>r-l+1) p=slink[p];
    return len[slink[p]]+(r-l+1-len[slink[p]])/diff[p]*diff[p];
}
void calc(){
    for(int i=1;i<=n;i++){
        int p=pos[i];
        int lst=ask(dfn[p],dfn[p]+siz[p]-1);
        if(lst) b[++cnt]={lst-len[p]+1,i,len[p],1};
        while(p){
            if(len[fail[p]]*2<=len[p]) {
                int lst=ask(dfn[fail[p]],dfn[fail[p]]+siz[fail[p]]-1);
                if(lst) b[++cnt]={lst-len[fail[p]]+1,i,len[fail[p]],1};
            }
            p=slink[p];
        }
        change(dfn[pos[i]],i);
    }
    for(int i=2;i<=tot;i++){
        ans[i]=max(ans[i],len[fail[i]]);
        for(int c=0;c<26;c++) if(t[i][c]) ans[t[i][c]]=max(ans[t[i][c]],ans[i]);
    }
}
int find2(int l,int r){
    int p=pos[r];
    if(len[p]==r-l+1) return ans[p];
    for(int i=19;i>=0;i--) if(len[f[p][i]]>r-l+1) p=f[p][i];
    return ans[fail[p]];
}
}pam1,pam2;

int zkw2[N*8];
void change2(int x,int k){
    x+=M*2;
    zkw2[x]=max(zkw2[x],k);
    for(x>>=1;x;x>>=1) zkw2[x]=max(zkw2[x*2],zkw2[x*2+1]);
}
int ask2(int l,int r){
    l+=M*2,r+=M*2;
    if(l==r) return zkw2[l];
    int ans=max(zkw2[l],zkw2[r]);
    while(l+1!=r){
        if((l&1)==0) ans=max(ans,zkw2[l+1]);
        if(r&1) ans=max(ans,zkw2[r-1]);
        l>>=1,r>>=1;
    }
    return ans;
}
char s2[N*2];
int p[N*2];
void manacher(){
    s2[0]='%',s2[1]='#';
    for(int i=1;i<=n;i++) s2[i*2]=s[i],s2[i*2+1]='#';
    s2[n*2+2]='$';
    int r=0,c=0;
    for(int i=1;i<=n*2+1;i++){
        if(i<r) p[i]=min(p[c*2-i],r-i);
        while(s2[i+p[i]+1]==s2[i-p[i]-1]) p[i]++;
        if(i+p[i]>r) c=i,r=i+p[i];
        if(i%2==0) change2(i,pam1.find2(i/2-p[i]/2,i/2+p[i]/2));
        else change2(i,pam1.find2(i/2-p[i]/2+1,i/2+p[i]/2));
    }
}
int ans[N];

int main(){
    read(n),read(q);
    scanf("%s",s+1);
    pam1.len[1]=-1,pam1.fail[0]=1;
    pam2.len[1]=-1,pam2.fail[0]=1;
    for(int i=1;i<=n;i++){
        pam1.s[i]=s[i];
        pam1.insert(i);
        pam2.s[i]=s[n-i+1];
        pam2.insert(i);
    }
    for(int i=2;i<=pam1.tot;i++) ver[pam1.fail[i]].push_back(i);
    ver[1].push_back(0);
    dfs(1,0);
    pam1.calc();
    manacher();
    for(int i=1;i<=q;i++){
        int l,r;
        read(l),read(r);
        int val1=pam2.find(n-r+1,n-l+1),val2=pam1.find(l,r);
        if(val1==r-l+1){
            ans[i]=pam1.find2(l,r);
            continue;
        }
        int l2=l+val1-1,r2=r-val2+1;
        ans[i]=max(max(pam1.find2(l,l2),pam1.find2(r2,r)),ask2(l+l2+1,r+r2-1));
        b[++cnt]={l,r,i,2};
    }
    memset(zkw,0,sizeof(zkw));
    sort(b+1,b+cnt+1,[&](node a,node b){return a.r<b.r||a.r==b.r&&a.op<b.op;});
    for(int i=1;i<=cnt;i++){
        if(b[i].op==1) change(b[i].l,b[i].k);
        else ans[b[i].k]=max(ans[b[i].k],ask(b[i].l,b[i].r));
    }
    for(int i=1;i<=q;i++) cout<<ans[i]<<"\n";
    return 0;
}

:::