题解 P4770 【[NOI2018]你的名字】

· · 题解

Portal

我居然能独立AC NOI的黑题,后缀数组果然可爱

我们将命名串与所有的询问串,中间插上从未出现的字符然后放在一起跑SA。则对于来自询问串的某一条后缀i,我们需要求出一个数组res_i,表示所有出现在命名串以及该询问串中的后缀,与其\operatorname{LCP}的最大值。则对于一条询问串T,它的答案即为\sum\limits_{i\in T}len_i-res_i,其中len_i为后缀i的长度。

则现在问题被转换为求出所有的res_i。我们先考虑来自命名串的贡献。设当前我们讨论的后缀为suf[i]。则我们要求的即为区间[l,r]内所有后缀suf[j],求\min\Big(\operatorname{LCP}(suf[i],suf[j]),r-j+1\Big)的最大值。

我们考虑二分这个最大值,设为mid。现在要来判断mid这个值是否合法。则只有区间[l,l+mid-1]中的\operatorname{LCP}才可能达到这么长,故只要找到其中的\max\text{LCP},如果其长度大于等于mid,则mid合法。

到现在我们已经可以构思出一个O(n\log^2n)的二分套线段树的做法了。具体思路是,因为\max\text{LCP}一定在两个后缀的rk最接近时取到,所以我们将它拆成两半,一半是rk_j\leq rk_i的,一半是rk_j\geq rk_i的,并写两颗线段树分别维护。

则我们需要按顺序将位置插入线段树并统计答案,在前一棵中查询区间中rk_i的最大值,后一棵中查询rk_i的最小值(这里的线段树是以原串位置为下标的)。在查询到这个最大值/最小值后,就可以通过ST表求出\text{LCP}了。当然,这一切都是建立在二分的基础上,即,我们每次询问的区间都是二分出来的区间[l,l+mid-1]

但是这样子得写两棵线段树,再加上ST表,太难受了。我们不如这样,直接在线段树上维护\text{LCP}长度。办法很简单,当新加入一个位置后,直接将线段树中所有\text{LCP}长度与它取\min即可。这样,ST表可以省掉了,两颗线段树要支持的操作也一致了(全局取\min,区间求\max)。

在这样写后,我们发现干脆连二分都可以省掉了(因为这里线段树的对应位置储存的值就是真实答案,所以可以省掉),直接在线段树上二分即可(这部分线段树上二分的代码比较神奇,建议看一下代码)。

(不知道大家有没有做过[HEOI2016/TJOI2016]字符串这道题,实际上思路是差不多的)。

然后就是来自该询问串内部的其它后缀的\operatorname{LCP},这个也可以通过类似手法写出来。

则总复杂度O(n\log n)

