题解:P9576 「TAOI-2」Ciallo~(∠・ω< )⌒★

· · 题解

柚子厨差不多得了。

设字符串 s 的长度为 nt 的长度为 m

我们先对 [l,r][l',r'] 是否相交进行讨论。

若不交,则说明删除没有任何影响。那么我们找出 s 中的所有 t,设匹配右端点为 i,则方案数为 \binom{i-m}{2}+\binom{n-i}{2}

若相交,则合法情况一定类似下图:

图中 x_i,y_i 分别表示 \operatorname{LCP}(s(i,n),t),\operatorname{LCS}(s(1,i),t),红色部分表示重复部分。

重复部分的长度为 x_{l'}+y_{r'}-m,则方案数为 x_{l'}+y_{r'}-m+1

可以发现,这部分贡献的 [l',r'] 需要满足:

r'-l'\ge m \\ x_{l'}+y_{r'} \ge m \\ x_{l'} \neq m \\ y_{r'} \neq m

不能等于 m 是因为这部分贡献已经在前面算过了,不能重复计算。

可以发现,这是一个二维数点。

因为不能等于 m,所以在数点时还要把这部分的贡献减去。

```cpp #include"bits/stdc++.h" #define re register #define int long long #define lb(x) (x&(-x)) using namespace std; const int maxn=4e5+10,base=131,mod=1e9+7; int n,m,ans; char s1[maxn],s2[maxn]; int pw[maxn]; int h1[maxn],h2[maxn]; int x[maxn],y[maxn]; struct BIT{ int tr[maxn]; //因为会用到 0 这个位置,所以需要整体加个偏移量。 inline void add(int x,int val){++x;while(x<maxn) tr[x]+=val,x+=lb(x);} inline int query(int x){++x;int res=0;while(x>0) res+=tr[x],x-=lb(x);return res;} }a,b; inline int calc1(int l,int r){return ((h1[r]-h1[l-1]*pw[r-l+1]%mod)%mod+mod)%mod;} inline int calc2(int l,int r){return ((h2[r]-h2[l-1]*pw[r-l+1]%mod)%mod+mod)%mod;} inline bool check1(int pos,int len){ if(pos+len-1>n) return 0; return calc1(pos,pos+len-1)==calc2(1,len); } inline bool check2(int pos,int len){ if(pos-len+1<1) return 0; return calc1(pos-len+1,pos)==calc2(m-len+1,m); } inline void init(){ pw[0]=1; for(re int i=1;i<maxn;++i) pw[i]=pw[i-1]*base%mod; for(re int i=1;i<=n;++i) h1[i]=(h1[i-1]*base+s1[i])%mod; for(re int i=1;i<=m;++i) h2[i]=(h2[i-1]*base+s2[i])%mod; for(re int i=1,l,r,mid,res;i<=n;++i){ l=0,r=m,res=0; while(l<=r){ mid=(l+r)>>1; if(check1(i,mid)) res=mid,l=mid+1; else r=mid-1; } x[i]=res; l=0,r=m,res=0; while(l<=r){ mid=(l+r)>>1; if(check2(i,mid)) res=mid,l=mid+1; else r=mid-1; } y[i]=res; } } signed main(){ #ifndef ONLINE_JUDGE freopen("1.in","r",stdin); freopen("1.out","w",stdout); #endif ios::sync_with_stdio(0),cin.tie(0),cout.tie(0); cin>>(s1+1)>>(s2+1);n=strlen(s1+1),m=strlen(s2+1); init(); for(re int i=1;i<=n;++i) if(y[i]==m) ans+=((i-m)*(i-m+1)/2+(n-i)*(n-i+1)/2); for(re int i=m+1;i<=n;++i){ a.add(m-x[i-m],1),b.add(m-x[i-m],x[i-m]); ans+=b.query(y[i])+(y[i]-m+1)*a.query(y[i]); if(y[i]==m) ans-=a.query(y[i]); ans-=a.query(0); } cout<<ans; return 0; } ```