题解:AT_abc403_e [ABC403E] Forbidden Prefix

· · 题解

我们考虑对集合 Y 中的元素计算它在哪些区间会有贡献。那么一个元素有贡献的区间的右端点是其某个前缀在 X 中出现的最早时间。

我们将询问离线,将每个加入 X 的元素记录出现时间最小的(这里可以使用哈希维护)。

然后对于每个在 Y 中的元素(假设其出现时间为 l),我们枚举它的前缀,取它所有前缀在 X 中出现最小的时间 r。那么在 [l,r) 中每个位置都会产生贡献。(注意特判 r<l 的情况。)

这样时间复杂度是 O(\sum\limits_{i=1}^Q|S_i|\log\sum|S_i|) 的,可以通过。

(建议使用双模哈希,考场上我写的单模被卡了。)

Code

#include <bits/stdc++.h>
#define pii pair<int,int>
#define pb emplace_back
#define ll long long
#define mk make_pair
#define se second
#define fi first
using namespace std;
bool Mst;
const int Max=2e5+10;
const int mod=998244353;
const int inf=1e9+10;

inline int read(){
    int res=0,v=1;
    char c=getchar();
    while(c<'0'||c>'9'){v=(c=='-'?-1:1);c=getchar();}
    while(c>='0'&&c<='9'){res=(res<<3)+(res<<1)+(c^48);c=getchar();}
    return res*v;
}

template <int mod>
struct modint{

    int val;

    static int norm(const int &x){return x<0?x+mod:x;}
    static int Norm(const int &x){return x>=mod?x%mod:x;}

    modint inv()const{
        int a=val,b=mod,u=1,v=0,t;
        while(b>0)t=a/b,swap(a-=t*b,b),swap(u-=t*v,v);
        return modint(u);
    }

    modint():val(0){}
    modint(const int &m):val(norm(m)){}
    modint(const long long &m):val(norm(m%mod)){}
    modint operator -()const{return modint(norm(-val));}
    bool operator ==(const modint &x){return val==x.val;}
    bool operator !=(const modint &x){return val!=x.val;}
    bool operator <=(const modint &x){return val<=x.val;}
    bool operator >=(const modint &x){return val>=x.val;}
    bool operator >(const modint &x)const{return val>x.val;}
    bool operator <(const modint &x)const{return val<x.val;}
    modint& operator *=(const modint &x){return val=static_cast<int>(1ll*val*x.val%mod),*this;}
    modint& operator <<=(const modint &x){return val=(1ll*val<<x.val)%mod,*this;}
    modint& operator +=(const modint &x){return val=Norm(1ll*val+x.val),*this;}
    modint& operator -=(const modint &x){return val=norm(1ll*val-x.val),*this;}
    modint& operator >>=(const modint &x){return val>>=x.val,*this;}
    modint& operator ^=(const modint &x){return val^=x.val,*this;}
    modint operator >>(const modint &x){return modint(*this)>>=x;}
    modint operator <<(const modint &x){return modint(*this)<<=x;}
    modint& operator /=(const modint &x){return *this*=x.inv();}
    modint operator +(const modint &x){return modint(*this)+=x;}
    modint operator -(const modint &x){return modint(*this)-=x;}
    modint operator *(const modint &x){return modint(*this)*=x;}
    modint operator /(const modint &x){return modint(*this)/=x;}
    modint operator ^(const modint &x){return modint(*this)^=x;}
    friend std::ostream& operator<<(std::ostream& os,const modint &a){return os<<a.val;}
    friend std::istream& operator>>(std::istream& is,modint &a){return is>>a.val;}
};

typedef modint<1000000007>m17;
typedef modint<998244353>m98;

m17 Bas1=97;
m98 Bas2=83;

map<pii,int>m;
int sum[Max];

struct Que{
    int opt;string s;
}b[Max];

pii get(string s){
    int len=s.size();
    m17 Ans1=0;m98 Ans2=0;
    for(int i=0;i<len;++i){
        Ans1=Ans1*Bas1+(s[i]-'a'+1);
        Ans2=Ans2*Bas2+(s[i]-'a'+1);
    }
    return mk(Ans1.val,Ans2.val);
}

bool Med;
signed main(){
    int q=read();
    for(int i=1;i<=q;++i){
        cin>>b[i].opt>>b[i].s;
        if(b[i].opt==1) {
            pii Res=get(b[i].s);
            if(m.find(Res)==m.end())m[Res]=i;
        }
    }

    for(int i=1;i<=q;++i){
        if(b[i].opt==2){
            int len=b[i].s.size();
            int pos=q+1;
            m17 Ans1=0;m98 Ans2=0;
            for(int j=0;j<len;++j){
                Ans1=Ans1*Bas1+(b[i].s[j]-'a'+1); 
                Ans2=Ans2*Bas2+(b[i].s[j]-'a'+1);
                pii Res=mk(Ans1.val,Ans2.val) ;
                if(m.find(Res)!=m.end())pos=min(pos,m[Res]);
            }
            if(pos>=i){
                sum[i]++;sum[pos]--;
            }
        }
    }
    for(int i=1;i<=q;++i){
        sum[i]+=sum[i-1];
        cout << sum[i] << "\n";
    }

    cerr<< "Time: "<<clock()/1000.0 << "s\n";
    cerr<< "Memory: " << (&Mst-&Med)/1000000.0 << "MB\n";
    return 0;
}