题解:P10593 BZOJ2958 序列染色

· · 题解

好题。

为什么只有绿?

题意

给定一个长度为 n 的只包含 'W','X','B' 的字符串,其中 'X' 必须变成 'W' 或 'B'。

给定 k,如果一个字符串中存在长为 k 的都为 'W' 的连续段,并且这个连续段之前存在长为 k 的都为 'B' 的连续段,那么这个字符串就是合法的。

求将所有 'X' 变成 'W' 或 'B' 后合法字符串个数。

思路

先考虑 'B' 段。

想一想怎么算不会算重。

我们只在每一个长为 k 的 'B' 段第一次出现时考虑,这样就不会算重了。

于是我们设 f_i 表示考虑前 i 个位置,其中 i 这个位置是字符串中第一个长度为 k 的连续 'B' 段的结尾,这样的字符串的方案数。

怎么转移呢?

先判断 i-k+1\sim i 是否可以全变成 'B',并且要求 i-k 位置不是 'B',前 i-k 个位置中不存在长为 k 的 'B' 段。

于是设 g_i 表示考虑前 i 个位置,并且前 i 个位置没有长为 k 的 'B' 段的方案数。

于是 f 就能从 g 转移了。

如果满足上述条件并且 i-k 位置为 'X',此时这个 'X' 必须变成 'W',那么:

f_i=g_{i-k-1}

否则:

f_i=g_{i-k}

g 怎么算?

其实 g 也可以从 f 转移过来的。

我们可以算出到当前位置为止的总方案数,再减去所有合法的即可。

也就是:

g_i=2^{prex_i}-\sum_{j=1}^{i}f_j

前缀和优化一下即可。

但细想一下,其实这里会出问题的,本人就被卡了好久。

想明白了么?

于是在算前缀和的时候多处理一下就行了。 这样对于 'B' 段的dp就做完了。 同理,对于 'W' 段从后往前做一次即可。 怎么算答案? 现在设 'B' 段的 $f$ 数组为 $f1$,'W' 段的 $f$ 数组为 $f2$。 $ans=\sum_{i=1}^{n}\sum_{j=i+1}^{n}f1_i\times f2_j\times 2^{prex_j-prex_{i-1}}

这是 O(n^2) 的,但是可优化到 O(n)

枚举 j,记录一个 sum 即可。

代码

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e6+10,mod=1e9+7;
int n,f1[N],f2[N],g1[N],g2[N],k,s1[N],s2[N],pw[N],prex[N],lstw[N],sufx[N],nxtb[N],ans;
string c;
signed main(){
#ifndef ONLINE_JUDGE
    freopen("in.in","r",stdin);
    freopen("out.out","w",stdout);
#endif
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    cin>>n>>k>>c;
    pw[0]=1;
    for(int i=1;i<=n;i++) pw[i]=pw[i-1]*2%mod;
    c=' '+c;
    int lst=0;
    for(int i=1;i<=n;i++){
        if(c[i]=='W') lst=i;
        prex[i]=prex[i-1]+(c[i]=='X');
        lstw[i]=lst;
    }
    for(int i=0;i<k;i++) g1[i]=pw[prex[i]];
    for(int i=k;i<=n;i++){
        s1[i]=(s1[i-1]*(c[i]=='X'?2:1))%mod;
        if(i-lstw[i]>=k){
            if(c[i-k]!='B'){
                if(c[i-k]=='X') f1[i]=g1[i-k-1];
                else f1[i]=g1[i-k];
                (s1[i]+=f1[i])%=mod;
            }
        }
        g1[i]=(pw[prex[i]]-s1[i]+mod)%mod;
    }
    int nxt=n+1;
    for(int i=n;i>=1;i--){
        if(c[i]=='B') nxt=i;
        sufx[i]=sufx[i+1]+(c[i]=='X');
        nxtb[i]=nxt;
    }
    for(int i=n+1;i>n-k+1;i--) g2[i]=pw[sufx[i]];
    for(int i=n-k+1;i>=1;i--){
        s2[i]=s2[i+1]*(c[i]=='X'?2:1)%mod;
        if(nxtb[i]-i>=k){
            if(c[i+k]!='W'){
                if(c[i+k]=='X') f2[i]=g2[i+k+1];
                else f2[i]=g2[i+k];
                (s2[i]+=f2[i])%=mod;
            }
        }
        g2[i]=(pw[sufx[i]]-s2[i]+mod)%mod;
    }
    int sum=0;
    for(int j=1;j<=n;j++){
        (ans+=sum*f2[j]%mod)%=mod;
        (sum*=(c[j]=='X'?2:1))%=mod;
        (sum+=f1[j])%=mod;
    }
    cout<<ans<<'\n';
}