题解 P6095 【[JSOI2015]串分割】

· · 题解

后缀数组一个很好的作用就是可以用它来按照字典序二分字符串~~~

首先,我们发现,一组切割的结果只会有长度为\left\lceil\dfrac{n}{k}\right\rceil\left\lceil\dfrac{n}{k}\right\rceil-1的串——这是很显然的,因为只有这样才会让最大数尽量小。并且,最大数的长度必定是\left\lceil\dfrac{n}{k}\right\rceil。设A=\left\lceil\dfrac{n}{k}\right\rceil

假设我们确定了最大数是p,那我们应该如何判断p是否合法呢?

显然我们可以随便找一个位置断环成链。我们设一个指针tmp指向当前最后一个子串结尾的位置(初始为断点)。如果tmp开头的后缀的字典序大于等于后缀p的字典序,显然这时如果我们选取tmp开头的A个字符的话,p就不会是最大的那个数了,故这时不得不选取A-1个数;否则,我们可以贪心地选取A个字符截成一个子串。之后,更新tmp到新子串的末尾。

我们思考一下为什么这个贪心是正确的:假如因为这边截了A个字符,下一次不得不截A-1个字符,则总共少截了一个字符,同这边截A-1个字符的最优情况(这次A-1个,下次A个)相同;而如果下次居然还能截出A个,那么就肯定更优辣;综上,贪心是正确的。

因此我们只需要枚举这个断开的位置即可。看上去这是O(nk)的,但是因为每A个字符中必定会切一刀,故实际上只需要枚举前A个位置即可。A个位置,check每个位置都要用k次,共O(Ak)=O(n)

至于这个最大值吗,显然它具有单调性,可以二分;二分一个串,就是在后缀数组中二分rank。

因此总复杂度O(n\log n)

代码:

#include<bits/stdc++.h>
using namespace std;
const int N=400100;
int n,m,all,A;
int x[N],y[N],buc[N],sa[N],rk[N];
char s[N];
bool mat(int a,int b,int k){
    if(y[a]!=y[b])return false;
    if((a+k<n)^(b+k<n))return false;
    if((a+k<n)&&(b+k<n))return y[a+k]==y[b+k];
    return true;
}
void SA(){
    for(int i=0;i<n;i++)buc[x[i]=s[i]]++;
    for(int i=1;i<=m;i++)buc[i]+=buc[i-1];
    for(int i=n-1;i>=0;i--)sa[--buc[x[i]]]=i;
    for(int k=1;k<n;k<<=1){
        int num=0;
        for(int i=n-k;i<n;i++)y[num++]=i;
        for(int i=0;i<n;i++)if(sa[i]>=k)y[num++]=sa[i]-k;
        for(int i=0;i<=m;i++)buc[i]=0;
        for(int i=0;i<n;i++)buc[x[y[i]]]++;
        for(int i=1;i<=m;i++)buc[i]+=buc[i-1];
        for(int i=n-1;i>=0;i--)sa[--buc[x[y[i]]]]=y[i];
        swap(x,y);
        x[sa[0]]=num=0;
        for(int i=1;i<n;i++)x[sa[i]]=mat(sa[i],sa[i-1],k)?num:++num;
        if(num>=n-1)break;
        m=num;
    }
    for(int i=0;i<n;i++)rk[sa[i]]=i;
}
bool che(int ip){
    for(int i=0;i<A;i++){
        int tmp=i;
        for(int j=0;j<all;j++){tmp+=A-(rk[tmp]>ip);if(tmp-i>=n)return true;}
    }
    return false;
}
int main(){
    scanf("%d%d",&n,&all),A=n/all+(n%all!=0);
    scanf("%s",s);
    for(int i=0;i<n;i++)s[n+i]=s[i];
    n<<=1,m='9';
    SA();
//  printf("%s\n",s);
//  for(int i=0;i<n;i++)printf("%d ",sa[i]);puts("");
    n>>=1;
//  for(int i=0;i<n;i++)printf("%d ",qwq[i]);puts("");
    int l=0,r=(n<<1)-1;
    while(l<r){
        int mid=(l+r)>>1;
        if(che(mid))r=mid;
        else l=mid+1;
    }
    for(int j=0;j<(n<<1);j++)if(rk[j]==r)for(int i=0;i<A;i++)putchar(s[j+i]);
    return 0;
}