题解 AT4169 【[ARC100D] Colorful Sequences】
个人认为应该是全网写的最详细的一篇了…吧?
如果题解界面格式崩了请直接来blog。
给定一个长为
m 的序列,保证每个数都\leq k 。同时定义,如果一个全部元素均\leq x 的序列中存在一个长度为x 子序列,恰好是1\sim x 的排列,那么称这个序列为 「x- 好序列」。给定长度
n 。求所有长度为n 的k- 好序列共包含了多少个长度为m 的子序列。
记这个给定的序列为
首先发现,如果暴力计数统计的话似乎不是很简单,可能只允许状压。所以正难则反( trick1) ,考虑不加 「好序列」的限制,答案就是平凡的
接下来需要分类讨论:
1、
那么所有包含
2、
此时需要注意的是,绝对不会有跨过
具体一点。考虑记
第一个转移显然是找一个新的元素放进来,第二个转移则表示永远可以选一个
考虑如何在刷表的时候维护后缀和(trick2):假设当前状态
转移的时候保证不让
扯回正题。发现计算这个东西,如果一开始在
3、
一个很神仙的点。考虑不含重复元素,可以转化成随便一个长度为 2、 中的差不多,也是「记
然后就没有然后了。注意可能存在
const int K = 410 ;
const int N = 30010 ;
const int P = 1000000007 ;
ll ans ;
ll X[N] ;
ll Y[N] ;
ll I[N] ;
ll fac[N] ;
ll g[N][K] ;
ll f[N][K] ;
int buc[N] ;
int n, k, m ;
int base[N] ;
void add(ll &x, ll y){
(x += y) %= P ;
}
void dec(ll &x, ll y){
(x -= y) %= P ;
if (x < 0) x += P ;
}
ll expow(ll a, ll b){
ll res = 1 ;
while (b){
if (b & 1)
(res *= a) %= P ;
(a *= a) %= P ; b >>= 1 ;
}
return res ;
}
bool check(){
int now = 0, j = 1 ;
for (int i = 1 ; i <= m ; ++ i){
if (!buc[base[i]]) now ++ ; buc[base[i]] ++ ;
while (buc[base[j]] > 1) -- buc[base[j ++]] ;
if (i - j + 1 == k && now == k) return 1 ;
}
return 0 ;
}
bool check2(){
memset(buc, 0, sizeof(buc)) ;
for (int i = 1 ; i <= m ; ++ i)
if (buc[base[i]]) return 1 ; else ++ buc[base[i]] ;
return 0 ;
}
ll dp1(){
g[0][0] = 1 ; ll ret = 0 ;
for (int i = 0 ; i < n ; ++ i){
for (int p, q, j = 0 ; j < k ; ++ j){
p = f[i][j] * (k - j) % P ;
q = g[i][j] * (k - j) % P ;
if (j + 1 < k){
add(f[i + 1][j + 1], p) ;
add(g[i + 1][j + 1], q) ;
dec(f[i + 1][j + 2], p) ;
dec(g[i + 1][j + 2], q) ;
}
add(f[i + 1][1], f[i][j]) ;
add(g[i + 1][1], g[i][j]) ;
dec(f[i + 1][j + 1], f[i][j]) ;
dec(g[i + 1][j + 1], g[i][j]) ;
}
for (int j = 1 ; j < k ; ++ j){
add(g[i + 1][j], g[i + 1][j - 1]) ;
add(f[i + 1][j], f[i + 1][j - 1]) ;
}
for (int j = m ; j < k ; ++ j)
add(f[i + 1][j], g[i + 1][j]) ;
}
for (int i = 1 ; i < k ; ++ i) add(ret, f[n][i]) ;
return ret ;
}
void dp2(ll res[N], ll s[N][K], int mk){
memset(buc, 0, sizeof(buc)) ;
if (!mk){
for (int i = 1 ; i <= m ; ++ i)
if (!buc[base[i]]) buc[base[i]] = 1 ;
else { s[0][i - 1] = 1 ; break ; }
}
else {
for (int i = m ; i >= 1 ; -- i)
if (!buc[base[i]]) buc[base[i]] = 1 ;
else { s[0][m - i] = 1 ; break ; }
}
for (int i = 0 ; i < n - m ; ++ i){
for (int t, j = 1 ; j < k ; ++ j){
t = s[i][j] * (k - j) % P ;
if (j + 1 < k){
add(s[i + 1][j + 1], t) ;
dec(s[i + 1][j + 2], t) ;
}
add(s[i + 1][1], s[i][j]) ;
dec(s[i + 1][j + 1], s[i][j]) ;
}
for (int j = 1 ; j < k ; ++ j){
add(s[i + 1][j], s[i + 1][j - 1]) ;
add(res[i + 1], s[i + 1][j]) ;
}
}
}
int main(){
fac[0] = I[0] = 1 ;
cin >> n >> k >> m ;
int mx = max(n, max(k, m)) ;
for (int i = 1 ; i <= m ; ++ i)
scanf("%d", &base[i]) ;
for (int i = 1 ; i <= mx + 1 ; ++ i)
fac[i] = fac[i - 1] * (ll)i % P ;
I[mx + 1] = expow(fac[mx + 1], P - 2) ;
for (int i = mx ; i >= 1 ; -- i)
I[i] = (ll)(i + 1)* I[i + 1] % P ;
ans = expow(k, n - m) * (ll)(n - m + 1) % P ;
//cout << ans << endl ;
if (check()) return cout << ans << endl, 0 ;
if (check2()){ //there is a same pair of number
X[0] = Y[0] = 1 ;
dp2(X, f, 0) ; dp2(Y, g, 1) ;
for (int i = 0 ; i <= n - m ; ++ i)
dec(ans, X[i] * Y[n - m - i] % P) ;
}
else dec(ans, dp1() * I[k] % P * fac[k - m] % P) ;
cout << ans << endl ; return 0 ;
}