题解 P6629 【[ZJOI2020] 字符串】

· · 题解

首先,根据runs的那套理论,位置不同的本原平方串(即最小整周期恰为 len \over 2 的串)是 O(n \log n) 的,显然所有平方串都是由某一个本原平方串复制若干遍得到的,然后我们先用传统的枚举周期SA/二分hash的方法求出若干的周期三元组 (x,l,r) ,代表在区间 [l,r] 中任意一个长为 x 的倍数的子串整周期均为 x ,然后如果串 [l,l+2x-1] 是循环串,则 x 不为最小整周期,那么这个周期三元组可以直接去掉。

然后我们考虑这一个周期三元组的贡献,我们记贡献组 (x,y,v) 表示它对满足 [x,y] \in [l',r'] 的询问区间 [l',r']v 的贡献。那么对于一个周期三元组 (x,l,r) ,长度为 2x 的平方串贡献就为 (a,a+2x-1,1) , a \in [l,r-2x+1](a,a+3x-1,-1) , a \in [l,r-3x+1] ,同理,长度为 2kx 的平方串贡献为 (a,a+2kx-1,1) , a \in [l,r-2kx+1](a,a+(2k+1)x-1,-1) , a \in [l,r-(2k+1)x+1] ,根据runs那套理论,这显然是不超过 O(n \log n) 的。

还有一种情况就是不同周期三元组之间可能产生相同的平方串,这种情况的处理是容易的:首先你从前到后考虑每个周期三元组,然后对每个周期三元组找出它包含的所有长度为 2x 倍数的平方串(这和位置不同的本原平方串数量相等),然后找出它在这之前最后出现在哪里(在处理前面的时候记一下即可)还有在这个周期三元组第一次出现在哪里,然后把重复的贡献减掉即可。

然后对每个询问计算贡献就是一个斜线加矩形求和的二维数点,容易实现。

复杂度 O(n \log^2 n)

代码(二分hash版本)

