题解:P14363 [CSP-S 2025] 谐音替换 / replace(民间数据)

· · 题解

sqrt 做法(

能不能过官方存疑,现在过了熨斗的

upd:洛谷过了

观察到 t_1,t_2 一定是形如一段相同 + 一段不同 + 一段相同的形式,一个合法的替换一定也是这种形式,并且要满足 t_1,t_2 的不同等价于 s_1,s_2 的不同,t_1 能够分别匹配上 s_1,s_2 不同段之外的前缀和后缀。

考虑把一种相同的 替换中的不同段 视为一组替换,那么对于 t_1,t_2 的合法替换也就是在 其中正好符合 t_1,t_2 的那一组替换 中 找出能够同时匹配 t_1 替换部分之前的后缀 和 t_1 替换部分之后的前缀。

接下来的部分就是考场犯蠢。尝试把 s_1 替换部分的 之前部分子串 和 之后部分子串 分别扔进两个不同的 trie 树,并且 每一组 单独开两课这样的 trie。在 trie 的每个节点上标记结束的区间,那么答案就是 t_1 在两个 trie 树上走之后经过的 两个点集的并 的大小。这个东西暴力 vector 是 O(nq) 的,不满。2G 所以不用担心 trie 的空间开销。

考虑到 \sum L,尝试根号分治。对于 不同部分之外的前后缀部分 长度比较小的 s_1,考虑直接记录前后缀的 hash,之后在查询的时候小段暴力查询是否存在这样的 hash 值,长度大的仍然扔进 trie 里。设取的分界长度为 B,复杂度 O(qB^2+\max(\sum L,q\min(n,\frac{\sum L}{B})))。看起来并不可过,但是要除掉一些常数(中间的不同部分、前后缀长度要减半等,而且不能单开一组数据卡我这个吧再减半)相当不满,极限应该是不到 1e9 的。

球球大家都不要卡我

记得判断 |t_1|\ne |t_2|。考场没判。

::::::success[Code]

#include<bits/stdc++.h>
#define N 200005
#define S 5200005
#define P 13331
#define mod 1000000027
#define B 50
using namespace std;
struct Trie{
    int trans[S][26];
    int tot;
    int root[N];
    vector<int> ed[S];
    inline void insert(int Rt,char *s,int L,int R,int id,bool o){
        if(!root[Rt]) root[Rt]=++tot;
        int now=root[Rt];
        if(o)for(int i=L;i<=R;i++){
            if(!trans[now][s[i]-'a']) trans[now][s[i]-'a']=++tot;
            now=trans[now][s[i]-'a'];
        }
        else for(int i=R;i>=L;i--){
            if(!trans[now][s[i]-'a']) trans[now][s[i]-'a']=++tot;
            now=trans[now][s[i]-'a'];
        }
        ed[now].push_back(id);
    }
    inline void query(int Rt,char *s,int L,int R,vector<int> &v,bool o){
        int now=root[Rt];
        for(auto nx:ed[now]) v.push_back(nx);
        if(o)for(int i=L;i<=R;i++){
            now=trans[now][s[i]-'a'];
            if(!now) break;
            for(auto nx:ed[now]) v.push_back(nx);
        }
        else for(int i=R;i>=L;i--){
            now=trans[now][s[i]-'a'];
            if(!now) break;
            for(auto nx:ed[now]) v.push_back(nx);
        }
    }
}fr,ed;
map<pair<int,int>,int> mp;
int cnt;
inline int grt(int h1,int h2){
    if(mp.count({h1,h2})) return mp[{h1,h2}];
    return mp[{h1,h2}]=++cnt;
}
int bs[S];
int hsh[S];
inline int geths(int L,int R){
    if(L>R) return 0;
    return (hsh[R]-1ll*hsh[L-1]*bs[R-L+1]%mod+mod)%mod;
}
inline int ghs(char *s,int l,int r){
    int hs=0;
    for(int i=l;i<=r;i++) hs=(1ll*hs*P+s[i])%mod;
    return hs;
}
char s[S],t[S];
struct DS{
    int col[N],val[N],now;
    int &operator[](int x){
        if(col[x]!=now) col[x]=now,val[x]=0;
        return val[x];
    }
    void clear(){
        now++;
    }
}ds;
map<pair<int,int>,int> hscnt[N];
signed main(){
    ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    bs[0]=1;
    for(int i=1;i<S;i++) bs[i]=1ll*bs[i-1]*P%mod;
    int n,q;
    cin>>n>>q;
    for(int i=1;i<=n;i++){
        cin>>(s+1)>>(t+1);
        int m=strlen(s+1);
        int L=0,R=m+1;
        while(L+1<=m and s[L+1]==t[L+1]) L++;
        while(R-1>=1 and s[R-1]==t[R-1]) R--;
        if(L>=R) continue;
        int h1=ghs(s,L+1,R-1),h2=ghs(t,L+1,R-1);
        int id=grt(h1,h2);
        if(L<=B and R>=m-B+1){
            hscnt[id][{ghs(s,1,L),ghs(s,R,m)}]++;
            continue;
        }
        fr.insert(id,s,1,L,i,0);
        ed.insert(id,t,R,m,i,1);
    }
    while(q--){
        cin>>(s+1)>>(t+1);
        int m=strlen(s+1);
        int L=0,R=m+1;
        while(L+1<=m and s[L+1]==t[L+1]) L++;
        while(R-1>=1 and s[R-1]==t[R-1]) R--;
        int sum=0;
        if((int)strlen(t+1)!=m){
            cout<<0<<'\n';
            continue;
        }
        int h1=ghs(s,L+1,R-1),h2=ghs(t,L+1,R-1);
        if(!mp.count({h1,h2})){
            cout<<0<<'\n';
            continue;
        }
        int id=grt(h1,h2);
        for(int i=1;i<=m;i++) hsh[i]=(1ll*hsh[i-1]*P+s[i])%mod;
        for(int i=max(1,L-B+1);i<=L+1;i++){
            for(int j=R-1;j<=min(R+B-1,m);j++){
                sum+=hscnt[id][{geths(i,L),geths(R,j)}];
            }
        }
        if(m<=B){
            cout<<sum<<'\n';
            continue;
        }
        vector<int> pre,suf;
        fr.query(id,s,1,L,pre,0);
        ed.query(id,t,R,m,suf,1);
        ds.clear();
        for(auto nx:pre) ds[nx]++;
        for(auto nx:suf) sum+=ds[nx];
        cout<<sum<<'\n';
    }
    return 0;
}

::::::