题解:P8147 [JRKSJ R4] Salieri

· · 题解

前言

因为 csp 被 AC 自动机干爆所以决定做一下模板题

思路

首先我们会一个很简单的暴力我们将 s_i 插入自动机中,然后对于每个询问的 S 串先在 AC 自动机上走,然后我们可以发现对于 S 走到的每一个点都把 x\to rt 的路径上的点都加一(因为出现次数多了一次)然后对于一个 s_i 我们在每次加完之后只需要查询这个串末尾节点的值即可,这样我们就会了 O(Lq\log{L}) 的做法,考虑优化。

通过观察我们发现对于一次修改我们得到的 cnt 一共只有 |S| 个然后我们又有 \sum |S|\leq 5\times 10^5,通过这个不难想到虚树,我们发现对于修改一条链的时候中间那些点的 cnt 都是一样的所以对于他们来说需要达到的下限也是一样的,所以我们可以考虑对 S 走到的点建立虚树,然后二分答案 x 对于每一条链我们都可以求出所需的下限 val 于是这个问题就转化成了查询一条链上 w_i\geq val 的有多少个,这个用主席树不难维护,然后就做完了,记得清空就行了。

代码

细节见代码。

#include <bits/stdc++.h>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/tree_policy.hpp>
#include <ext/rope>
using namespace __gnu_pbds;
using namespace std;
#define pb push_back
#define rep(i,x,y) for(register int i=x;i<=y;i++)
#define rep1(i,x,y) for(register int i=x;i>=y;--i)
#define int long long
#define fire signed
#define il inline
template<class T> il void print(T x) {
    if(x<0) printf("-"),x=-x;
    if (x > 9) print(x / 10);
    putchar(x % 10 + '0');
}
template<class T> il void in(T &x) {
    x = 0; char ch = getchar();
    int f = 1;
    while (ch < '0' || ch > '9') {if(ch=='-') f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }
    x *= f;
}
int T=1;
const int N=5e5+10,M=1e3;
struct node{
    int fl;
    int s[5];
}tr[N];
int idx=1;
vector<int>v[N];
int rt[N];
struct nn{
    int l,r;
    int cnt;
}t[N*20];
int tot;
vector<int>ve[N];
void up(int x) {
    t[x].cnt=t[t[x].l].cnt+t[t[x].r].cnt;
}
int modify(int u,int l,int r,int k) {
    int p=++tot;
    t[p]=t[u];
    if(l==r) {
        t[p].cnt++;
        return p;
    }
    int mid=l+r>>1;
    if(mid>=k) t[p].l=modify(t[u].l,l,mid,k);
    else t[p].r=modify(t[u].r,mid+1,r,k);
    up(p);
    return p;
}
int Ans(int u,int l,int r,int k) {
    if(k>r) return false;
    if(k<=l) return t[u].cnt;
    if(!u) return false;
    if(t[u].l==t[u].r) return t[u].cnt;
    int mid=l+r>>1;
    if(mid>=k) return t[t[u].r].cnt+Ans(t[u].l,l,mid,k);
    return Ans(t[u].r,mid+1,r,k);
}
int w[N];
void modify(string s,int w) {
    int p=1;
    for(auto to:s) {
        if(!tr[p].s[to-'a']) tr[p].s[to-'a']=++idx;
        p=tr[p].s[to-'a'];
    }
    ve[p].pb(w);
}
void get() {
    queue<int>q;
    rep(i,0,3) tr[0].s[i]=1;
    q.push(1);
    while(q.size()) {
        int x=q.front(),f=tr[x].fl;
        q.pop();
        rep(i,0,3) {
            int &v=tr[x].s[i];
            if(v) {
                tr[v].fl=tr[f].s[i];
                q.push(v);
            }else v=tr[f].s[i];
        }
    }
}
int dfn[N],tt;
int f[N][20];
void init() {
    rep(j,1,19) {
        rep(i,1,idx) {
            f[i][j]=f[f[i][j-1]][j-1];
        }
    }
}
int dep[N];
int lca(int x,int y) {
    if(dep[x]>dep[y]) swap(x,y);
    rep1(i,19,0) if(dep[f[y][i]]>=dep[x]) y=f[y][i];
    if(x==y) return x;
    rep1(i,19,0) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
    return f[x][0];
}
void dfs(int x,int fa) {
    dfn[x]=++tt;
    dep[x]=dep[fa]+1;
    f[x][0]=fa;
    for(auto to:v[x]) {
        if(to==fa) continue;
        rt[to]=rt[x];
        for(auto to1:ve[to]) rt[to]=modify(rt[to],1,M,to1);
        dfs(to,x);
    }
}
bool cmp(int x,int y) {
    return dfn[x]<dfn[y];
}
vector<int>g[N];
int tag[N];
void dfs1(int x,int fa) {
    for(auto to:g[x]) {
        if(to==fa) continue;
        dfs1(to,x);
        tag[x]+=tag[to];
    }
}
int cnt=0;
void dfs2(int x,int fa,int mid) {
    for(auto to:g[x]) {
        if(to==fa) continue;
        dfs2(to,x,mid);
        int now=(tag[to]+mid-1)/tag[to];
        cnt+=Ans(rt[to],1,M,now)-Ans(rt[x],1,M,now);
    }
}
int check(int x) {
    cnt=0;
    dfs2(1,0,x);
    return cnt;
}
int n,m;
void solve() {
    in(n),in(m);
    rep(i,1,n) {
        string s;
        cin>>s;
        int w;
        in(w);
        modify(s,w);
    }
    get();
    rep(i,1,idx) v[tr[i].fl].pb(i);
    dfs(1,0);
    init();
    while(m--) {
        string s;
        cin>>s;
        int k;
        in(k);
        int p=1;
        vector<int>arr;
        arr.pb(1);
        for(auto to:s) {
            p=tr[p].s[to-'a'];
            arr.pb(p);
            tag[p]++;
        }
        sort(arr.begin(),arr.end(),cmp);
        int len=arr.size()-1;
        rep(i,0,len-1) {
            arr.pb(lca(arr[i],arr[i+1]));
        }
        sort(arr.begin(),arr.end());
        arr.erase(unique(arr.begin(),arr.end()),arr.end());
        sort(arr.begin(),arr.end(),cmp);
        len=arr.size()-1;
        for(auto to:arr) g[to].clear();
        rep(i,0,len-1) {
            int dis=lca(arr[i],arr[i+1]);
            g[dis].clear();
        }
        rep(i,0,len-1) {
            int dis=lca(arr[i],arr[i+1]);
            g[dis].pb(arr[i+1]);
        }
        dfs1(1,0);
        int l=1,r=5e8,res=0;
        while(l<=r) {
            int mid=l+r>>1;
            if(check(mid)>=k) l=mid+1,res=mid;
            else r=mid-1;
        }
        for(auto to:arr) tag[to]=false;
        printf("%lld\n",res);
    }
}
fire main() {
    while(T--) {
        solve();
    }
    return false;
}

能不能让我过 csp t3 啊,求你了。