题解 AT4169 【[ARC100D] Colorful Sequences】

· · 题解

个人认为应该是全网写的最详细的一篇了…吧?

如果题解界面格式崩了请直接来blog。

给定一个长为 m 的序列,保证每个数都 \leq k。同时定义,如果一个全部元素均 \leq x 的序列中存在一个长度为 x 子序列,恰好是 1\sim x 的排列,那么称这个序列为 「x-好序列」。

给定长度 n 。求所有长度为 nk-好序列共包含了多少个长度为 m 的子序列。

记这个给定的序列为 s .

首先发现,如果暴力计数统计的话似乎不是很简单,可能只允许状压。所以正难则反( trick1) ,考虑不加 「好序列」的限制,答案就是平凡的 k^{n-m}(n - m + 1)n-m+1 用来确定 s 放在哪个位置。

接下来需要分类讨论:

1、s 本身就是一个 k-好序列。

那么所有包含 s 的序列,一定包含 k 。此时所有情况都是合法的。

2、s 中含有一对重复元素。

此时需要注意的是,绝对不会有跨过 sk-连续段。所以想到在左右两边分别计算有多少个序列使得与 s 拼接之后,不存在 k-连续段。

具体一点。考虑记 f_{i,j} 表示已经考虑了 i 个数,末尾存在一个 j 连续段的序列排列数。发现一共两种转移:

\begin{aligned} (k-j)\cdot f_{i,j}&\to f_{i+1,j+1}\\ f_{i,j}&\to f_{i+1,p}\quad (p\leq j) \end{aligned}

第一个转移显然是找一个新的元素放进来,第二个转移则表示永远可以选一个 k-连续段中的数字来调整长度。朴素的 dpnk^2 的,考虑如何快速计算第二个转移。发现对于 i 下的一个第二维 j,根据第二个转移需要加上所有 i-1q\geq if_{i-1,q} 的和,是一个后缀和的形式。

考虑如何在刷表的时候维护后缀和(trick2):假设当前状态 (i,j) 要转移到 (i+1,j+1) ,转移了 v 。那么考虑再从 (i,j) 转移到 (i+1,j+2),转移 -v 。这样对于每个状态 (i,j),其真实值都是一个 1\to j 的累加过程,可以前缀和,也就是把第一个转移给前缀和化;对于第二个转移,考虑向 (i+1,1) 这个状态转移 \sum _jf_{i,j} ,然后在每个 (i+1,j) 处减掉 (i,j-1) 。原因是根据第二个转移,对于 i 而言,只有 \geq jf_{i,p} 才能产生贡献,所以相当于提前减掉了这一部分。

转移的时候保证不让 j=k 参与进来,那么就可以保证合法。

扯回正题。发现计算这个东西,如果一开始在 s 中确定了 f_0 的哪个状态合法,就应当对其设初值。具体的,以左半边为例,找出左半边最长的连续段长度为 l,那么 f_{0,l} 初值便应该为 1 。整个计算过程相当于从 s_1 向左延伸和从 s_m 向右延伸。最后的合并显然是一个卷积的形式。

3、s 中不含有重复元素。

一个很神仙的点。考虑不含重复元素,可以转化成随便一个长度为 n k-好序列,含有多少 m-好序列的方案数,除以 A_{k}^m。因为显然每个序列都是等价的,所以可以如是转化。这个东西的 dp 过程和 2、 中的差不多,也是「记 f_{i,j} 表示已经考虑了 i 个数,末尾存在一个 j 连续段的序列排列数」——不同点在于这个 dp 记录的是个数,那么再拿一个数组,和 f 一样转移,但是每次都需要加上 f 对应的方案数,表示计算多次。

然后就没有然后了。注意可能存在 k>n 的sb情况,判掉即可。

const int K = 410 ;
const int N = 30010 ;
const int P = 1000000007 ;

ll ans ;
ll X[N] ;
ll Y[N] ;
ll I[N] ;
ll fac[N] ;
ll g[N][K] ;
ll f[N][K] ;
int buc[N] ;
int n, k, m ;
int base[N] ;

