题解:P13812 [CERC 2022] Insertions

· · 题解

由于问题要求的很多,其中包含最大数量,出现次数,最左可行位置和最右可行位置。所以考虑对于每个可以插入的位置,都求一下出现次数。

显然在一个位置插入后 p 的出现次数可以分成若干种情况讨论。不妨假设在这个位置插入后分成三段,分别为 A,t,B。其中 A+B=s。那么在 A,t,B 中一段的出现次数是可以预处理,在线性时间内求出所有位置的答案的。跨两段的,如果是 A,t,由于 t 是固定的,所以就考虑用 t 的一段前缀去匹配 p 的一端后缀,然后再检查 A 的后缀能否与 p 的剩余部分匹配。这样暴力是非常慢的,所以考虑充分利用已知信息。注意到,如果失配是发生在 t 内,那么直接把失配的情况预处理;如果失配不是发生在 t 内,那么就不需要考虑了,这个根本没有意义。实际上,预处理相当于在失配树上做点到根的路径和。这个是简单的。然后 t,B 的同理,考虑把所有串都反过来就和 A,t 是一样的。

最难处理的是 A,t,B 三段的部分。这时候发现 t 就是 p 的子串,所以可以先用 tp 匹配,然后考虑剩余部分怎么处理。因为可以知道 tp 匹配的位置在哪,所以就给前半段剩下了一个前缀,给后半段剩下了一个后缀。实际上可以对每个位置预处理以这个位置分割,前半段能匹配前缀到哪里,后半段能匹配后缀到哪里,然后就变成了要求实际匹配的前缀和后缀是失配树上最远匹配的前缀和后缀的祖先。不难发现这个问题是一个二维数点问题,所以可以直接离线下来扫描线套树状数组。

不妨设三个字符串的长度都是 O(n) 的,时间复杂度为 O(n\log n)

#include<bits/stdc++.h>
using namespace std;
#define __MY_TEST__ 0
#define int long long
inline int read()
{
    int re=0,f=1;
    char ch=getchar();
    while(!isdigit(ch)){if(ch=='-') f=-1; ch=getchar();}
    while( isdigit(ch)) re=(re<<3)+(re<<1)+(ch^48),ch=getchar();
    return re*f;
}
const int N=1e5+5,bse=37;
char s[N],t[N],p[N];
int pw[N];
int ls,lt,lp;
struct HT
{
    int hsh[N];
    void build(char *s,int len)
    {
        for(int i=1;i<=len;i++)
        {
            hsh[i]=hsh[i-1]*bse+s[i]-'a'+1;
        }
    }
    int get(int l,int r)
    {
        return hsh[r]-hsh[l-1]*pw[r-l+1];
    }
}hs,ht,hp;
struct FT
{
    char S[N];
    int val[N],fail[N],dfn[N],tot,low[N];
    vector<int>gra[N];
    void dfs(int u)
    {
        dfn[u]=++tot;
        for(auto v:gra[u])
        {
            val[v]+=val[u];
            dfs(v);
        }
        low[u]=tot;
    }
    void build(char *s,int len)
    {
        for(int i=1;i<=len;i++) S[i]=s[i];
        fail[0]=fail[1]=0;
        int pos=0;
        for(int i=2;i<=len;i++)
        {
            while(pos&&S[pos+1]!=S[i]) pos=fail[pos];
            if(S[pos+1]==S[i]) fail[i]=++pos;
            else fail[i]=0;
        }
        for(int i=1;i<=len;i++) gra[fail[i]].push_back(i);
        dfs(0);
    }
    int get_nxt(int p,char c)
    {
        while(p&&S[p+1]!=c) p=fail[p];
        return p+(S[p+1]==c);
    }
}t1,t2;
struct BIT
{

int c[N];
int lowbit(int x)
{
    return x&(-x);
}
void add(int x,int xx)
{
    while(x<=lp+1) c[x]+=xx,x+=lowbit(x);
}
int ask(int x)
{
    int re=0;
    while(x) re+=c[x],x-=lowbit(x);
    return re;
}

}T;
int vt,pre[N],suf[N],ans[N],fp[N];
vector<pair<int,int> >vec[N],smx[N];
signed main(){
#if __MY_TEST__
    freopen(".in","r",stdin);
    freopen(".out","w",stdout);
#endif
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);
    pw[0]=1;
    for(int i=1;i<=100000;i++) pw[i]=pw[i-1]*bse;
    cin>>s+1>>t+1>>p+1;
    ls=strlen(s+1),lt=strlen(t+1),lp=strlen(p+1);
    hs.build(s,ls),ht.build(t,lt),hp.build(p,lp);
    for(int i=1;i<=min(lt,lp-1);i++)
    {
        if(ht.get(1,i)==hp.get(lp-i+1,lp)) t1.val[lp-i]++;
        if(hp.get(1,i)==ht.get(lt-i+1,lt)) t2.val[lp-i]++;
    }
    for(int i=lp;i<=lt;i++) if(ht.get(i-lp+1,i)==hp.get(1,lp)) vt++;
    t1.build(p,lp);
    reverse(p+1,p+lp+1);
    t2.build(p,lp);
    reverse(p+1,p+lp+1);
    for(int i=lp;i<=ls;i++) if(hs.get(i-lp+1,i)==hp.get(1,lp)) pre[i]=1;
    for(int i=ls-lp+1;i>0;i--) if(hs.get(i,i+lp-1)==hp.get(1,lp)) suf[i]=1;
    for(int i=1;i<=ls;i++) pre[i]+=pre[i-1];
    for(int i=ls;i;i--) suf[i]+=suf[i+1];
    for(int i=0;i<=ls;i++) ans[i]=pre[i]+suf[i+1]+vt;
    for(int i=ls;i;i--) fp[i]=t2.get_nxt(fp[i+1],s[i]);
    int pos=0;
    for(int i=0;i<=ls;i++)
    {
        if(i) pos=t1.get_nxt(pos,s[i]);
        ans[i]+=t1.val[pos]+t2.val[fp[i+1]];
        vec[t1.dfn[pos]].emplace_back(t2.dfn[fp[i+1]],i);
    }
    for(int i=2;i+lt-1<lp;i++)
    {
        if(ht.get(1,lt)==hp.get(i,i+lt-1))
            smx[t1.dfn[i-1]].emplace_back(lp-i-lt+1,1),
            smx[t1.low[i-1]+1].emplace_back(lp-i-lt+1,-1);
    }
    for(int i=1;i<=lp+1;i++)
    {
        for(auto x:smx[i]) T.add(t2.dfn[x.first],x.second),T.add(t2.low[x.first]+1,-x.second);
        for(auto x:vec[i]) ans[x.second]+=T.ask(x.first);
    }
    int maxn=0,sum=0;
    for(int i=0;i<=ls;i++)
    {
        if(ans[i]==maxn) sum++;
        else if(ans[i]>maxn) maxn=ans[i],sum=1;
    }
    cout<<maxn<<' '<<sum<<' ';
    for(int i=0;i<=ls;i++) if(ans[i]==maxn)
    {
        cout<<i<<' ';
        break;
    }
    for(int i=ls;~i;i--) if(ans[i]==maxn)
    {
        cout<<i<<' ';
        return 0;
    }
}