题解 CF710F 【String Set Queries】

super蒟蒻

2020-07-23 21:45:40

Solution

解法:二进制分组 + AC自动机 啥是二进制分组? 就是把所有的操作按二进制分组,每一组里面就开个数据结构维护。 怎么分组呢? 因为是在线算法,所以这个东西有点像一个叫2048的游戏。 假设之前已经分好两组,每组的大小为 4,2(大小就是里面包含了多少个操作)。 加一个操作,那么就有三组 4,2,1;如果再加一个操作,就是 4,2,1,1。 此时有两个 1,就要把那两组暴力重构,合并成一组,得到 4,2,2;再合并 ...... 最终剩余一组,大小为 8 。 可以得到,如果有 n 个操作,那么至多会分为 $\log n$ 组,因为它分得的组数恰好就是 n 在二进制的表示下 1 的个数。(例如上面的 7 个操作被分成了 3 组) 那么暴力重构的时间复杂度呢? 考虑**每个点**会被重构多少次,如果把合并的过程画出来的话(如下图),你会发现这很像是一棵线段树,那么每个点至多被重构 $\log n$ 次。 ![](https://cdn.luogu.com.cn/upload/image_hosting/l242p5zc.png) 最后乘上用数据结构维护对应的复杂度就是重构的复杂度了。 回到这题,首先得会 `AC自动机` 并 A 了这道模板题:[【模板】AC自动机(二次加强版)](https://www.luogu.com.cn/problem/P5357) 然后你会发现这题反而简单点的,模板题是求每个模式串的贡献,这题只需求所有模式串的总贡献。 模板题并没有每次暴力跳 $fail$ 边而是打了个标记,那我们也无需暴力求,像弄个前缀和一样在建树时沿着 $fail$ 边累加就行。 对于插入操作,对它进行二进制分组,**每个组建一个AC自动机**,到时候查询就是把 $\log n$ 个AC自动机 $search$ 一遍就好了。 最后就是删除操作,可以发现贡献**具有可减性**。 什么意思? 就是我们可以求出所有串的贡献,用它减去已删除串的贡献,得到的就是答案。 这样一来只要对插入,删除的字符串**各自**进行二进制分组,然后相减即可。 AC自动机的建树是 $O(n)$ 的,所以时间复杂度是 $O(len \log n)$ , $len$ 表示输入的总串长。 第一次写写丑了,常数有点大。 如果用了 $printf$ 的话在后面记得加上这句话:`fflush(stdout);` ~~我也不知道为什么~~ $code:$ ```cpp // 二进制分组 + AC自动机 #include <cstdio> #include <string> #include <cstring> #include <queue> using namespace std; const int maxn = 300005; struct Group { string data[maxn]; int tot,ch[maxn][26],cnt[maxn],n; void insert(int r,char *s) { int d=s[1]-'a',len=strlen(s+1),_r=r; for(int i=1;i<=len;d=s[++i]-'a') { if(ch[r][d]==_r) { ch[r][d]=++tot; for(int j=0;j<26;j++) ch[tot][j]=_r; } r=ch[r][d]; } ++cnt[r]; } int fail[maxn]; void build(int r) { fail[r]=r; queue<int>q; for(int i=0;i<26;i++) if(ch[r][i]>r) q.push(ch[r][i]),fail[ch[r][i]]=r; while(!q.empty()) { int u=q.front(); q.pop(); for(int i=0;i<26;i++) { int v=ch[u][i]; if(v==r) ch[u][i]=ch[fail[u]][i]; else { fail[v]=ch[fail[u]][i]; cnt[v]+=cnt[ch[fail[u]][i]]; q.push(v); } } } } int search(int r,char *s) { int d=s[1]-'a',len=strlen(s+1),res=0; for(int i=1;i<=len;d=s[++i]-'a') r=ch[r][d],res+=cnt[r]; return res; } int stk[maxn],fr[maxn],size[maxn],N; void merge() { --N,size[N]<<=1; for(int i=stk[N];i<=tot;i++) { cnt[i]=fail[i]=0; for(int j=0;j<26;j++) ch[i][j]=0; } tot=stk[N]; for(int i=0;i<26;i++) ch[tot][i]=tot; for(int L=fr[N];L<=n;L++) insert(stk[N],&data[L][0]); build(stk[N]); } void Insert(char *s) { stk[++N]=++tot; for(int i=0;i<26;i++) ch[tot][i]=tot; size[N]=1; insert(tot,s),build(stk[N]); int len=strlen(s+1); data[fr[N]=++n]=" "; for(int i=1;i<=len;i++) data[n]+=s[i]; while(size[N]==size[N-1]) merge(); } int Count(char *s) { int ans=0; for(int i=1;i<=N;i++) ans+=search(stk[i],s); return ans; } }Add,Sub; char tmp[maxn]; int main() { int T; scanf("%d",&T); for(int i=1;i<=T;i++) { int opt; scanf("%d%s",&opt,tmp+1); if(opt==1) Add.Insert(tmp); else if(opt==2) Sub.Insert(tmp); else printf("%d\n",Add.Count(tmp)-Sub.Count(tmp)); fflush(stdout); } return 0; } ```