题解 P4324 【[JSOI2016]扭动的回文串】

2018-11-04 16:38:45


这题网上的题解都是manacher+Hash+二分,复杂度为$n\log n$的。(如果只是大力Hash+二分不也是$n\log n$的吗

1、网上其他题解的思路:

对$A$、$B$串各跑一遍manacher,求出第1、2类扭动回文串的最大长度。

考虑第三类的扭动回文串$S(i,j,k)$,一定可以表示为$A(i,l)+A(l+1,j)+B(j,k)$或$A(i,j)+B(j,l)+B(l+1,k)$,其中,第一段与第三段对称(第一段正着Hash和第三段反着Hash相同),第二段是一个回文子串,三段都可以是空串。

在$A$、$B$上枚举扭动的回文串的中心$mid$,不难发现,以其在原串上能扩展出的最长回文子串为第二段进行扩展一定最优。对每个$mid$,以最长回文子串作为第二段,在$A$、$B$上二分第一、三段的长度,其结果一定最优。

枚举中心$O(n)$,二分扩展$O(\log n)$,时间复杂度$O(n\log n)$。

2、不用manacher

其实是因为我总记不住manacher怎么写,而且manacher经常可以用Hash(比manacher多个$log$)或回文树替代

求以mid为中心的最长回文子串时,用二分+Hash替代manacher。常数一下就大了好多,不开O2光荣TLE,开O2就能过了。

#include<cstdio>
#include<utility>
#include<algorithm>
#define LL long long
#define PLL std::pair<LL,LL>
#define BASE 131
#define MOD0 1000000007
#define MOD1 1000000009
#define MODP PLL(MOD0,MOD1)
#define PBASE PLL(BASE,BASE)

PLL operator * (PLL a,PLL b)
{
    return PLL(a.first*b.first,a.second*b.second);
}

PLL operator + (PLL a,PLL b)
{
    return PLL(a.first+b.first,a.second+b.second);
}

PLL operator - (PLL a,PLL b)
{
    return PLL(a.first-b.first,a.second-b.second);
}

PLL operator % (PLL a,PLL b)
{
    return PLL(a.first%b.first,a.second%b.second);
}

PLL base_pow[200005];

void init_base_pow(int len)
{
    base_pow[0].first=base_pow[0].second=1;
    for(int i=1; i<=len; ++i)
    {
        base_pow[i]=base_pow[i-1]*PBASE%MODP;
    }
}

char input[100005];

struct str
{
    char s[200005];
    PLL hash[2][200005];
    int pal_len[200005];
    void insert_$(char T[],int &len)
	{
		s[1]='$';
        for(int i=1; i<=len; ++i)
        {
            s[i<<1]=T[i];
            s[i<<1|1]='$';
		}
		len=len<<1|1;
	}

	void get_hash(int len)
	{
		for(int i=1; i<=len; ++i)
		{
			hash[0][i]=(hash[0][i-1]*PBASE+PLL(s[i]+1,s[i]+1))%MODP;
		}
		for(int i=len; i>=1; --i)
		{
			hash[1][i]=(hash[1][i+1]*PBASE+PLL(s[i]+1,s[i]+1))%MODP;
		}
	}

	PLL seg_hash(int l,int r)
	{
		if(l>=r)
		{
			return ((hash[1][r]-hash[1][l+1]*base_pow[l-r+1])%MODP+MODP)%MODP;
		}
		else
		{
			return ((hash[0][r]-hash[0][l-1]*base_pow[r-l+1])%MODP+MODP)%MODP;
		}
	}

	void find_pal(int len)
	{
		for(int i=1; i<=len; ++i)
		{
			int l_len=i,r_len=len-i+1;
			int l=1,r=std::min(l_len,r_len)+1,mid=(l+r)>>1;
			while(l<r)
			{
				if(seg_hash(i,i+mid-1)!=seg_hash(i,i-mid+1))
				{
					r=mid;
				}
				else
				{
					l=mid+1;
				}
				mid=(l+r)>>1;
			}
			pal_len[i]=(mid-1)*2-1;
		}
	}
	
	void init(int len)
	{
		scanf("%s",input+1);
		insert_$(input,len);
        get_hash(len);
        find_pal(len);
    }
} A,B;

int main()
{
    int n;
    scanf("%d",&n);
    init_base_pow(n<<1|1);
    A.init(n);
    B.init(n);
    int ans=0;
    n=n<<1|1;
    for(int i=1;i<=n;++i)
    {
        int l_st=i-(A.pal_len[i]>>1)-1;
        int r_st=i+(A.pal_len[i]>>1)-1;
        int l=1,r=std::min(l_st,n-r_st+1)+1,mid=(l+r)>>1;
        while(l<r)
        {
            if(A.seg_hash(l_st,l_st-mid+1)!=B.seg_hash(r_st,r_st+mid-1))
            {
                r=mid;
            }
            else
            {
                l=mid+1;
            }
            mid=(l+r)>>1;
        }
        ans=std::max(ans,mid-1+(A.pal_len[i]>>1));
    }
    for(int i=1;i<=n;++i)
    {
        int l_st=i-(B.pal_len[i]>>1)+1;
        int r_st=i+(B.pal_len[i]>>1)+1;
        int l=1,r=std::min(l_st,n-r_st+1)+1,mid=(l+r)>>1;
        while(l<r)
        {
            if(A.seg_hash(l_st,l_st-mid+1)!=B.seg_hash(r_st,r_st+mid-1))
            {
                r=mid;
            }
            else
            {
                l=mid+1;
            }
            mid=(l+r)>>1;
        }
        ans=std::max(ans,mid-1+(B.pal_len[i]>>1));
    }
    printf("%d",ans);
    return 0;
}