题解CF1313E 【Concatenation with intersection】

· · 题解

题意:给定三个串 a,b,s,其中 a,b 长度为 ns 长度为 m,求出四元组 (l1,r1,l2,r2) 的个数,满足:

一道字符串+数据结构的神题。

首先,我们可以将两个串拼成 s 转化为 a 中匹配一个 s 的非空前缀,b 中匹配一个 s 的非空后缀。

我们定义 za_is_1s_2 \dots s_{m-1}a_i,a_{i+1} \dots a_n 的最长公共前缀,zb_is_2s_3 \dots s_{m}b_1 b_2 \dots b_i 的最长公共后缀。这个可以用 Z 算法求出。详情见我的 blog ,这里就不赘述了。

不难发现,由于拼成的字符串恰好为 s,因此

r1-l1+1+r2-l2+1=m

一定成立。移一下项,即

r2=l1-r1+l2-2+m

由于两区间有交点,所以 l1 \leq r2l2 \leq r1 成立,所以 -r1+l2 一定小于等于 0,式子又可以变为

r2=l1-r1+l2-2+m \leq l1-2+m

所以我们可以得到 r2 \leq l1+m-2

又根据 l1 \leq r2,所以两个区间满足条件一当且仅当 l1 \leq r2 \leq l1+m-2l2r1 就可以被我们忽略了。

我们发现,如果两个数 l1,r2 满足上面的条件,那么满足条件的方案数为 \max(za_{l1}+zb_{r2}-m+1,0)

这应该不是太难理解,举个栗子:

- $\texttt{aa}+\texttt{bab}=\texttt{aabab}

刚好就是 4+3-5+1

这样我们就可以将原本 n^5 的事情降到了 n^2,但是还是过不去啊。别急我们继续推式子。

根据上面的分析最终的答案就是:

\sum\limits_{l1=1}^n\ \sum\limits_{r2=l1}^{\min(l1+m-2,n)} \max(za_{l1}+zb_{r2}-m+1,0)

不难发现随着 l1 的增长,r2 也是单调递增的,所以我们可以用类似 two pointers 的方式维护,但是这个 \max 有点讨厌,那么我们怎么想办法把这个 \max 去掉呢?

我们发现,如果 za_{l1}+zb_{r2}-m+1<0,那么这个情况对答案没有贡献,因此我们只用考虑 za_{l1}+zb_{r2}-m+1 \geq 0,即 m-1-zb_{r2} \leq za_{l1} 的情况,我们可以把这个条件加到我们的式子中,即

\sum\limits_{l1=1}^n\ \sum\limits_{r2=l1}^{\min(l1+m-2,n)} za_{l1}+zb_{r2}-m+1\ [m-1-zb_{r2} \leq za_{l1}]

假设在区间 [l1,\min(l1+m-2,n)] 中满足 m-1-zb_{r2} \leq za_{l1}r1 的个数为 cnt,又可以变为

\sum\limits_{l1=1}^n\ cnt \times (za_{l1}-m+1)+\sum\limits_{r2=l1}^{\min(l1+m-2,n)} zb_{r2}\ [m-1-zb_{r2} \leq za_{l1}]

发现 m-1-zb_{r2} 不含 l1,那么我们可以建两个树状数组,从小到大枚举 l1,然后单调地增加 r2,每增加一个 l2,就在第一个树状数组 m-1-zb_{r2} 增加 zb_{r2},第二个树状数组 m-1-zb_{r2} 增加 1,然后第一个树状数组 \leq za_{l1} 位置上的和就是对应的 \sum\limits_{r2=l1}^{\min(l1+m-2,n)} zb_{r2},第一个树状数组 \leq za_{l1} 位置上的和就是对应的 cnt。然后加一下就可以了。

