题解:P7943 「Wdcfr-1」CONsecutive and CONcat (hard version)
提供一个简单做法。
我们把贡献分为两种:自己内部的贡献和拼接产生的贡献。
第一部分是好求的。
只需要找到一个串内的所有长度
然后考虑第二部分。
算重后再减去太麻烦了,能不能直接计算?
为了不重不漏,我们计算以每个位置为起点,向后长度为
那么这个区间一定是形如一个后缀散块+若干整块+一个前缀散块的形式。
设后缀长度为
我们需要对前缀散块分类讨论:
- 没有前缀散块。
此时方案数为
- 前缀散块所在的整块内的字符种类
>1
此时方案数为
- 前缀散块所在的整块内的字符种类
= 1
此时方案数为
注意:在算当前块作为后缀散块时,一定要把当前块从桶里去掉,否则会算上自己对自己的贡献,这是错误的。
复杂度
#include"bits/stdc++.h"
#define re register
#define int long long
using namespace std;
const int maxn=1e6+10,maxm=110,maxv=30,mod=998244353;
int n,m,k,ans;
int fac[maxn],inv[maxn];
int buc[maxv],num[maxv][maxm];
string s[maxn];
inline int qpow(int a,int b){
int res=1;
while(b){
if(b&1) res=res*a%mod;
b>>=1;
a=a*a%mod;
}
return res;
}
inline int Inv(int x){return qpow(x,mod-2);}
inline void init(){
fac[0]=1;
for(re int i=1;i<=n*m;++i) fac[i]=fac[i-1]*i%mod;
inv[n*m]=Inv(fac[n*m]);
for(re int i=n*m-1;i>=0;--i) inv[i]=inv[i+1]*(i+1)%mod;
}
inline int A(int n,int m){if(n<m) return 0;return fac[n]%mod*inv[n-m]%mod;}
inline void add(string s){
int cnt=1;char c=s[0];
for(re int i=1;i<m;++i) cnt+=(s[i]==c);
if(cnt==m) ++buc[c-'a'];
else{
int pre=1;
for(re int i=1;i<m;++i){
if(s[i]!=c) break;
else ++pre;
}
for(re int i=1;i<=pre;++i) ++num[c-'a'][i];
}
}
inline void del(string s){
int cnt=1;char c=s[0];
for(re int i=1;i<m;++i) cnt+=(s[i]==c);
if(cnt==m) --buc[c-'a'];
else{
int pre=1;
for(re int i=1;i<m;++i){
if(s[i]!=c) break;
else ++pre;
}
for(re int i=1;i<=pre;++i) --num[c-'a'][i];
}
}
inline void solve(string s){
int len=0;char lst=0;
for(re int i=0;i<m;++i){
if(s[i]!=lst){
if(len>=k) ans=(ans+(len-k+1)*fac[n]%mod)%mod;
lst=s[i],len=1;
}
else ++len;
}
if(len>=k) ans=(ans+(len-k+1)*fac[n]%mod)%mod;
}
signed main(){
#ifndef ONLINE_JUDGE
freopen("1.in","r",stdin);
freopen("1.out","w",stdout);
#endif
ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
cin>>n>>m>>k;
init();
for(re int i=1;i<=n;++i){
cin>>s[i];
add(s[i]);solve(s[i]);
}
for(re int i=1;i<=n;++i){
del(s[i]);
char c=s[i][m-1];
for(re int j=m-1,len,num1,len1;j>=0;--j){
if(s[i][j]!=c) break;
if(m-j>=k) continue;
len=k-(m-j);num1=len/m;len1=len-num1*m;
if(!len1) ans=(ans+A(buc[c-'a'],num1)%mod*fac[n-num1]%mod)%mod;
else ans=(ans+A(buc[c-'a'],num1)%mod*num[c-'a'][len1]%mod*fac[n-num1-1]%mod+A(buc[c-'a'],num1+1)%mod*fac[n-num1-1]%mod)%mod;
}
add(s[i]);
}
cout<<ans;
return 0;
}