void add(ll &x, ll y){
    (x += y) %= P ;
}
void dec(ll &x, ll y){
    (x -= y) %= P ;
    if (x < 0) x += P ;
}
ll expow(ll a, ll b){
    ll res = 1 ;
    while (b){
        if (b & 1)
            (res *= a) %= P ;
        (a *= a) %= P ; b >>= 1 ;
    }
    return res ;
}
bool check(){
    int now = 0, j = 1 ;
    for (int i = 1 ; i <= m ; ++ i){
        if (!buc[base[i]]) now ++ ; buc[base[i]] ++ ;
        while (buc[base[j]] > 1) -- buc[base[j ++]] ;
        if (i - j + 1 == k && now == k) return 1 ;
    }
    return 0 ;
}
bool check2(){
    memset(buc, 0, sizeof(buc)) ;
    for (int i = 1 ; i <= m ; ++ i)
        if (buc[base[i]]) return 1 ; else ++ buc[base[i]] ;
    return 0 ;
}
ll dp1(){
    g[0][0] = 1 ; ll ret = 0 ;
    for (int i = 0 ; i < n ; ++ i){
        for (int p, q, j = 0 ; j < k ; ++ j){
            p = f[i][j] * (k - j) % P ;
            q = g[i][j] * (k - j) % P ;
            if (j + 1 < k){
                add(f[i + 1][j + 1], p) ;
                add(g[i + 1][j + 1], q) ;
                dec(f[i + 1][j + 2], p) ;
                dec(g[i + 1][j + 2], q) ;
            }
            add(f[i + 1][1], f[i][j]) ;
            add(g[i + 1][1], g[i][j]) ;
            dec(f[i + 1][j + 1], f[i][j]) ;
            dec(g[i + 1][j + 1], g[i][j]) ;
        }
        for (int j = 1 ; j < k ; ++ j){
            add(g[i + 1][j], g[i + 1][j - 1]) ;
            add(f[i + 1][j], f[i + 1][j - 1]) ;
        }
        for (int j = m ; j < k ; ++ j)
            add(f[i + 1][j], g[i + 1][j]) ;
    }
    for (int i = 1 ; i < k ; ++ i) add(ret, f[n][i]) ;
    return ret ;
}
void dp2(ll res[N], ll s[N][K], int mk){
    memset(buc, 0, sizeof(buc)) ;
    if (!mk){
        for (int i = 1 ; i <= m ; ++ i)
            if (!buc[base[i]]) buc[base[i]] = 1 ;
            else { s[0][i - 1] = 1 ; break ; }
    }
    else {
        for (int i = m ; i >= 1 ; -- i)
            if (!buc[base[i]]) buc[base[i]] = 1 ;
            else { s[0][m - i] = 1 ; break ; }
    }
    for (int i = 0 ; i < n - m ; ++ i){
        for (int t, j = 1 ; j < k ; ++ j){
            t = s[i][j] * (k - j) % P ;
            if (j + 1 < k){
                add(s[i + 1][j + 1], t) ;
                dec(s[i + 1][j + 2], t) ;
            }
            add(s[i + 1][1], s[i][j]) ;
            dec(s[i + 1][j + 1], s[i][j]) ;
        }
        for (int j = 1 ; j < k ; ++ j){
            add(s[i + 1][j], s[i + 1][j - 1]) ;
            add(res[i + 1], s[i + 1][j]) ;
        }
    }
}
int main(){
    fac[0] = I[0] = 1 ;
    cin >> n >> k >> m ;
    int mx = max(n, max(k, m)) ;
    for (int i = 1 ; i <= m ; ++ i)
        scanf("%d", &base[i]) ;
    for (int i = 1 ; i <= mx + 1 ; ++ i)
        fac[i] = fac[i - 1] * (ll)i % P ;
    I[mx + 1] = expow(fac[mx + 1], P - 2) ;
    for (int i = mx ; i >= 1 ; -- i)
        I[i] = (ll)(i + 1)* I[i + 1] % P ;
    ans = expow(k, n - m) * (ll)(n - m + 1) % P ;
    //cout << ans << endl ;
    if (check()) return cout << ans << endl, 0 ;
    if (check2()){ //there is a same pair of number
        X[0] = Y[0] = 1 ;
        dp2(X, f, 0) ; dp2(Y, g, 1) ;
        for (int i = 0 ; i <= n - m ; ++ i)
            dec(ans, X[i] * Y[n - m - i] % P) ;
    }
    else dec(ans, dp1() * I[k] % P * fac[k - m] % P) ;
    cout << ans << endl ; return 0 ;
}