#include<bits/stdc++.h>
#define re register
#define ull unsigned long long
char ss[200100];
const ull base=61,inv=5745707170499696405llu;
ull hs[200100],ex[200100],rex[200100];
int n;
inline ull get(re int l,re int r)
{
    return hs[r]-hs[l-1]*ex[r-l+1];
}
inline bool equ(re int a,re int b,re int len)
{
    return get(a,a+len-1)==get(b,b+len-1);
}
int LCP(re int a,re int b)
{
    re int l=0,r=std::min(n-a+1,n-b+1);
    for(re int mid=(l+r+1)/2;l<r;mid=(l+r+1)/2)
    {
        if(equ(a,b,mid))l=mid;
        else r=mid-1;
    }
    return l;
}
int LCS(re int a,re int b)
{
    re int l=0,r=std::min(a,b);
    for(re int mid=(l+r+1)/2;l<r;mid=(l+r+1)/2)
    {
        if(equ(a-mid+1,b-mid+1,mid))l=mid;
        else r=mid-1;
    }
    return l;
}
std::vector<std::pair<int,int> >vc[200100],mc[200100];
std::vector<int>qry[400100];
int cnt,l[200100],r[200100],f[200100];
long long ans[200100];
std::unordered_map<ull,int>mp;
void add(re int x,re int ad)
{
    for(;x;x-=x&-x)f[x]+=ad;
}
struct par
{
    int x,y,v;
}pc1[10001000],pc2[10001000];
inline bool cmp(const par&A,const par&B){return A.x<B.x;}
int cct1,cct2; 
int qr(re int x)
{
    re int ans=0;
    for(;x<=n;x+=x&-x)ans+=f[x];
    return ans;
}
int ans1;
void adp(re int x1,re int x2,re int y1,re int y2,re int v)
{
    pc1[++cct1]={y1-x1,x1,v};
    pc1[++cct1]={y1-x1,x2+1,-v};
    pc2[++cct2]={y1-x1,y1,v};
    pc2[++cct2]={y1-x1,y2+1,-v};
}
int f1[200100],f2[200100];
void add1(re int a,re int ad)
{
    re int ad1=ad*a;
    for(;a<=n;a+=(a&-a))f1[a]+=ad,f2[a]+=ad1;//printf("**ad**%d\n",n);
}
long long sum(re int a)
{
    re int ans=0;
    re int a1=a;
    for(;a;a-=(a&-a))ans+=(a1+1)*f1[a]-f2[a];
    return ans;
}
int main()
{
    re int q,x,y;
    scanf("%d%d",&n,&q);
    scanf("%s",ss+1);
    ex[0]=rex[0]=1;
    for(re int i=1;i<=n;i++)hs[i]=hs[i-1]*base+(ss[i]-'a'+1),ex[i]=ex[i-1]*base,rex[i]=rex[i-1]*inv;
    for(re int i=1;i<=n;i++)
    {
        for(re int j=i;j+i<=n;j+=i)
        {
            re int ls=LCS(j,j+i),lp=LCP(j+1,j+i+1);
            if(ls==0||ls>i||ls+lp<i)continue;
            re int bg=j-ls+1,ed=j+i+lp;
            for(auto x:vc[bg])if((!(i%x.first))&&x.second>=i){ls=0;break;}
            if(ls==0)continue;
            for(re int i1=bg;i1+2*i-1<=ed;i1++)
            {
                vc[i1].push_back(std::make_pair(i,ed-i1+1));
                re ull nw=get(i1,i1+2*i-1);
                re int nv;
                if(!mp[nw])mp[nw]=++cnt;
                nv=mp[nw];
                re int len=(i1-bg)/i;
                if(!(len&1))
                {
                    len=len/2;
                    if(mc[nv].size()>len)
                    {
                        adp(mc[nv][len].first,mc[nv][len].first,i1+2*i-1,i1+2*i-1,-1);
                        mc[nv][len]=std::make_pair(i1,i1+2*i-1);
                    }
                }
            }
            for(re int i1=ed-2*i+1;i1>=bg;i1--)
            {
                re ull nw=get(i1,i1+2*i-1);
                re int nv;
                if(!mp[nw])mp[nw]=++cnt;
                nv=mp[nw];
                re int len=(ed-2*i+1-i1)/i;
                if(!(len&1))
                {
                    len=len/2;
                    if(mc[nv].size()>len)
                    {
                        mc[nv][len]=std::make_pair(i1,i1+2*i-1);
                    }
                    else mc[nv].push_back(std::make_pair(i1,i1+2*i-1));
                }
            }
            for(re int i1=0,vv=1;bg+i1+2*i-1<=ed;i1+=i,vv=-vv)
            {
                adp(bg,ed-i1-2*i+1,bg+i1+2*i-1,ed,vv);
            }
        }
    }
    for(re int i=1;i<=q;i++)
    {
        scanf("%d%d",&l[i],&r[i]);
        qry[r[i]-l[i]].push_back(i);
    }
    std::sort(pc1+1,pc1+cct1+1,cmp);
    std::sort(pc2+1,pc2+cct2+1,cmp);
    re int ta=1;
    for(re int i=0;i<=200000;i++)
    {
        while(ta<=cct1&&pc1[ta].x==i)
        {
        //  printf("**a**%d %d %d\n",i,pc1[ta].y,pc1[ta].v);
            add1(pc1[ta].y,pc1[ta].v);
            ta++;
        }
        for(auto x:qry[i])
        {
            ans[x]-=sum(l[x]-1);
        }
    }
    ta=1;
    memset(f1,0,sizeof(f1));
    memset(f2,0,sizeof(f2));
    for(re int i=0;i<=200000;i++)
    {
        while(ta<=cct2&&pc2[ta].x==i)
        {
            add1(pc2[ta].y,pc2[ta].v);
            ta++;
        }
        for(auto x:qry[i])
        {
            ans[x]+=sum(r[x]);
        }
    }
    for(re int i=1;i<=q;i++)printf("%lld\n",ans[i]);
}