题解:CF1313E Concatenation with intersection

· · 题解

给定两个长度为 n 的字符串 a,b 以及一个长度为 m 的字符串 s。问有多少个四元组 (l_1,r_1,l_2,r_2) 满足 a[l_1:r_1]+b[l_2:r_2]=sn \le 5 \times 10^5,m \le 10^6

一个比较错误的切入点是枚举 a 贡献的前缀。如果你想要使用后缀数组的话,可以处理出来 s 对应前缀所能匹配的后缀数组区间以及 s 后缀的反串在 b 反串所能匹配的后缀数组区间,然后就转化为了以下问题:

n 组询问,每组询问形如:给定 [l_1,r_1][l_2,r_2],求 \sum \limits_{i=l_1}^{r_1} \sum \limits_{j=l_2}^{r_2} [|a_i-b_j| < m]

显然做不了!

事实上,我们需要注意到的一件事情是,[l_1,r_1][l_2,r_2] 的区间长度总和为 m。也就是说,我们只需要保证 r_2<l_1+m-1 并且 r_2 \ge l_1 即可。

现在我们考虑怎么刻画 r_1l_2 的计数。

发现在不考虑总长度和为 m 的情况下,r_1l_2 都呈一个区间。

考虑预处理 A_ia[i:n]s 的 LCP,B_ib[1:i]s 的 LCS。

我们就能够发现,合法的方案数就恰好为 \max(0,(A_{l_1}-(A_{l_1} \ge m))+(B_{r_1}-(B_{r_1} \ge m))-(m-1))。有这么一个特判是因为 ab 的贡献不能为空。

va_i=A_i-(A_i \ge m),vb_i 同理,则所求为:

\sum \limits_{r \ge l,r < l+m-1} \max(0,va+vb-m+1)

使用你喜欢的数据结构处理即可。时间复杂度 O(n \log n)

#include<bits/stdc++.h>

using namespace std;

#define int long long

constexpr int maxn=1e6;
constexpr int base=31;
constexpr int mod=998244353;

int powbase[maxn+5];

struct String{
    string s;
    vector<int> hsh;
    int n;

    void init(){
        cin>>s;
        n=s.length();
        s=" "+s;
        hsh.resize(n+1,0);
        for(int i=1;i<=n;i++){
            hsh[i]=(hsh[i-1]*base+s[i])%mod;
            //cout<<hsh[i]<<" ";
        }
        //cout<<"\n";
    }

    int calc(int l,int r){
    //  cout<<"Calc "<<l<<" "<<r<<"\n";
        return (hsh[r]-hsh[l-1]*powbase[r-l+1]%mod+mod)%mod;
    }

    char& operator [](int x){
        //cerr<<"Call "<<x<<"\n";
        return s[x];
    }
};

int n,m;
String a,b,s;

int longa[maxn+5],longb[maxn+5];

struct bit{
    int v[maxn+5];

    void update(int pos,int value){
        pos++;
        while(pos){
            v[pos]+=value;
            pos-=pos&(-pos);
        }
    }

    int quary(int pos){
        pos++;
        int ans=0;
        while(pos<=m+1){
            ans+=v[pos];
            pos+=pos&(-pos);
        }
        return ans;
    }
}ct,v;

signed main(){
    ios::sync_with_stdio(0);
    cin.tie(nullptr);
    powbase[0]=1;
    for(int i=1;i<=maxn;i++){
        powbase[i]=powbase[i-1]*base%mod;
    }
    cin>>n>>m;
    a.init();
    b.init();
    s.init();
    for(int i=1;i<=n;i++){
        if(a[i]!=s[1]){
            longa[i]=0;
        }
        else{
            int l=1,r=min(n-i+1,m);
            while(l<r){
                int mid=((l+r)>>1)+1;
                if(a.calc(i,i+mid-1)==s.calc(1,mid)){
                    l=mid;
                }
                else{
                    r=mid-1;
                }
            }
            longa[i]=l;
        }
    }
    for(int i=1;i<=n;i++){
        if(b[i]!=s[m]){
            longb[i]=0;
        }
        else{
            int l=1,r=min(i,m);
            while(l<r){
                int mid=((l+r)>>1)+1;
                if(b.calc(i-mid+1,i)==s.calc(m-mid+1,m)){
                    l=mid;
                }
                else{
                    r=mid-1;
                }
            }
            longb[i]=l;
        }
    }
    int ans=0;
//  for(int l=1;l<=n;l++){
//      for(int r=1;r<=n;r++){
//          if(r>=l&&r<l+m-1){
//              int va=longa[l]-(longa[l]>=m);
//              int vb=longb[r]-(longb[r]>=m);
//              if(vb>=max(0ll,m-1-va)){
//                  ans+=va+vb-m+1;
//              }
//          }
//      }
//  }
    vector<pair<int,int>> quary[n+5];
    for(int l=n;l>=1;l--){
        int va=m-1-(longa[l]-(longa[l]>=m));//注意此处 va 为题解中的 m-1-va
        quary[l].push_back(make_pair(va,1));
        if(l+m-1<=n){
            quary[l+m-1].push_back(make_pair(va,-1));
        }
    }
    for(int l=n;l>=1;l--){
        int vb=longb[l]-(longb[l]>=m);
        ct.update(vb,1);
        v.update(vb,vb);
        for(auto j:quary[l]){
            ans+=(-ct.quary(j.first)*(j.first)+v.quary(j.first))*j.second;
        }
    }
    cout<<ans;
}