题解:CF1313E Concatenation with intersection
shiruoyu114514 · · 题解
给定两个长度为
n 的字符串a,b 以及一个长度为m 的字符串s 。问有多少个四元组(l_1,r_1,l_2,r_2) 满足a[l_1:r_1]+b[l_2:r_2]=s 。n \le 5 \times 10^5,m \le 10^6 。
一个比较错误的切入点是枚举
有
n 组询问,每组询问形如:给定[l_1,r_1][l_2,r_2] ,求\sum \limits_{i=l_1}^{r_1} \sum \limits_{j=l_2}^{r_2} [|a_i-b_j| < m] 。
显然做不了!
事实上,我们需要注意到的一件事情是,
现在我们考虑怎么刻画
发现在不考虑总长度和为
考虑预处理
我们就能够发现,合法的方案数就恰好为
令
使用你喜欢的数据结构处理即可。时间复杂度
#include<bits/stdc++.h>
using namespace std;
#define int long long
constexpr int maxn=1e6;
constexpr int base=31;
constexpr int mod=998244353;
int powbase[maxn+5];
struct String{
string s;
vector<int> hsh;
int n;
void init(){
cin>>s;
n=s.length();
s=" "+s;
hsh.resize(n+1,0);
for(int i=1;i<=n;i++){
hsh[i]=(hsh[i-1]*base+s[i])%mod;
//cout<<hsh[i]<<" ";
}
//cout<<"\n";
}
int calc(int l,int r){
// cout<<"Calc "<<l<<" "<<r<<"\n";
return (hsh[r]-hsh[l-1]*powbase[r-l+1]%mod+mod)%mod;
}
char& operator [](int x){
//cerr<<"Call "<<x<<"\n";
return s[x];
}
};
int n,m;
String a,b,s;
int longa[maxn+5],longb[maxn+5];
struct bit{
int v[maxn+5];
void update(int pos,int value){
pos++;
while(pos){
v[pos]+=value;
pos-=pos&(-pos);
}
}
int quary(int pos){
pos++;
int ans=0;
while(pos<=m+1){
ans+=v[pos];
pos+=pos&(-pos);
}
return ans;
}
}ct,v;
signed main(){
ios::sync_with_stdio(0);
cin.tie(nullptr);
powbase[0]=1;
for(int i=1;i<=maxn;i++){
powbase[i]=powbase[i-1]*base%mod;
}
cin>>n>>m;
a.init();
b.init();
s.init();
for(int i=1;i<=n;i++){
if(a[i]!=s[1]){
longa[i]=0;
}
else{
int l=1,r=min(n-i+1,m);
while(l<r){
int mid=((l+r)>>1)+1;
if(a.calc(i,i+mid-1)==s.calc(1,mid)){
l=mid;
}
else{
r=mid-1;
}
}
longa[i]=l;
}
}
for(int i=1;i<=n;i++){
if(b[i]!=s[m]){
longb[i]=0;
}
else{
int l=1,r=min(i,m);
while(l<r){
int mid=((l+r)>>1)+1;
if(b.calc(i-mid+1,i)==s.calc(m-mid+1,m)){
l=mid;
}
else{
r=mid-1;
}
}
longb[i]=l;
}
}
int ans=0;
// for(int l=1;l<=n;l++){
// for(int r=1;r<=n;r++){
// if(r>=l&&r<l+m-1){
// int va=longa[l]-(longa[l]>=m);
// int vb=longb[r]-(longb[r]>=m);
// if(vb>=max(0ll,m-1-va)){
// ans+=va+vb-m+1;
// }
// }
// }
// }
vector<pair<int,int>> quary[n+5];
for(int l=n;l>=1;l--){
int va=m-1-(longa[l]-(longa[l]>=m));//注意此处 va 为题解中的 m-1-va
quary[l].push_back(make_pair(va,1));
if(l+m-1<=n){
quary[l+m-1].push_back(make_pair(va,-1));
}
}
for(int l=n;l>=1;l--){
int vb=longb[l]-(longb[l]>=m);
ct.update(vb,1);
v.update(vb,vb);
for(auto j:quary[l]){
ans+=(-ct.quary(j.first)*(j.first)+v.quary(j.first))*j.second;
}
}
cout<<ans;
}