回文树 / 回文自动机(PAM)入门笔记
题解 P5496 【模板】回文树 / 回文自动机(PAM)暨 PAM 入门教程
更新日志:
- 【本版,管理大大可以只看这一版】22/12/2025:根据@Xycxxx 神犇的建议,修正了
fail_x 定义的错误。
本文简要解释了如何理解 PAM 的时空复杂度,如果您不理解可以参考。
有任何问题欢迎提问。
尽管本题只需要最基本的回文自动机即可完成,但作为模板题,还是多写一点来的合适,毕竟很多同学来这是学算法的。所以本文包括:
- 用大量的图文来展开朴素 PAM 的算法流程,又用了精简的文字来解释编写 PAM 时需要保持的基本思想。
- 对朴素 PAM 时空复杂度的简单解释。
- 几个例题(个别题难度略高,
毕竟我连 Global 的压轴题都给塞进去了)。
注:为了方便区分字符和字符变量,本文中字符及字符串一律使用 \mathtt 字体。
引入
做到 CF2164H 发现需要 PAM 作为前置然后我不会,于是就来记笔记了。
关于回文的问题,我们已经有一个 Manacher 算法来处理了。Manacher 求解了以每个位置为中心的最长回文子串半径,而回文自动机则是以更大的空间开销为代价,接受一个字符串所有回文子串的自动机。
罗勇军、郭卫斌《算法竞赛》上总结说,PAM 的算法思想概括来说是“奇偶字典树+后缀链跳跃”。私以为这个概括十分精辟准确,所以接下来就分成“建 trie”和“求
『奇偶字典树』
在正式建树之前,我们需要了解
PAM 使用类似于 trie 的结构来存储所有本质不同的回文串。普通的 trie 结构类似于这样:
但 PAM 中 trie 是用来存储回文串的,所以只有回文串配出现在它的结点上。怎么保证结点上一定是回文串呢?
方法很简单:trie 上的边不再表示向后添加边上的字符,改为表示向前后各添加一个该边上的字符。按照这个含义,上述 trie 各点对应的字符串就如下图所示了:
但是这样的话,似乎不能存储所有本质不同的回文串:上图中只有长度为偶数的。
所以对于长度为奇数的我们还要额外建一棵 trie。
这就是『奇偶字典树』的思想,上一张图即奇字典树,而再上一张图即偶字典树。而所谓的
『后缀链跳跃』
联赛的字符串算法一共有两个需要背板子:KMP 和马拉车。
这两种算法在算法思路上的共性特点之一是:知道了
PAM 也采用相同的增量构造方式。所以我们考虑:知道了以
如果运气不错,
此时
但更多的时候运气不会这么好。
我们考虑,我们刚才直接通过判断
为了保持这个性质,我们需要从长到短不断枚举以
于是『后缀链跳跃』的含义也很清楚了:对于
目前为止,我们求出了
如果求出的
而由于求的还是以
至此,我们已经讲完了 PAM 的所有内容。
当然,结点上还需要维护题目要求的其他信息,例如本题需要维护后缀链长度(即跳
也许你会很疑惑:其他教材都把 PAM 画出来了,你这还没开始画呢就讲完了?不是说好的多图吗?
别急,这就给你画。
算法流程图解
下面将第三组样例稍加修改,使其出现多次加入到同一点上的情况,即
如果你能按照顺序在不提前看的情况下找准每一个
复杂度
时间
朴素 PAM 建树的时间复杂度是
有些同学不理解为什么回文树不断跳
空间
朴素 PAM 的空间复杂度是
根据 trie 的思想,PAM 只会为本质不同的回文串创建点,而这样的回文串只有
模板代码
PAM 代码的要点绝对不是背下来,因为它代码十分简单,而且没有易错的编码细节。它的重点在于理解算法过程,所以在编写的过程中脑子里可以想着上面的图来强化理解。
再次重复算法要点:
- 将各个本质不同的回文串按长度的奇偶性分为,并在节点上维护
trans,fail,len 及题目要求维护的其他信息。 - 通过不断跳原
lst 的fail 指针的方式计算新lst 及新lst 的fail 指针。
只要把握住上面两点,编写 PAM 其实很简单!
struct pam{
ll trans[30],fail,len,dep;//dep 表示后缀链长度,也就是本题的答案
};
string s;
ll lstans,lst=1,p=2;
pam a[500010];
void ins(ll now,ll pos){
//在 now 后试图加入 s[pos] 字符
ll ch=s[pos]-'a';//方便后面书写
while(s[pos]!=s[pos-a[now].len-1]){
now=a[now].fail;//不断跳 fail 直至 now 两侧字符相同
}
if(a[now].trans[ch]){
//该点已经存在,所有信息都有了,直接退出
lstans=a[a[now].trans[ch]].dep;
lst=a[now].trans[ch];
return ;
}
//该点原来不存在,需要求出 fail,len,dep
//求 fail
if(now==1){//如果当前已经是单字符了
a[p].fail=0;//则 fail 是空串
}
else{
//否则接着跳 fail
ll tmp=a[now].fail;
while(s[pos]!=s[pos-a[tmp].len-1]){
tmp=a[tmp].fail;
}
a[p].fail=a[tmp].trans[ch];
}
a[p].len=a[now].len+2;//len 直接由父亲的 len 加 2
a[p].dep=a[a[p].fail].dep+1;//dep 直接由后缀链上一个点的 dep 加 1
a[now].trans[ch]=(p++);//创建树上的边,这一行放到最后是为了方便前面直接用 p
lstans=a[a[now].trans[ch]].dep;
lst=a[now].trans[ch];
}
int main(){
ios::sync_with_stdio(0);
cin>>s;
a[0].fail=1;
a[1].len=-1;
for(int i=0;i<s.length();i++){
s[i]=(s[i]+lstans-'a')%26+'a';
ins(lst,i);
cout<<lstans<<' ';
}
return 0;
}
一些例题
本来还想放一个刚结束的 CCPC【数据删除】站的题的,根据小青鱼通知 Ucup 可能会用所以暂不公开。
P3649 [APIO2014] 回文串
这个和板子差不多,仍然把考察点放在
记录每个结点作为
:::success[代码]
struct pam{
ll trans[30],fail,len,cnt;
};
string s;
ll lst=1,p=2,ans;
pam a[300010];
vector<ll> f[300010];
void dfs(ll now){
for(int i=0;i<f[now].size();i++){
dfs(f[now][i]);
a[now].cnt+=a[f[now][i]].cnt;
}
ans=max(ans,a[now].cnt*a[now].len);
}
void ins(ll now,ll pos){
ll ch=s[pos]-'a';
while(s[pos]!=s[pos-a[now].len-1]){
now=a[now].fail;
}
if(a[now].trans[ch]){
lst=a[now].trans[ch];
a[lst].cnt++;
return ;
}
if(now==1){
a[p].fail=0;
}
else{
ll tmp=a[now].fail;
while(s[pos]!=s[pos-a[tmp].len-1]){
tmp=a[tmp].fail;
}
a[p].fail=a[tmp].trans[ch];
}
f[a[p].fail].push_back(p);
a[p].len=a[now].len+2;
a[p].cnt++;
a[now].trans[ch]=(p++);
lst=a[now].trans[ch];
}
int main(){
ios::sync_with_stdio(0);
cin>>s;
a[0].fail=1;
a[1].len=-1;
for(int i=0;i<s.length();i++){
ins(lst,i);
}
dfs(0);
cout<<ans;
return 0;
}
:::
P4287 [SHOI2011] 双倍回文
这可以算作下一题的前置。
这里首先容易想到 PAM 中偶字典树的相关内容。但肯定不是直接扫偶字典树,而是要求存在一个长度恰好为自身一半的回文后缀。所以本题最终还是个后缀链的题。
我们在建 PAM 的过程中对每个回文串处理出其长度小于其自身长度的一半的回文后缀中最长的那个,然后看该回文后缀长度是否恰好为其自身一半即可。维护的方式是对点记录
稍微有点细节,记得每个可能的不合法情况都判到。
:::success[代码]
struct pam{
ll trans[30],fail,len,slink;
};
string s;
ll lst=1,p=2,ans;
pam a[500010];
void ins(ll now,ll pos){
ll ch=s[pos]-'a';
while(s[pos]!=s[pos-a[now].len-1]){
now=a[now].fail;
}
if(a[now].trans[ch]){
lst=a[now].trans[ch];
return ;
}
if(now==1){
a[p].fail=0;
}
else{
ll tmp=a[now].fail;
while(s[pos]!=s[pos-a[tmp].len-1]){
tmp=a[tmp].fail;
}
a[p].fail=a[tmp].trans[ch];
}
a[p].len=a[now].len+2;
if((a[a[p].fail].len<<1)<=a[p].len){
a[p].slink=a[p].fail;
}
else{
ll tmp=a[now].slink;
while((a[tmp].len+2<<1)>a[p].len||s[pos]!=s[pos-a[tmp].len-1]){
tmp=a[tmp].fail;
}
a[p].slink=a[tmp].trans[ch];
}
a[now].trans[ch]=(p++);
lst=a[now].trans[ch];
}
int main(){
ios::sync_with_stdio(0);
cin>>s>>s;
a[0].fail=1;
a[1].len=-1;
for(int i=0;i<s.length();i++){
ins(lst,i);
if(a[lst].len%4==0&&(a[a[lst].slink].len<<1)==a[lst].len){
ans=max(ans,a[lst].len);
}
}
cout<<ans;
return 0;
}
:::
CF2164H PalindromePalindrome
定义一个字符串的幽默度为其中出现至少两次的最长的回文串的长度。
给定长度为
数据范围:
在阅读下文时,一旦出现不理解的地方,一定要提醒自己,这些串都是回文串,它们具有回文串的性质。
考虑回文子串两次出现位置的关系。显然可以简单地分为相交和不相交。
两次出现范围相交
设两次出现的区间分别是
- 回文中心在
S[l,l'] 和S[r',r] 的回文中心(取边界)之间的所有回文子串。 -
-
第一条可以马拉车预处理,后两条可以上 PAM 找祖先,查询的时候直接上线段树即可。
两次出现范围不相交
对于这种情况,考虑以
我们在构建 PAM 时预处理出每个串的长度不超过其一半的最长回文后缀,每次跳这个后缀就可以只遍历我们需要的串了。查询还是上线段树即可。
总复杂度
:::success[代码 By @ljw0102]
#include<bits/stdc++.h>
#define N 500005
#define ll long long
#define fi first
#define se second
using namespace std;
template<typename T> void read(T &x){
x=0;int f=0;char c=getchar();
for(;c<'0'||c>'9';c=getchar()) f=(c=='-');
for(;c>='0'&&c<='9';c=getchar()) x=x*10+c-'0';
if(f) x=-x;
}
int n,q;
char s[N];
struct node{
int l,r,k,op;
}b[N*20];
int cnt;
vector<int> ver[N];
int siz[N],dfn[N],cd;
int f[N][20];
void dfs(int x,int p){
dfn[x]=++cd;siz[x]=1;
f[x][0]=p;
for(int j=1;j<20;j++) f[x][j]=f[f[x][j-1]][j-1];
for(int y:ver[x]) dfs(y,x),siz[x]+=siz[y];
}
int zkw[N*4],M=524288;
void change(int x,int k){
x+=M;
zkw[x]=max(zkw[x],k);
for(x>>=1;x;x>>=1) zkw[x]=max(zkw[x*2],zkw[x*2+1]);
}
int ask(int l,int r){
l+=M,r+=M;
if(l==r) return zkw[l];
int ans=max(zkw[l],zkw[r]);
while(l+1!=r){
if((l&1)==0) ans=max(ans,zkw[l+1]);
if(r&1) ans=max(ans,zkw[r-1]);
l>>=1,r>>=1;
}
return ans;
}
struct PAM{
char s[N];
int t[N][26],fail[N],len[N],lst=1,tot=1,pos[N],ans[N];
int diff[N],slink[N];
void insert(int i){
int p=lst,c=s[i]-'a';
while(s[i-len[p]-1]!=s[i]) p=fail[p];
if(t[p][c]) lst=t[p][c];
else{
int u=++tot,q=fail[p];
while(s[i-len[q]-1]!=s[i]) q=fail[q];
fail[u]=t[q][c];
t[p][c]=u;len[u]=len[p]+2;
diff[u]=len[u]-len[fail[u]];
if(diff[u]==diff[fail[u]]) slink[u]=slink[fail[u]];
else slink[u]=fail[u];
lst=u;
}
pos[i]=lst;
}
int find(int l,int r){
int p=pos[r];
if(len[p]<=r-l+1) return len[p];
while(len[slink[p]]>r-l+1) p=slink[p];
return len[slink[p]]+(r-l+1-len[slink[p]])/diff[p]*diff[p];
}
void calc(){
for(int i=1;i<=n;i++){
int p=pos[i];
int lst=ask(dfn[p],dfn[p]+siz[p]-1);
if(lst) b[++cnt]={lst-len[p]+1,i,len[p],1};
while(p){
if(len[fail[p]]*2<=len[p]) {
int lst=ask(dfn[fail[p]],dfn[fail[p]]+siz[fail[p]]-1);
if(lst) b[++cnt]={lst-len[fail[p]]+1,i,len[fail[p]],1};
}
p=slink[p];
}
change(dfn[pos[i]],i);
}
for(int i=2;i<=tot;i++){
ans[i]=max(ans[i],len[fail[i]]);
for(int c=0;c<26;c++) if(t[i][c]) ans[t[i][c]]=max(ans[t[i][c]],ans[i]);
}
}
int find2(int l,int r){
int p=pos[r];
if(len[p]==r-l+1) return ans[p];
for(int i=19;i>=0;i--) if(len[f[p][i]]>r-l+1) p=f[p][i];
return ans[fail[p]];
}
}pam1,pam2;
int zkw2[N*8];
void change2(int x,int k){
x+=M*2;
zkw2[x]=max(zkw2[x],k);
for(x>>=1;x;x>>=1) zkw2[x]=max(zkw2[x*2],zkw2[x*2+1]);
}
int ask2(int l,int r){
l+=M*2,r+=M*2;
if(l==r) return zkw2[l];
int ans=max(zkw2[l],zkw2[r]);
while(l+1!=r){
if((l&1)==0) ans=max(ans,zkw2[l+1]);
if(r&1) ans=max(ans,zkw2[r-1]);
l>>=1,r>>=1;
}
return ans;
}
char s2[N*2];
int p[N*2];
void manacher(){
s2[0]='%',s2[1]='#';
for(int i=1;i<=n;i++) s2[i*2]=s[i],s2[i*2+1]='#';
s2[n*2+2]='$';
int r=0,c=0;
for(int i=1;i<=n*2+1;i++){
if(i<r) p[i]=min(p[c*2-i],r-i);
while(s2[i+p[i]+1]==s2[i-p[i]-1]) p[i]++;
if(i+p[i]>r) c=i,r=i+p[i];
if(i%2==0) change2(i,pam1.find2(i/2-p[i]/2,i/2+p[i]/2));
else change2(i,pam1.find2(i/2-p[i]/2+1,i/2+p[i]/2));
}
}
int ans[N];
int main(){
read(n),read(q);
scanf("%s",s+1);
pam1.len[1]=-1,pam1.fail[0]=1;
pam2.len[1]=-1,pam2.fail[0]=1;
for(int i=1;i<=n;i++){
pam1.s[i]=s[i];
pam1.insert(i);
pam2.s[i]=s[n-i+1];
pam2.insert(i);
}
for(int i=2;i<=pam1.tot;i++) ver[pam1.fail[i]].push_back(i);
ver[1].push_back(0);
dfs(1,0);
pam1.calc();
manacher();
for(int i=1;i<=q;i++){
int l,r;
read(l),read(r);
int val1=pam2.find(n-r+1,n-l+1),val2=pam1.find(l,r);
if(val1==r-l+1){
ans[i]=pam1.find2(l,r);
continue;
}
int l2=l+val1-1,r2=r-val2+1;
ans[i]=max(max(pam1.find2(l,l2),pam1.find2(r2,r)),ask2(l+l2+1,r+r2-1));
b[++cnt]={l,r,i,2};
}
memset(zkw,0,sizeof(zkw));
sort(b+1,b+cnt+1,[&](node a,node b){return a.r<b.r||a.r==b.r&&a.op<b.op;});
for(int i=1;i<=cnt;i++){
if(b[i].op==1) change(b[i].l,b[i].k);
else ans[b[i].k]=max(ans[b[i].k],ask(b[i].l,b[i].r));
}
for(int i=1;i<=q;i++) cout<<ans[i]<<"\n";
return 0;
}
:::