笨蛋花的小窝qwq

笨蛋花的小窝qwq

“我要再和生活死磕几年。要么我就毁灭,要么我就注定铸就辉煌。如果有一天,你发现我在平庸面前低了头,请向我开炮。”

题解 AT4169 【[ARC100D] Colorful Sequences】

posted on 2020-03-09 00:34:36 | under 题解 |

个人认为应该是全网写的最详细的一篇了…吧?

如果题解界面格式崩了请直接来blog。

给定一个长为 $m$ 的序列,保证每个数都 $\leq k$。同时定义,如果一个全部元素均 $\leq x$ 的序列中存在一个长度为 $x$ 子序列,恰好是 $1\sim x$ 的排列,那么称这个序列为 「$x-$好序列」。

给定长度 $n$ 。求所有长度为 $n$ 的 $k-$好序列共包含了多少个长度为 $m$ 的子序列。

$1\leq n\leq 25000,1\leq k\leq 400$ 。

记这个给定的序列为 $s$ .

首先发现,如果暴力计数统计的话似乎不是很简单,可能只允许状压。所以正难则反( trick1) ,考虑不加 「好序列」的限制,答案就是平凡的 $k^{n-m}(n - m + 1)$,$n-m+1$ 用来确定 $s$ 放在哪个位置。

接下来需要分类讨论:

1、$s$ 本身就是一个 $k-$好序列。

那么所有包含 $s$ 的序列,一定包含 $k$ 。此时所有情况都是合法的。

2、$s$ 中含有一对重复元素。

此时需要注意的是,绝对不会有跨过 $s$ 的 $k-$连续段。所以想到在左右两边分别计算有多少个序列使得与 $s$ 拼接之后,不存在 $k-$连续段。

具体一点。考虑记 $f_{i,j}$ 表示已经考虑了 $i$ 个数,末尾存在一个 $j$ 连续段的序列排列数。发现一共两种转移:

$$ \begin{aligned} (k-j)\cdot f_{i,j}&\to f_{i+1,j+1}\\ f_{i,j}&\to f_{i+1,p}\quad (p\leq j) \end{aligned} $$

第一个转移显然是找一个新的元素放进来,第二个转移则表示永远可以选一个 $k-$连续段中的数字来调整长度。朴素的 $dp$ 是 $nk^2$ 的,考虑如何快速计算第二个转移。发现对于 $i$ 下的一个第二维 $j$,根据第二个转移需要加上所有 $i-1$ 且 $q\geq i$ 的 $f_{i-1,q}$ 的和,是一个后缀和的形式。

考虑如何在刷表的时候维护后缀和(trick2):假设当前状态 $(i,j)$ 要转移到 $(i+1,j+1)$ ,转移了 $v$ 。那么考虑再从 $(i,j)$ 转移到 $(i+1,j+2)$,转移 $-v$ 。这样对于每个状态 $(i,j)$,其真实值都是一个 $1\to j$ 的累加过程,可以前缀和,也就是把第一个转移给前缀和化;对于第二个转移,考虑向 $(i+1,1)$ 这个状态转移 $\sum _jf_{i,j}$ ,然后在每个 $(i+1,j)$ 处减掉 $(i,j-1)$ 。原因是根据第二个转移,对于 $i$ 而言,只有 $\geq j$ 的 $f_{i,p}$ 才能产生贡献,所以相当于提前减掉了这一部分。

转移的时候保证不让 $j=k$ 参与进来,那么就可以保证合法。

扯回正题。发现计算这个东西,如果一开始在 $s$ 中确定了 $f_0$ 的哪个状态合法,就应当对其设初值。具体的,以左半边为例,找出左半边最长的连续段长度为 $l$,那么 $f_{0,l}$ 初值便应该为 $1$ 。整个计算过程相当于从 $s_1$ 向左延伸和从 $s_m$ 向右延伸。最后的合并显然是一个卷积的形式。

3、$s$ 中不含有重复元素。

一个很神仙的点。考虑不含重复元素,可以转化成随便一个长度为 $n$ 的 $k-$好序列,含有多少 $m-$好序列的方案数,除以 $A_{k}^m$。因为显然每个序列都是等价的,所以可以如是转化。这个东西的 $dp$ 过程和 2、 中的差不多,也是「记 $f_{i,j}$ 表示已经考虑了 $i$ 个数,末尾存在一个 $j$ 连续段的序列排列数」——不同点在于这个 $dp$ 记录的是个数,那么再拿一个数组,和 $f$ 一样转移,但是每次都需要加上 $f$ 对应的方案数,表示计算多次。

然后就没有然后了。注意可能存在 $k>n$ 的sb情况,判掉即可。

const int K = 410 ;
const int N = 30010 ;
const int P = 1000000007 ;

