题解:P7943 「Wdcfr-1」CONsecutive and CONcat (hard version)

· · 题解

提供一个简单做法。

我们把贡献分为两种:自己内部的贡献和拼接产生的贡献。

第一部分是好求的。

只需要找到一个串内的所有长度 \ge k 的连续段,它们一定能贡献到所有排列里,所以贡献为 (len-k+1)\times n!

然后考虑第二部分。

算重后再减去太麻烦了,能不能直接计算?

为了不重不漏,我们计算以每个位置为起点,向后长度为 k 的区间的贡献

那么这个区间一定是形如一个后缀散块+若干整块+一个前缀散块的形式。

设后缀长度为 len,该区间对应字符的整块个数为 num,区间内整块数量为 num_1,该区间对应字符的长度为 len' 的前缀的块(且该块不能全为一种字符)的数量为 cnt_{len'}

我们需要对前缀散块分类讨论:

  1. 没有前缀散块。

此时方案数为 A_{num}^{num_1}\times (n-num_1)!

  1. 前缀散块所在的整块内的字符种类 >1

此时方案数为 A_{num}^{num_1}\times cnt_{k-len-num_1\times m}\times (n-num_1-1)!

  1. 前缀散块所在的整块内的字符种类 = 1

此时方案数为 A_{num}^{num_1+1}\times (n-num_1-1)!

注意:在算当前块作为后缀散块时,一定要把当前块从桶里去掉,否则会算上自己对自己的贡献,这是错误的。

复杂度 O(nm)

#include"bits/stdc++.h"
#define re register
#define int long long
using namespace std;
const int maxn=1e6+10,maxm=110,maxv=30,mod=998244353;
int n,m,k,ans;
int fac[maxn],inv[maxn];
int buc[maxv],num[maxv][maxm];
string s[maxn];
inline int qpow(int a,int b){
    int res=1;
    while(b){
        if(b&1) res=res*a%mod;
        b>>=1;
        a=a*a%mod;
    }
    return res;
}
inline int Inv(int x){return qpow(x,mod-2);}
inline void init(){
    fac[0]=1;
    for(re int i=1;i<=n*m;++i) fac[i]=fac[i-1]*i%mod;
    inv[n*m]=Inv(fac[n*m]);
    for(re int i=n*m-1;i>=0;--i) inv[i]=inv[i+1]*(i+1)%mod;
}
inline int A(int n,int m){if(n<m) return 0;return fac[n]%mod*inv[n-m]%mod;}
inline void add(string s){
    int cnt=1;char c=s[0];
    for(re int i=1;i<m;++i) cnt+=(s[i]==c);
    if(cnt==m) ++buc[c-'a'];
    else{
        int pre=1;
        for(re int i=1;i<m;++i){
            if(s[i]!=c) break;
            else ++pre;
        }
        for(re int i=1;i<=pre;++i) ++num[c-'a'][i];
    } 
}
inline void del(string s){
    int cnt=1;char c=s[0];
    for(re int i=1;i<m;++i) cnt+=(s[i]==c);
    if(cnt==m) --buc[c-'a'];
    else{
        int pre=1;
        for(re int i=1;i<m;++i){
            if(s[i]!=c) break;
            else ++pre;
        }
        for(re int i=1;i<=pre;++i) --num[c-'a'][i];
    } 
}
inline void solve(string s){
    int len=0;char lst=0;
    for(re int i=0;i<m;++i){
        if(s[i]!=lst){
            if(len>=k) ans=(ans+(len-k+1)*fac[n]%mod)%mod;
            lst=s[i],len=1;
        }
        else ++len;
    }
    if(len>=k) ans=(ans+(len-k+1)*fac[n]%mod)%mod;
}
signed main(){
#ifndef ONLINE_JUDGE
    freopen("1.in","r",stdin);
    freopen("1.out","w",stdout);
#endif
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    cin>>n>>m>>k;
    init();
    for(re int i=1;i<=n;++i){
        cin>>s[i];
        add(s[i]);solve(s[i]);
    }
    for(re int i=1;i<=n;++i){
        del(s[i]);
        char c=s[i][m-1];
        for(re int j=m-1,len,num1,len1;j>=0;--j){
            if(s[i][j]!=c) break;
            if(m-j>=k) continue;
            len=k-(m-j);num1=len/m;len1=len-num1*m;
            if(!len1) ans=(ans+A(buc[c-'a'],num1)%mod*fac[n-num1]%mod)%mod;
            else ans=(ans+A(buc[c-'a'],num1)%mod*num[c-'a'][len1]%mod*fac[n-num1-1]%mod+A(buc[c-'a'],num1+1)%mod*fac[n-num1-1]%mod)%mod;
        }
        add(s[i]);
    }
    cout<<ans;
    return 0;
}