//Coded by tzc_wk
/*
数据不清空,爆零两行泪。
多测不读完,爆零两行泪。
边界不特判,爆零两行泪。
贪心不证明,爆零两行泪。
D P 顺序错,爆零两行泪。
大小少等号,爆零两行泪。
变量不统一,爆零两行泪。
越界不判断,爆零两行泪。
调试不注释,爆零两行泪。
溢出不 l l,爆零两行泪。
*/
#include <bits/stdc++.h>
using namespace std;
#define fi          first
#define se          second
#define fz(i,a,b)   for(int i=a;i<=b;i++)
#define fd(i,a,b)   for(int i=a;i>=b;i--)
#define foreach(it,v) for(__typeof(v.begin()) it=v.begin();it!=v.end();it++)
#define all(a)      a.begin(),a.end()
#define giveup(...) return printf(__VA_ARGS__),0;
#define fill0(a)    memset(a,0,sizeof(a))
#define fill1(a)    memset(a,-1,sizeof(a))
#define fillbig(a)  memset(a,0x3f,sizeof(a))
#define fillsmall(a) memset(a,0xcf,sizeof(a))
#define mask(a)     (1ll<<(a))
#define maskx(a,x)  ((a)<<(x))
#define _bit(a,x)   (((a)>>(x))&1)
#define _sz(a)      ((int)(a).size())
#define filei(a)    freopen(a,"r",stdin);
#define fileo(a)    freopen(a,"w",stdout);
#define fileio(a)   freopen(a".in","r",stdin);freopen(a".out","w",stdout)
#define eprintf(...) fprintf(stderr,__VA_ARGS__)
#define put(x)      putchar(x)
#define eoln        put('\n')
#define space       put(' ')
#define y1          y_chenxiaoyan_1
#define y0          y_chenxiaoyan_0
#define int long long
typedef pair<int,int> pii;
inline int read(){
    int x=0,neg=1;char c=getchar();
    while(!isdigit(c)){
        if(c=='-')  neg=-1;
        c=getchar();
    }
    while(isdigit(c))   x=x*10+c-'0',c=getchar();
    return x*neg;
}
inline void print(int x){
    if(x<0){
        putchar('-');
        print(abs(x));
        return;
    }
    if(x<=9)    putchar(x+'0');
    else{
        print(x/10);
        putchar(x%10+'0');
    }
}
inline int qpow(int x,int e,int _MOD){
    int ans=1;
    while(e){
        if(e&1) ans=ans*x%_MOD;
        x=x*x%_MOD;
        e>>=1;
    }
    return ans;
}
int n=read(),m=read();
char a[500005],b[500005],s[1000005];
char buf[1500005];
int cnt=0,za[1500005],zb[1500005],z1[500005],z2[500005];
struct Fenwick_Tree{
    int bit[1500005];
    inline int lowbit(int x){
        return x&(-x);
    }
    inline void add(int x,int val){
        x=max(x,1ll);
        for(int i=x;i<=1000000;i+=lowbit(i))    bit[i]+=val;
    }
    inline int query(int x){
        int ans=0;
        for(int i=x;i;i-=lowbit(i)) ans+=bit[i];
        return ans;
    }
} t1,t2;
signed main(){
    scanf("%s%s%s",a+1,b+1,s+1);
    fz(i,1,m-1) buf[cnt++]=s[i];
    buf[cnt++]='#';
    fz(i,1,n)   buf[cnt++]=a[i];
    int l=0,r=0;
    for(int i=1;i<cnt;i++){
        if(i<=r)    za[i]=min(za[i-l],r-i);
        while(i+za[i]<cnt&&buf[za[i]]==buf[i+za[i]])    za[i]++;
        if(i+za[i]>r){
            l=i;
            r=i+za[i];
        }
    }
    fill0(buf);cnt=0;
    fd(i,m,2)   buf[cnt++]=s[i];
    buf[cnt++]='#';
    fd(i,n,1)   buf[cnt++]=b[i];
    l=0,r=0;
    for(int i=1;i<cnt;i++){
        if(i<=r)    zb[i]=min(zb[i-l],r-i);
        while(i+zb[i]<cnt&&buf[zb[i]]==buf[i+zb[i]])    zb[i]++;
        if(i+zb[i]>r){
            l=i;
            r=i+zb[i];
        }
    }
    fz(i,1,n)   z1[i]=za[i+m-1];
    fz(i,1,n)   z2[i]=zb[i+m-1];
    reverse(z2+1,z2+n+1); 
    int r2=1,ans=0;
    fz(l1,1,n){
        while(r2<=min(l1+m-2,n)){
            t1.add(m-1-z2[r2],z2[r2]);
            t2.add(m-1-z2[r2],1);
            r2++;
        }
        ans+=(z1[l1]-m+1)*t2.query(z1[l1])+t1.query(z1[l1]);
        t1.add(m-1-z2[l1],-z2[l1]);
        t2.add(m-1-z2[l1],-1);
    }
    cout<<ans<<endl;
    return 0;
}

UPD:2020.2.26:公式中有个地方写错了,应该是 za_{l1}+zb_{r2}-m+1