【CF356E】Xenia and String Problem

· · 题解

传送门

上古时代水 3000 黑题,实际 2500 紫差不多。

如果我们设 a_i 是第 i 长的 gray 串的长度,不难得到:

\begin{cases}a_i=2a_{i-1}-1,i>1 \\ a_1=1\end{cases}

显然得 a_i=2^i-1.

也就意味着 gray 串的长度取值只有 O(\log n) 种,总的串数量为 O(n\log n) 级别的。

瞬间就会发现可以 O(n\log n) 用倍增+哈希+桶的形式计算出所有 gray 串。

考虑枚举修改的位置和修改后该位置的字符,一共有 26n 种方案。此时首先发现我们应该需要计算原串中所有跟第 i 个位置有关的 gray 串的权值之和,还需要知道当第 i 个字符换成字符 j 的时候新产生的 gray 串权值之和。分别设为 w(i),g(i,j).

显然考虑 w(i) 会简单一点:考虑 i 所在的所有 gray 串发现不是很好搞。换个方向,枚举所有 gray 串,假设 s[i...j] 是一个 gray 串。那么 g(i)g(j) 都应该加上 (j-i+1)^2. 静态区间修改上差分,于是在 O(n \log n) 的时间内解决了这部分。

然后考虑计算 g(i,j). 我们同样还是反过来算贡献。枚举 O(n\log n) 个长度为 2^k-1 的区间 [i,j],考虑如果它为 gray 串可以通过修改哪些位置得到,这些位置的对应字符答案增加。第一眼,似乎不能快速确定哪些位置修改后合法。分讨一下修改中间位置和修改两边:

显然这部分 g 的计算是 O(26n\log n) 的。

最后求 \max\{sum-w(i)+g(i,j)\} 即可,sum 是原串的美丽值,可以在开始算出所有 gray 串的时候顺带搞出来。

所以总复杂度为 O(26n\log n).

#include<bits/stdc++.h>
#define rep(i,a,b) for(ll i=(a);i<=(b);i++)
#define per(i,a,b) for(ll i=(a);i>=(b);i--)
#define op(x) ((x&1)?x+1:x-1)
#define odd(x) (x&1)
#define even(x) (!odd(x))
#define lc(x) (x<<1)
#define rc(x) (lc(x)|1)
#define lowbit(x) (x&-x)
#define Max(a,b) (a>b?a:b)
#define Min(a,b) (a<b?a:b)
#define next Cry_For_theMoon
#define il inline
#define pb(x) push_back(x)
#define is(x) insert(x)
#define sit set<int>::iterator
#define mapit map<int,int>::iterator
#define pi pair<int,int>
#define ppi pair<int,pi>
#define pp pair<pi,pi>
#define fr first
#define se second
#define vit vector<int>::iterator
#define mp(x,y) make_pair(x,y)
typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int uint;
typedef double db;
using namespace std;
const int MAXN=1e5+10,base=13331;
int n,valid[MAXN][27];
int bucket[MAXN][27]; //前缀桶 
char s[MAXN];
ull power[MAXN],preh[MAXN]; //前缀hash值
ll w[MAXN],g[MAXN][27]; //原有贡献,修改增加贡献
ll ans; 
il ull gethash(int i,int j){
    //hash(S[i..j])
    if(i>j)return 0;
    return preh[j]-preh[i-1]*power[j-i+1];
}
il int getcnt(int i,int j,char ch){
    //ch在S[i..j]中的出现次数
    if(i>j)return 0;
    return bucket[j][ch-'a'+1]-bucket[i-1][ch-'a'+1]; 
}
int main(){
    scanf("%s",s+1);n=strlen(s+1);
    power[0]=1;
    rep(i,1,n){
        preh[i]=preh[i-1]*base+(s[i]-'a'+1);
        power[i]=power[i-1]*base;
        rep(j,1,26)bucket[i][j]=bucket[i-1][j];
        bucket[i][s[i]-'a'+1]++;
    }
    //预处理所有gray串
    rep(i,1,n)valid[i][1]=1;
    rep(k,2,29){
        rep(i,1,n){
            if(i+(1<<k)-2>n)break;
            int j=i+(1<<k)-2;int mid=(i+j)/2; 
            valid[i][k]=(getcnt(i,j,s[mid])==1) && (valid[i][k-1] && valid[mid+1][k-1]) && 
            (gethash(i,mid-1)==gethash(mid+1,j));
        }
    } 
    //计算w
    rep(i,1,n){
        rep(k,1,19){
            int len=(1<<k)-1;
            int j=i+len-1;
            if(valid[i][k])ans+=(ll)len*len,w[i]+=(ll)len*len,w[j+1]-=(ll)len*len;
        }
    } 
    rep(i,1,n)w[i]+=w[i-1];
    //计算g 
    rep(i,1,n){
        rep(j,1,26){g[i][j]++;} //长度为1的字符 
        rep(k,2,19){
            int len=(1<<k)-1;int j=i+len-1;int mid=(i+j)/2;
            if(j>n)break;
            //修改s[i...j]使其变成gray串 
            if(gethash(i,mid-1)==gethash(mid+1,j)){
                //改中间 
                if(valid[i][k-1] && valid[mid+1][k-1]){
                    rep(x,1,26){
                        if(!getcnt(i,mid-1,'a'+x-1))g[mid][x]+=(ll)len*len;
                    } 
                }
            }else{
                //改两边的唯一不同字符
                //二分查找这个位置
                int L=1,R=(1<<(k-1))-1,ret=0;
                while(L<=R){
                    int MID=(L+R)>>1;
                    if(gethash(i,i+MID-1)!=gethash(mid+1,mid+MID)){
                        ret=MID;
                        R=MID-1;
                    }else{
                        L=MID+1;
                    }
                }
                //s[i+ret-1]和s[mid+ret]不同 
                if(gethash(i+ret,mid-1) != gethash(mid+ret+1,j))continue;

                //1.左改右
                if(valid[mid+1][k-1]){
                    if(!getcnt(mid+1,j,s[mid]))g[i+ret-1][s[mid+ret]-'a'+1]+=(ll)len*len;
                } 
                //2.右改左
                if(valid[i][k-1]){
                    if(!getcnt(i,mid-1,s[mid]))g[mid+ret][s[i+ret-1]-'a'+1]+=(ll)len*len;
                } 
            }
        }
    } 
    ll t=0;
    rep(i,1,n){
        ll tmp=ans;
        rep(j,1,26)tmp=Max(tmp,ans-w[i]+g[i][j]);
        t=Max(t,tmp);
    }
    ans=t;
    printf("%lld",ans);
    return 0;
}