题解 AT4169 【[ARC100D] Colorful Sequences】

· · 题解

Solution

首先答案就是a序列的总出现次数减去其在不 colorful 序列中的总出现次数,而前面那个的答案就是(n−m+1)×k^{n−m},即枚举a出现的位置,考虑每个位置对答案的贡献。

考虑计算后面那个,分成三种情况考虑

  1. 否则,若a_{1...m}中的数互不相同,但由于m<k所以不是 colorful 序列。我们设f_{i,j}表示考虑前i个位置,当前极长没有重复颜色的后缀长度为j的方案数,初值为f_{0,0}=1,最终要求的即为\sum_{j=0}^{k-1}f_{n,j},转移如下

    \begin{aligned} f_{i,j} \leftarrow f_{i-1,j^\prime} \cdot\begin{cases}0 & j^\prime + 1 < j \\k – j^\prime & j^\prime + 1 = j \\1 & j^\prime + 1 > j\end{cases} \end{aligned}

    用后缀和优化就可以做到\mathcal O(nk)的复杂度。 同理我们可以设一个g_{i,j}表示f_{i,j}a_{1...m}的个数,转移类似,当j\geq m时让g_{i,j}+=f_{i,j}即可,算出最终的g后除以k^{\underline{m}}即为答案。

  2. 否则,这意味着a_{1...m}中有相同的数且不是 colorful 序列。先求出极长的没有重复数的前后缀长度,设为l,r,同样采用上面的dp,只是将初值改成f_{0,0...l}=g_{0,0..r}=1,然后将fg都第二种情况转移即可。(注意这里的g不是上面的g,它和f的意义是一样的,也代表方案数),最终枚举a_{1...m}的开头位置,那么答案就是

    \left(\sum_{j=0}^{k-1} f_{i-1,j}\right)\left(\sum_{j=0}^{k-1} g_{n-(i-1)-m,j}\right)

时间复杂度\mathcal O(nk)

Code

int n,k,m,ans,tot;
int a[MAXN],f[MAXN][MAXK],g[MAXN][MAXK],vis[MAXK];

bool check(int l,int r){
    memset(vis,0,sizeof(vis));
    for(int i = l;i <= r;i++){
        if(vis[a[i]]) return false;
        vis[a[i]] = 1;
    }
    return true;
}

bool colorful(){
    for(int i = 1;i + k - 1 <= m;i++)
        if(check(i,i + k - 1)) return true;
    return false;
}

void Solve(int (*f)[MAXK]){
    for(int i = 1;i <= n;i++){
        for(int j = 1;j < k;j++)
            f[i][j] = add(1ll * sub(f[i - 1][j - 1],f[i - 1][j]) * (k - j + 1) % MOD,f[i - 1][j]);
        for(int j = k - 1;j >= 0;j--)
            addmod(f[i][j],f[i][j + 1]); 
    }
}

int main(){
    scanf("%d%d%d",&n,&k,&m);
    tot = 1ll * power(k,n - m) * (n - m + 1) % MOD;
    for(int i = 1;i <= m;i++) scanf("%d",&a[i]);
    if(colorful()) {printf("%d\n",tot); return 0;}
    int l = m, r = 1;
    memset(vis,0,sizeof(vis));
    for(int i = 1;i <= m;i++){
        if(vis[a[i]]) {l = i - 1; break;}
        vis[a[i]] = 1;
    }
    memset(vis,0,sizeof(vis));
    for(int i = m;i >= 1;i--){
        if(vis[a[i]]) {r = m - i; break;}
        vis[a[i]] = 1;
    }
    if(l == m){
        f[0][0] = 1;
        for(int i = 1;i <= n;i++){
            for(int j = 1;j < k;j++){
                f[i][j] = add(1ll * sub(f[i - 1][j - 1],f[i - 1][j]) * (k - j + 1) % MOD,f[i - 1][j]);
                g[i][j] = add(1ll * sub(g[i - 1][j - 1],g[i - 1][j]) * (k - j + 1) % MOD,g[i - 1][j]);
                if(j >= m) addmod(g[i][j],f[i][j]);
            }
            for(int j = k - 1;j >= 0;j--){
                addmod(f[i][j],f[i][j + 1]);
                addmod(g[i][j],g[i][j + 1]);
            }
        }
        ans = g[n][0];
        for(int i = k;i > k - m;i--) ans = 1ll * power(i,MOD - 2) * ans % MOD;
    }else{
        for(int i = 0;i <= l;i++) f[0][i] = 1;
        for(int i = 0;i <= r;i++) g[0][i] = 1;
        Solve(f); Solve(g);
        for(int i = 0;i <= n - m;i++)
            addmod(ans,1ll * f[i][0] * g[n - m - i][0] % MOD);
    }
    printf("%d\n",sub(tot,ans));
    return 0;
}