浅谈 KMP

· · 算法·理论

本篇文章同步发表在博客园。

问题引入

给定一个字符串 s 和一个字符串 t,问 s 中有哪些子串为 t

1 \le |s|,|t| \le 1000 的时候,这就是一个特别简单的题目,我们可以暴力枚举 s 的每个长度为 |t| 的子串,然后取出来看看是否等于 t。这样的时间复杂度是 O(|s| \times |t|) 的,在这个数据范围下可谓是轻而易举。

或也可以使用 string 中自带的 find 函数,时间复杂度同样 O(|s| \times |t|)。注意,find 的时间复杂度最高会到平方级别,并不是什么线性或者甚至说是 O(1) 的哦。

但是这里讨论的都是小规模数据。当 1 \le |s|,|t| \le 10^5 的时候,这些方法就都用不了了,我们又该如何解决这个问题呢?

什么是 KMP?

刚才提到当 |s||t| 的规模到了 10^5 的情况下该怎么办。这时候刚才提到的两种朴素方法就都会超时了。

在这个情况下,KMP 就是一个非常好的选择!

是的,你可能会说 Hash 也行,没错,但是 Hash 再怎么说也存在一定的冲突概率,容易被卡,不是很好。而 KMP 却是一个非常保险的算法,当然是不可能像 Hash 一样出现什么哈希冲突的情况啦。

KMP 的时间复杂度是 O(|s|+|t|) 的,是不是非常便捷呢?

如何实现 KMP?

我们首先来看看暴力匹配的过程,能不能进行一些实现细节上的优化。

我们是一直往后找,很多时候的匹配做的都是无用功,要是可以延用之前比较出的结果来加快匹配速度就好了。

其实很好办的呐!我们依然是匹配,但不一下子匹配一整个串儿了,咱先就一个一个字符匹配——匹配到出现错误了,匹配不上了,哎,它不对应了,这个时候咱该干什么呢?按照常规套路,是不是应该把序列往后移一格,然后继续从第一个字符开始匹配呀。但是这里呢咱换种思想,我们考虑当前匹配的这部分串儿的一个 border——这东西表示最长公共前后缀,比如说字符串 \text{abcab},最长公共前后缀就是 2,也就是这个 \text{ab},瞧 {\color{red}\text{ab}}\text{c}{\color{red}\text{ab}} ,它在前后都出现啦。这个 border 要干啥呢?求出这个 border 的长度,然后把整个字符串位移到 border 在后缀出现的那部分,作为新匹配串的前缀,然后再继续往下匹配就成了。

于是现在的问题就变成了如何求出这个最长公共前后缀,也就是这所谓的 border。

nxt_i 表示 t 的前 i 个字符构成的前缀串儿的 border 的长度,当然是不能算上自己本身这个串的,不然就无意义了。

咋求?首先肯定直接从 2 开始遍历,因为长度为 1 的字符串——哈哈,就是一个字符嘛——是没有 border 的,或者说它的长度是 0,因为不能算上自己。从 2 开始遍历,首先看,如果你可以直接和上一个匹配,那咱就直接匹配,那么 nxt_i 的答案就是 nxt_{i-1} + 1 了。但要是匹配不上了,那 nxt_i 的答案又是什么呢?在这里就有一个很厉害的做法——咱去求 border 的 border,找不到就再去求 border!从 i-1 开始一个类似递归的形式,只要匹配不上,就跳转到 border 页面,然后再判,再跳,直到找到。当然也有直到最后都还找不到的情况,这个时候 nxt_i 就是 0 了。

求完这个东西就可以根据上面那个经过优化的匹配思路去匹配啦,这样就实现完了,但是当你有了求 border 这件武器之后,你又发现了一种全新的匹配方式!把 ts 拼接起来,中间插个无意义字符如 #,连成一个更长的大字符串,然后对这个大字符串求 border,得到一个 nxt 序列。然后让 i|t|+2 开始找(我这里全部从 1 开始编号),直到 |t|+|s|+1,如果哪里 nxt_i = |t|,是不是说明这里就匹配到了呢?于是这样就比像刚才那样再去费尽心思匹配要轻松多了。

