题解:P9576 「TAOI-2」Ciallo~(∠・ω< )⌒★
xtzqhy
·
·
题解
柚子厨差不多得了。
设字符串 s 的长度为 n,t 的长度为 m。
我们先对 [l,r] 和 [l',r'] 是否相交进行讨论。
若不交,则说明删除没有任何影响。那么我们找出 s 中的所有 t,设匹配右端点为 i,则方案数为 \binom{i-m}{2}+\binom{n-i}{2}。
若相交,则合法情况一定类似下图:
图中 x_i,y_i 分别表示 \operatorname{LCP}(s(i,n),t),\operatorname{LCS}(s(1,i),t),红色部分表示重复部分。
重复部分的长度为 x_{l'}+y_{r'}-m,则方案数为 x_{l'}+y_{r'}-m+1。
可以发现,这部分贡献的 [l',r'] 需要满足:
r'-l'\ge m
\\
x_{l'}+y_{r'} \ge m
\\
x_{l'} \neq m
\\
y_{r'} \neq m
不能等于 m 是因为这部分贡献已经在前面算过了,不能重复计算。
可以发现,这是一个二维数点。
因为不能等于 m,所以在数点时还要把这部分的贡献减去。
```cpp
#include"bits/stdc++.h"
#define re register
#define int long long
#define lb(x) (x&(-x))
using namespace std;
const int maxn=4e5+10,base=131,mod=1e9+7;
int n,m,ans;
char s1[maxn],s2[maxn];
int pw[maxn];
int h1[maxn],h2[maxn];
int x[maxn],y[maxn];
struct BIT{
int tr[maxn];
//因为会用到 0 这个位置,所以需要整体加个偏移量。
inline void add(int x,int val){++x;while(x<maxn) tr[x]+=val,x+=lb(x);}
inline int query(int x){++x;int res=0;while(x>0) res+=tr[x],x-=lb(x);return res;}
}a,b;
inline int calc1(int l,int r){return ((h1[r]-h1[l-1]*pw[r-l+1]%mod)%mod+mod)%mod;}
inline int calc2(int l,int r){return ((h2[r]-h2[l-1]*pw[r-l+1]%mod)%mod+mod)%mod;}
inline bool check1(int pos,int len){
if(pos+len-1>n) return 0;
return calc1(pos,pos+len-1)==calc2(1,len);
}
inline bool check2(int pos,int len){
if(pos-len+1<1) return 0;
return calc1(pos-len+1,pos)==calc2(m-len+1,m);
}
inline void init(){
pw[0]=1;
for(re int i=1;i<maxn;++i) pw[i]=pw[i-1]*base%mod;
for(re int i=1;i<=n;++i) h1[i]=(h1[i-1]*base+s1[i])%mod;
for(re int i=1;i<=m;++i) h2[i]=(h2[i-1]*base+s2[i])%mod;
for(re int i=1,l,r,mid,res;i<=n;++i){
l=0,r=m,res=0;
while(l<=r){
mid=(l+r)>>1;
if(check1(i,mid)) res=mid,l=mid+1;
else r=mid-1;
}
x[i]=res;
l=0,r=m,res=0;
while(l<=r){
mid=(l+r)>>1;
if(check2(i,mid)) res=mid,l=mid+1;
else r=mid-1;
}
y[i]=res;
}
}
signed main(){
#ifndef ONLINE_JUDGE
freopen("1.in","r",stdin);
freopen("1.out","w",stdout);
#endif
ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
cin>>(s1+1)>>(s2+1);n=strlen(s1+1),m=strlen(s2+1);
init();
for(re int i=1;i<=n;++i) if(y[i]==m) ans+=((i-m)*(i-m+1)/2+(n-i)*(n-i+1)/2);
for(re int i=m+1;i<=n;++i){
a.add(m-x[i-m],1),b.add(m-x[i-m],x[i-m]);
ans+=b.query(y[i])+(y[i]-m+1)*a.query(y[i]);
if(y[i]==m) ans-=a.query(y[i]);
ans-=a.query(0);
}
cout<<ans;
return 0;
}
```