题解 P3867 【[TJOI2009]排列计数】

· · 题解

做了一道OI省选题,题目很短,题意求1到N的全排列,有多少个排列满足所有相邻两项之差都不超过K(N≤50,K≤4).

这道题让我想起以前做过的一题(URAL1260)。那题要求K=2,排列开头必须是1.当时,我从N=1,开始枚举,并且用了递推的方法(从N-1的序列推到N的序列。如果N-1序列不是符合要求的,那么无论把N插入N-1序列的任意位置,N序列都是不符合要求的。那也就是说,符合要求的N序列,是从符合要求的N-1序列推出的。),解决了那道题,递推公式并不复杂。这道题显然情况更复杂。我还是先尝试K=2的情况。

对于K=2,要从N-1序列生成N序列,其实只要考虑N-1序列中N-1和N-2的位置,当它们在序列的两端或者它们是相连的,就可以插入N,得到符合要求的N序列,再记录N和N-1的位置关系的种数(注意这时候N-2已经没用了,要把它删除掉,只记录N和N-1的位置关系)。当N≤2时,所有序列都符合,N>2时,可以记录各种情况的种数,最后把所有情况累加起来就是答案。

对于K=3和K=4的情形,其实也是一样的。分成2^(K+1)*(K!)种情况。代码很复杂,写了好久,供大家参考:

#include <stdio.h>
#include <algorithm>
using namespace std;
int m = 1000000007;
int b2[6] = {1, 2, 4, 8, 16, 32};
int f[5] = {1, 1, 2, 6, 24};
int c[32][5] = {0}, p[24][4];
int s[5][51][24][32] = {0};
int solve(int n, int k)
{
    if(n <= k) return f[n];
    int i, ic, t, ip, j, d = 4 - k, tm = b2[k + 1] - 1, l, l0, ipp, icc;
    int tp[4], tc[5], ttp[5], ttc[6], ans = 0;
    for(i = 0; i < f[k]; i++){
        s[k][k][i][tm] = 1;
    }
    for(i = k + 1; i <= n; i++){
        for(ip = 0; ip < f[k]; ip++){
            for(j = 0; j < k; j++) tp[j] = p[ip][j + d] - d;
            for(ic = 0; ic <= tm; ic++){
                if(!s[k][i - 1][ip][ic]) continue; 
                for(j = 0; j <= k; j++) tc[j] = c[ic][j];
                for(j = 0; j <= k; j++){
                    if(!tc[j]) continue;
                    for(l = 0; l < j; l++) ttc[l] = tc[l], ttp[l] = tp[l];
                    for(l = k; l >= j; l--) ttc[l + 1] = tc[l];
                    for(l = k - 1; l >= j; l--) ttp[l + 1] = tp[l];
                    ttc[j] = 1, ttp[j] = k;
                    for(l0 = 0; ttp[l0]; l0++);
                    for(l = l0; l < k; l++) ttp[l] = ttp[l + 1];
                    for(l = l0; l <= k; l++) ttc[l] = ttc[l + 1];
                    ttc[l0] = 0;
                    for(l = 0; l < k; l++) --ttp[l];
                    for(l = icc = 0; l <= k; l++) icc += b2[l] * ttc[l];
                    int h[4] = {0}, ch;
                    for(l = 0, t = k - 1, ipp = 0; t >= 0; l++, t--){
                        for(ch = 0, l0 = 0; l0 < ttp[l]; l0++) ch += (h[l0] == 0);
                        h[ttp[l]] = 1;
                        ipp += ch * f[t];
                    }
                    s[k][i][ipp][icc] += s[k][i - 1][ip][ic];
                    if(s[k][i][ipp][icc] >= m) s[k][i][ipp][icc] -= m;
                }
            }
        }
    }
    for(ip = ans = 0; ip < f[k]; ip++){
        for(ic = 0; ic <= tm; ic++){
            ans += s[k][n][ip][ic];
            if(ans >= m) ans -= m;
        } 
    }
    return ans;
}
int main() 
{
    int n, k, i, j, t;
    int tp[4] = {0, 1, 2, 3};
    for(i = 1; i < 32; i++){
        t = i, j = 0;
        while(t){
            c[i][j] = t & 1;
            t >>= 1;
            j++;
        }
    }
    for(i = 0; i < 24; i++){
        for(j = 0; j < 4; j++) p[i][j] = tp[j];
        next_permutation(tp, tp + 4);
    }
    scanf("%d%d", &n, &k);
    printf("%d\n", solve(n, k));
    return 0;
}