这就是 KMP 啦,是不是很厉害呢?

KMP 模版代码

放个代码供参考。

#include<bits/stdc++.h>
#define LL long long
#define UInt unsigned int
#define ULL unsigned long long
#define LD long double
#define pii pair<int,int>
#define pLL pair<LL,LL>
#define pDD pair<LD,LD>
#define fr first
#define se second
#define pb push_back
#define isr insert
using namespace std;
const int N = 2e6+5;
string s,t,p;
int nxt[N],n,m,k;
int read(){
    int su=0,pp=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')pp=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){su=su*10+ch-'0';ch=getchar();}
    return su*pp;
}
int main(){
    cin>>s>>t;p=t+'#'+s;
    n=s.size(),m=t.size(),k=p.size();
    s=" "+s,t=" "+t,p=" "+p;
    for(int i=2;i<=k;i++){
        int j=i-1;
        while(j>0){
            if(p[nxt[j]+1]==p[i])
                {nxt[i]=nxt[j]+1;break;}
            j=nxt[j];
        }
    }for(int i=m+2;i<=k;i++)
        if(nxt[i]==m)cout<<i-2*m<<"\n";
    for(int i=1;i<=m;i++)
        cout<<nxt[i]<<" ";cout<<"\n";
    return 0;
}

例题选讲

这边选择几道比较经典的例题进行讲解。

P4391 Radio Transmission 无线传输

结论题,答案就是 n-nxt_n,原理不用过多解释,因为就是一个很直觉性的,随便画个图就能知道。和板题代码几乎一样。

P9606 ABB

这个也是比较板子的题目,只需要取原字符串翻转后的结果,让它和原字符串拼接在一起,中间加个无意义字符,做 border,最后让 n 减掉那个 nxt_L 即可(L 表示拼接后的字符串的长度)。

CF1137B Camp Schedule

跟 KMP 关系不大,重点是要能熟练运用 border 的求解方式。

由于这个题目是给定你一个字符串 s 然后要求你重排它,使得里面出现的 t 的次数尽可能多。而且它保证了一个特别重要的性质,那就是不论是 s 还是 t 都只由 01 组成。

首先可以考虑把 s 拆一下,拆成 s00s11,到时候直接重组便可。当然 t 也是要拆的,拆成 t00t11

这个时候,我们要先对 t 求一趟 border,得出这个对应的 nxt 数组。拿到这个 nxt 数组之后,我们取出 nxt_{|t|} 的值,并让 |t| 减去它,这就是在 t 后为再产生一个 t 而所需要花费的总长度。当然了,要从 t 中真的提取出这一段字符串内容,同样也对其进行拆分,拆成 w00w11,方便后面的判断。

想想看,现在我们什么都有了,究竟要怎么构造才是最优的?显而易见,当 s01 的个数都是够用的情况下,我们首先拼出一个完整的 t,接着不断按照 border 求解之后得到的情况进行拼接,尽量凑出尽可能多的 t。直到某个时候 0 或者 1 有哪个不够了,那么也就没法子再多拼出一个 t 了,这个时候把剩余的 01 随意拼接到串末即可。

这样是不是就结束了?是不是很简单呢?

放个代码。后面的拼接部分其实也可以除法求出最多能拼多少个,不过我这里用的是循环暴力枚举的一个方式,也是一样的啦。

