题解 P3867 【[TJOI2009]排列计数】
做了一道OI省选题,题目很短,题意求1到N的全排列,有多少个排列满足所有相邻两项之差都不超过K(N≤50,K≤4).
这道题让我想起以前做过的一题(URAL1260)。那题要求K=2,排列开头必须是1.当时,我从N=1,开始枚举,并且用了递推的方法(从N-1的序列推到N的序列。如果N-1序列不是符合要求的,那么无论把N插入N-1序列的任意位置,N序列都是不符合要求的。那也就是说,符合要求的N序列,是从符合要求的N-1序列推出的。),解决了那道题,递推公式并不复杂。这道题显然情况更复杂。我还是先尝试K=2的情况。
对于K=2,要从N-1序列生成N序列,其实只要考虑N-1序列中N-1和N-2的位置,当它们在序列的两端或者它们是相连的,就可以插入N,得到符合要求的N序列,再记录N和N-1的位置关系的种数(注意这时候N-2已经没用了,要把它删除掉,只记录N和N-1的位置关系)。当N≤2时,所有序列都符合,N>2时,可以记录各种情况的种数,最后把所有情况累加起来就是答案。
对于K=3和K=4的情形,其实也是一样的。分成2^(K+1)*(K!)种情况。代码很复杂,写了好久,供大家参考:
#include <stdio.h>
#include <algorithm>
using namespace std;
int m = 1000000007;
int b2[6] = {1, 2, 4, 8, 16, 32};
int f[5] = {1, 1, 2, 6, 24};
int c[32][5] = {0}, p[24][4];
int s[5][51][24][32] = {0};
int solve(int n, int k)
{
if(n <= k) return f[n];
int i, ic, t, ip, j, d = 4 - k, tm = b2[k + 1] - 1, l, l0, ipp, icc;
int tp[4], tc[5], ttp[5], ttc[6], ans = 0;
for(i = 0; i < f[k]; i++){
s[k][k][i][tm] = 1;
}
for(i = k + 1; i <= n; i++){
for(ip = 0; ip < f[k]; ip++){
for(j = 0; j < k; j++) tp[j] = p[ip][j + d] - d;
for(ic = 0; ic <= tm; ic++){
if(!s[k][i - 1][ip][ic]) continue;
for(j = 0; j <= k; j++) tc[j] = c[ic][j];
for(j = 0; j <= k; j++){
if(!tc[j]) continue;
for(l = 0; l < j; l++) ttc[l] = tc[l], ttp[l] = tp[l];
for(l = k; l >= j; l--) ttc[l + 1] = tc[l];
for(l = k - 1; l >= j; l--) ttp[l + 1] = tp[l];
ttc[j] = 1, ttp[j] = k;
for(l0 = 0; ttp[l0]; l0++);
for(l = l0; l < k; l++) ttp[l] = ttp[l + 1];
for(l = l0; l <= k; l++) ttc[l] = ttc[l + 1];
ttc[l0] = 0;
for(l = 0; l < k; l++) --ttp[l];
for(l = icc = 0; l <= k; l++) icc += b2[l] * ttc[l];
int h[4] = {0}, ch;
for(l = 0, t = k - 1, ipp = 0; t >= 0; l++, t--){
for(ch = 0, l0 = 0; l0 < ttp[l]; l0++) ch += (h[l0] == 0);
h[ttp[l]] = 1;
ipp += ch * f[t];
}
s[k][i][ipp][icc] += s[k][i - 1][ip][ic];
if(s[k][i][ipp][icc] >= m) s[k][i][ipp][icc] -= m;
}
}
}
}
for(ip = ans = 0; ip < f[k]; ip++){
for(ic = 0; ic <= tm; ic++){
ans += s[k][n][ip][ic];
if(ans >= m) ans -= m;
}
}
return ans;
}
int main()
{
int n, k, i, j, t;
int tp[4] = {0, 1, 2, 3};
for(i = 1; i < 32; i++){
t = i, j = 0;
while(t){
c[i][j] = t & 1;
t >>= 1;
j++;
}
}
for(i = 0; i < 24; i++){
for(j = 0; j < 4; j++) p[i][j] = tp[j];
next_permutation(tp, tp + 4);
}
scanf("%d%d", &n, &k);
printf("%d\n", solve(n, k));
return 0;
}