代码:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=2001000;
int n,m,q,t,id[N],res[N],len[N];
ll ans[N];
int x[N],y[N],buc[N],sa[N],ht[N],rk[N],s[N];
char str[N];
bool mat(int a,int b,int k){
    if(y[a]!=y[b])return false;
    if((a+k<n)^(b+k<n))return false;
    if((a+k<n)&&(b+k<n))return y[a+k]==y[b+k];
    return true;
}
void SA(){
    for(int i=0;i<n;i++)buc[x[i]=s[i]]++;
    for(int i=1;i<=m;i++)buc[i]+=buc[i-1];
    for(int i=n-1;i>=0;i--)sa[--buc[x[i]]]=i;
    for(int k=1;k<n;k<<=1){
        int num=0;
        for(int i=n-k;i<n;i++)y[num++]=i;
        for(int i=0;i<n;i++)if(sa[i]>=k)y[num++]=sa[i]-k;
        for(int i=0;i<=m;i++)buc[i]=0;
        for(int i=0;i<n;i++)buc[x[y[i]]]++;
        for(int i=1;i<=m;i++)buc[i]+=buc[i-1];
        for(int i=n-1;i>=0;i--)sa[--buc[x[y[i]]]]=y[i];
        swap(x,y);
        x[sa[0]]=num=0;
        for(int i=1;i<n;i++)x[sa[i]]=mat(sa[i],sa[i-1],k)?num:++num;
        if(num>=n-1)break;
        m=num;
    }
    for(int i=0;i<n;i++)rk[sa[i]]=i;
    for(int i=0,k=0;i<n;i++){
        if(!rk[i])continue;
        if(k)k--;
        int j=sa[rk[i]-1];
        while(i+k<n&&j+k<n&&s[i+k]==s[j+k])k++;
        ht[rk[i]]=k;
    }
}
#define lson x<<1
#define rson x<<1|1
#define mid ((l+r)>>1)
#define change(x,y) seg[x].mn=min(seg[x].mn,y),seg[x].tag=min(seg[x].tag,y)
struct SegTree{
    int tag,mn;
}seg[N<<2];
void build(int x,int l,int r){
    seg[x].tag=0x3f3f3f3f,seg[x].mn=0;
    if(l!=r)build(lson,l,mid),build(rson,mid+1,r);
}
void pushdown(int x){
    change(lson,seg[x].tag),change(rson,seg[x].tag),seg[x].tag=0x3f3f3f3f;
}
void turnon(int x,int l,int r,int P,int val){
    if(l>P||r<P)return;
    seg[x].mn=max(seg[x].mn,val);
    pushdown(x);
    if(l!=r)turnon(lson,l,mid,P,val),turnon(rson,mid+1,r,P,val);
}
int getans(int x,int l,int r,int L,int R){
    if(l==r)return min(seg[x].mn,R-r+1);
    pushdown(x);
    if(seg[lson].mn>=R-mid+1)return getans(lson,l,mid,L,R);
    else return max(seg[lson].mn,getans(rson,mid+1,r,L,R));
}
int query(int x,int l,int r,int L,int R,bool &findans){
    if(l>R||r<L)return -1;
    if(L<=l&&r<=R){
        if(seg[x].mn>=R-r+1){findans=true;return getans(x,l,r,L,R);}
        return seg[x].mn;
    }
    pushdown(x);
    int tmp=query(lson,l,mid,L,R,findans);
    if(findans)return tmp;
    return max(tmp,query(rson,mid+1,r,L,R,findans));
}
int pointask(int x,int l,int r,int P){
    if(l>P||r<P)return 0;
    if(l==r)return seg[x].mn;
    pushdown(x);
    return pointask(lson,l,mid,P)+pointask(rson,mid+1,r,P);
}
pair<int,int>p[500100];
void read(int &x){
    x=0;
    char c=getchar();
    while(c>'9'||c<'0')c=getchar();
    while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=getchar(); 
}
int main(){
    scanf("%s",str),t=strlen(str);for(int i=0;i<t;i++)s[n++]=str[i]-'a'+1;
    read(q);
    for(int i=1;i<=q;i++){
        scanf("%s",str),read(p[i].first),read(p[i].second),m=strlen(str),p[i].first--,p[i].second--;
        id[n]=-1,s[n]=i+26,n++;
        for(int j=0;j<m;j++)id[n]=i,s[n]=str[j]-'a'+1,len[n]=m-j,n++;
    }
    m=q+26;
    SA();
    bool tmp;
    build(1,0,t-1);
    for(int i=1;i<n;i++){
        change(1,ht[i]);
        if(!id[sa[i-1]])turnon(1,0,t-1,sa[i-1],ht[i]);
        if(id[sa[i]]>=1)tmp=false,res[sa[i]]=max(res[sa[i]],query(1,0,t-1,p[id[sa[i]]].first,p[id[sa[i]]].second,tmp));
    }
    build(1,0,t-1);
    for(int i=n-1;i>=0;i--){
        if(id[sa[i]]>=1)tmp=false,res[sa[i]]=max(res[sa[i]],query(1,0,t-1,p[id[sa[i]]].first,p[id[sa[i]]].second,tmp));
        change(1,ht[i]);
        if(!id[sa[i]])turnon(1,0,t-1,sa[i],ht[i]);
    }
    build(1,1,q);
    for(int i=1;i<n;i++){
        change(1,ht[i]);
        if(id[sa[i-1]]>=1)turnon(1,1,q,id[sa[i-1]],ht[i]);
        if(id[sa[i]]>=1)res[sa[i]]=max(res[sa[i]],pointask(1,1,q,id[sa[i]]));
    }
    for(int i=0;i<n;i++)if(id[i]>=1)ans[id[i]]+=len[i]-res[i];
    for(int i=1;i<=q;i++)printf("%lld\n",ans[i]);
    return 0;
}