#include<bits/stdc++.h>
#define LL long long
#define UInt unsigned int
#define ULL unsigned long long
#define LD long double
#define pii pair<int,int>
#define pLL pair<LL,LL>
#define pDD pair<LD,LD>
#define fr first
#define se second
#define pb push_back
#define isr insert
using namespace std;
const int N = 1e6+5;
int n,m,s0,s1,t0,t1,w0,w1,nxt[N];
string s,t,sub,ans;
int read(){
    int su=0,pp=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')pp=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){su=su*10+ch-'0';ch=getchar();}
    return su*pp;
}
int main(){
    cin>>s>>t;
    n=s.size(),m=t.size();
    s=" "+s,t=" "+t;
    for(int i=1;i<=n;i++)
        if(s[i]=='1')s1++;else s0++;
    for(int i=1;i<=m;i++)
        if(t[i]=='1')t1++;else t0++;
    for(int i=2;i<=m;i++){
        int h=i-1;
        while(h>0){
            if(t[nxt[h]+1]==t[i])
                {nxt[i]=nxt[h]+1;break;}
            h=nxt[h];
        }
    }for(int i=1;i<=nxt[m];i++)
        if(t[i]=='1')w1++;else w0++;
    if(m>n){for(int i=1;i<=n;i++)cout<<s[i];return 0;}
    if(s0<t0||s1<t1){for(int i=1;i<=n;i++)cout<<s[i];return 0;}
    ans=t;s0-=t0,s1-=t1;
    for(int i=nxt[m]+1;i<=m;i++)sub+=t[i];
    for(int i=m+1;i<=n;i++){
        if(s0<t0-w0||s1<t1-w1)break;
        s0-=(t0-w0),s1-=(t1-w1),ans+=sub;
    }while(s0--)ans+="0";while(s1--)ans+="1";
    for(int i=1;i<=n;i++)cout<<ans[i];cout<<"\n";
    return 0;
}

CF1200E Compress Words

很明显这个东西也是要进行拼凑,但是要把重复的地方去掉。

就是要找到两个字符串重复的地方嘛,靠前的字符串的后缀,以及靠后的字符串的前缀。这不也可以用 border 实现吗?翻转一下,拼接,不就可以求了吗?

呃,但是这个东西好像直接干会超时……是的没错,因为这样子的时间复杂度是 O(n \sum |s|) 的!大概是 10^{11} 的级别,绝对接受不了。实在是因为这个 ans 太长了啊,让它一直不停地去和各个 s_i 算 border,一次又一次,不超时才怪呢!

怎么办,怎么办呢?注意到不论这个 ans 有多长,最极端的情况也莫过于这个 s_i 完全融合进原来的 ans,也就是说这个 border,换句话说求出来的这个 nxt_k(其中 k 表示 anss_i 的拼接字符串的总长度)为 |s_i| 嘛!这是最大的情况了!既然 border 最大也就这样,那我们的 ans 还给这么多干嘛?去吃闲饭的吗?多浪费时间呐!于是咱就干脆只截 ans 的后 |s_i| 位去给做 border 匹配,这样的话这速度就快多了,只剩下个 O(\sum 2|s|) 了,一点点常数有什么问题嘛!

于是就结束啦,代码特别简单。

#include<bits/stdc++.h>
#define LL long long
#define UInt unsigned int
#define ULL unsigned long long
#define LD long double
#define pii pair<int,int>
#define pLL pair<LL,LL>
#define pDD pair<LD,LD>
#define fr first
#define se second
#define pb push_back
#define isr insert
using namespace std;
const int N = 1e5+5;
const int M = 1e6+5;
int n,nxt[M];
string s[N],ans;
int read(){
    int su=0,pp=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')pp=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){su=su*10+ch-'0';ch=getchar();}
    return su*pp;
}
string border(string s,string t){
    string p=" "+t+"#"+s;
    int k=p.size()-1;
    for(int i=1;i<=k;i++)nxt[i]=0;
    for(int i=2;i<=k;i++){
        int j=i-1;
        while(j>0){
            if(p[nxt[j]+1]==p[i])
                {nxt[i]=nxt[j]+1;break;}
            j=nxt[j];
        }
    }string tmp="";
    for(int i=nxt[k];i<t.size();i++)
        tmp+=t[i];return tmp;
}
int main(){
    n=read();
    for(int i=1;i<=n;i++)cin>>s[i];
    ans=s[1];
    for(int i=2;i<=n;i++){
        string tmp="";
        int x=ans.size(),y=s[i].size();
        for(int j=max(0,x-y);j<ans.size();j++)
            tmp+=ans[j];
        ans+=border(tmp,s[i]);
    }
    cout<<ans<<"\n";
    return 0;
}

