P9453 [ZSHOI-R1] 有效打击 题解
P9453 [ZSHOI-R1] 有效打击 题解
题目分析
-
提供一种使用 kmp 实现的线性做法。
-
对于此类查询
A 串中B 串出现了多少次的题目,可以想到使用 kmp 解决。 -
容易发现只有每一段颜色的数量和颜色的顺序是重要的,所以把原串缩写成
a_1,a_2,a_3,\dots,b_1,b_2,b_3,\dots 的形式,其中每一个a_i 或b_i 记录每个极长同色连续段的颜色和数量。 -
题目要求有
\frac{a_1}{b_1}=\frac{a_2}{b_2}=\frac{a_3}{b_3}=\dots=k\ , k>0 ,移项发现两序列相似的充分必要条件是\sum_{i=2}^{n} \frac{a_i}{a_{i-1}} =\frac{b_i}{b_{i-1}} 且每一段a_i 和b_i 的颜色相同。按照这个条件判断 kmp 的两字符相等就可以了。 -
需要注意的是这里的相等需要特判
A 和B 匹配时的起始段和B 的末尾段,这些段对前后两段之间的数量比并没有要求。 -
同样的,当
B 串缩写后长度为1 或者2 时,可能会存在缩写后的段并不是匹配的最小单位的情况,需要分类讨论。当长度为1 时显然为A 串与b_1 颜色相同段的字串个数,长度为2 时同样找出颜色相同段,计算B 串最多出现几次即可。 -
子串个数是
n^2 级别的,记得开 long long.代码实现
#include<iostream>
#include<cmath>
#define int long long
using namespace std;
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=getchar();}
while(ch<='9'&&ch>='0'){x=x*10+ch-48;ch=getchar();}
return x*f;
}
const int maxn=5e6+100;
struct node {
int x,v;
double y;
}a[maxn],b[maxn];
double h=1e-6;
int n,m;
bool cal(node e,node f,int ai,int bi){//比较两段是否相等
if((ai==1||bi==1||bi==m)&&e.x==f.x) return 1;
else if(e.x==f.x&&abs(e.y-f.y)<=h) return 1;
else return 0;
}
int gcd(int a,int b){
if(a>b)swap(a,b);
if(!a||!b) return a+b;
return gcd(b%a,a);
}
int kmp[maxn];
int j;
int ans;
signed main(){
int nn=read(),mm=read();
for(int i=1;i<=nn;i++) {
int p=read();
if(a[n].x!=p) a[++n].x=p,a[n].v=1;
else a[n].v++;
}
for(int i=1;i<=mm;i++) {
int p=read();
if(b[m].x!=p) b[++m].x=p,b[m].v=1;
else b[m].v++;
}
if(m==1){
for(int i=1;i<=n;i++){
if(a[i].x==b[m].x) ans+=a[i].v*(a[i].v+1)/2;
}
cout<<ans<<endl;
}
else if(m==2){
int k=gcd(b[1].v,b[2].v);
b[1].v=b[1].v/k,b[2].v=b[2].v/k;
for(int i=1;i<=n;i++){
if(a[i].x==b[2].x&&a[i-1].x==b[1].x)ans+=min(a[i].v/b[2].v,a[i-1].v/b[1].v);
}
cout<<ans<<endl;
}
else {
for(int i=2;i<=n;i++) a[i].y=double(a[i].v)/double(a[i-1].v);
for(int i=2;i<=m;i++) b[i].y=double(b[i].v)/double(b[i-1].v);
int la=n,lb=m;
for(int i=2;i<=lb;i++){
while(j&&!cal(b[i],b[j+1],i,j+1)){
j=kmp[j];
}
if(cal(b[i],b[j+1],i,j+1)){
j++;
kmp[i]=j;
}
}
j=0;
for(int i=1;i<=la;i++){
while(j&&!cal(a[i],b[j+1],i,j+1)){
j=kmp[j];
}
if(cal(a[i],b[j+1],i,j+1)){
j++;
}
if(j==lb){
ans++;
j=kmp[j];
}
}
cout<<ans<<endl;
}
}