ll ans ;
ll X[N] ;
ll Y[N] ;
ll I[N] ;
ll fac[N] ;
ll g[N][K] ;
ll f[N][K] ;
int buc[N] ;
int n, k, m ;
int base[N] ;

void add(ll &x, ll y){
    (x += y) %= P ;
}
void dec(ll &x, ll y){
    (x -= y) %= P ;
    if (x < 0) x += P ;
}
ll expow(ll a, ll b){
    ll res = 1 ;
    while (b){
        if (b & 1)
            (res *= a) %= P ;
        (a *= a) %= P ; b >>= 1 ;
    }
    return res ;
}
bool check(){
    int now = 0, j = 1 ;
    for (int i = 1 ; i <= m ; ++ i){
        if (!buc[base[i]]) now ++ ; buc[base[i]] ++ ;
        while (buc[base[j]] > 1) -- buc[base[j ++]] ;
        if (i - j + 1 == k && now == k) return 1 ;
    }
    return 0 ;
}
bool check2(){
    memset(buc, 0, sizeof(buc)) ;
    for (int i = 1 ; i <= m ; ++ i)
        if (buc[base[i]]) return 1 ; else ++ buc[base[i]] ;
    return 0 ;
}
ll dp1(){
    g[0][0] = 1 ; ll ret = 0 ;
    for (int i = 0 ; i < n ; ++ i){
        for (int p, q, j = 0 ; j < k ; ++ j){
            p = f[i][j] * (k - j) % P ;
            q = g[i][j] * (k - j) % P ;
            if (j + 1 < k){
                add(f[i + 1][j + 1], p) ;
                add(g[i + 1][j + 1], q) ;
                dec(f[i + 1][j + 2], p) ;
                dec(g[i + 1][j + 2], q) ;
            }
            add(f[i + 1][1], f[i][j]) ;
            add(g[i + 1][1], g[i][j]) ;
            dec(f[i + 1][j + 1], f[i][j]) ;
            dec(g[i + 1][j + 1], g[i][j]) ;
        }
        for (int j = 1 ; j < k ; ++ j){
            add(g[i + 1][j], g[i + 1][j - 1]) ;
            add(f[i + 1][j], f[i + 1][j - 1]) ;
        }
        for (int j = m ; j < k ; ++ j)
            add(f[i + 1][j], g[i + 1][j]) ;
    }
    for (int i = 1 ; i < k ; ++ i) add(ret, f[n][i]) ;
    return ret ;
}
void dp2(ll res[N], ll s[N][K], int mk){
    memset(buc, 0, sizeof(buc)) ;
    if (!mk){
        for (int i = 1 ; i <= m ; ++ i)
            if (!buc[base[i]]) buc[base[i]] = 1 ;
            else { s[0][i - 1] = 1 ; break ; }
    }
    else {
        for (int i = m ; i >= 1 ; -- i)
            if (!buc[base[i]]) buc[base[i]] = 1 ;
            else { s[0][m - i] = 1 ; break ; }
    }
    for (int i = 0 ; i < n - m ; ++ i){
        for (int t, j = 1 ; j < k ; ++ j){
            t = s[i][j] * (k - j) % P ;
            if (j + 1 < k){
                add(s[i + 1][j + 1], t) ;
                dec(s[i + 1][j + 2], t) ;
            }
            add(s[i + 1][1], s[i][j]) ;
            dec(s[i + 1][j + 1], s[i][j]) ;
        }
        for (int j = 1 ; j < k ; ++ j){
            add(s[i + 1][j], s[i + 1][j - 1]) ;
            add(res[i + 1], s[i + 1][j]) ;
        }
    }
}
int main(){
    fac[0] = I[0] = 1 ;
    cin >> n >> k >> m ;
    int mx = max(n, max(k, m)) ;
    for (int i = 1 ; i <= m ; ++ i)
        scanf("%d", &base[i]) ;
    for (int i = 1 ; i <= mx + 1 ; ++ i)
        fac[i] = fac[i - 1] * (ll)i % P ;
    I[mx + 1] = expow(fac[mx + 1], P - 2) ;
    for (int i = mx ; i >= 1 ; -- i)
        I[i] = (ll)(i + 1)* I[i + 1] % P ;
    ans = expow(k, n - m) * (ll)(n - m + 1) % P ;
    //cout << ans << endl ;
    if (check()) return cout << ans << endl, 0 ;
    if (check2()){ //there is a same pair of number
        X[0] = Y[0] = 1 ;
        dp2(X, f, 0) ; dp2(Y, g, 1) ;
        for (int i = 0 ; i <= n - m ; ++ i)
            dec(ans, X[i] * Y[n - m - i] % P) ;
    }
    else dec(ans, dp1() * I[k] % P * fac[k - m] % P) ;
    cout << ans << endl ; return 0 ;
}