CF631D Messenger

发现这个东西和往常的 KMP 匹配不一样呐,这东西它多了个奇奇怪怪的封装——也难怪,按照这个规模算一下,拉长之后岂不有 2 \times 10^{10} 长?谁存的下,谁又做得来呢?不说别的,都读不进来呢!

于是只能在这个封装的基础上进行 KMP 的操作。当然了,我们要先把它弄成最简封装,换句话说,封装中不能存在相邻两个字符相同,否则就要把它们合并起来!

我们发现,封装之后的 KMP 和之前不一样了,之前是全都要相同,现在只要中间一坨全部相同,左右端点只需要字符匹配,并且个数足够即可。那么 m=1m=2(所有提到的 nm 均是最简封装情况下的)的需要你去特判一下,因为这俩玩意儿不存在中间一坨,玩不了 KMP。

搞完这两种特殊情况就来真的了,依然 KMP,依然求 border,不过这个时候那个 t 别给全塞进去了,头和尾塞不得,因为这是最后要判断的。搞完之后看谁的 nxt 值是 m-2,是的话再判断下头和尾行不行,如果行就多一种情况啦,就可以更新答案啦!最后输出就行了。

#include<bits/stdc++.h>
#define LL long long
#define UInt unsigned int
#define ULL unsigned long long
#define LD long double
#define pii pair<int,int>
#define pLL pair<LL,LL>
#define pDD pair<LD,LD>
#define fr first
#define se second
#define pb push_back
#define isr insert
using namespace std;
const int N = 4e5+5;
struct node{LL x;char c;}a[N],b[N],p[N];
LL n,m,k,cn,cm,Ans,nxt[N];
LL read(){
    LL su=0,pp=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')pp=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){su=su*10+ch-'0';ch=getchar();}
    return su*pp;
}
bool operator == (const node &A , const node &B){
    return (A.x==B.x&&A.c==B.c);
}
bool operator <= (const node &A , const node &B){
    return (A.x<=B.x&&A.c==B.c);
}
int main(){
    n=read(),m=read();
    for(int i=1;i<=n;i++){
        LL p=read();char h;cin>>h;
        if(h==a[cn].c)a[cn].x+=p;else a[++cn]={p,h};
    }n=cn;
    for(int i=1;i<=m;i++){
        LL p=read();char h;cin>>h;
        if(h==b[cm].c)b[cm].x+=p;else b[++cm]={p,h};
    }m=cm;
    if(m==1){
        for(int i=1;i<=n;i++)
            if(b[1]<=a[i])Ans+=a[i].x-b[1].x+1;
    }else if(m==2){
        for(int i=1;i<n;i++)
            if(b[1]<=a[i]&&b[2]<=a[i+1])Ans++;
    }else{
        for(int i=2;i<m;i++)p[++k]=b[i];
        for(int i=0;i<=n;i++)p[++k]=a[i];
        for(int i=2;i<=k;i++){
            int j=i-1;
            while(j>0){
                if(p[nxt[j]+1]==p[i])
                    {nxt[i]=nxt[j]+1;break;}
                j=nxt[j];
            }
        }for(int i=m;i<k;i++)
            if(nxt[i]==m-2&&b[1]<=p[i-m+2]&&b[m]<=p[i+1])Ans++;
    }cout<<Ans<<"\n";
    return 0;
}

简单总结

KMP,它通常用来处理字符串匹配问题,可以方便快速地查找到一个字符串在另一个字符串中的出现情况。和它一起出现的是 border,它可以快速求出字符串的任意前缀的最长公共前后缀,延伸运用多用于求解两个字符串的最长公共前后缀、回文串匹配情况等,运用多种多样,是非常好用的线性算法。

码这么多字也不容易,还麻烦你留个赞支持一下,真是太感谢啦!