题解:CF268D Wall Bars

· · 题解

题意

将题目抽象成这个问题:构造一个由 1,2,3,4 组成的序列,使得存在 1\le x\le 4,任意两个相邻的 x 的距离 \le h,且 x 第一次出现的位置 \le h,最后一次出现的位置 \ge n-h+1。求方案数。

Solution

考虑指定 x 是一个数时怎么做,这是容易的,显然可以设计 dp 状态 f_{i,j} 表示前 i 个位置,x 最后一次出现的位置是 i-j 的方案数。有转移:

f_{i,j} = \begin{cases} \sum_{k=0}^{h-1} f_{i-1,k} & j=0 \\ 3f_{i-1,j-1} & j>0 \end{cases}

上下两种转移分别表示新加的数是不是 x。由于我们只考虑了一个 x,所以最终答案要乘 4

然而这并不完善,因为我们考虑 4x 的时候情况会有重复。考虑容斥。类似地可以求出同时存在 2,3,4x 满足条件时的方案数,根据容斥原理做即可。

4x 满足条件时状态数达到了 O(nh^4),这是难以接受的。注意到加入一个数时四维状态必有一个变成 0,故可以压掉一维。

最终复杂度 O(nh^3),常数较大,注意卡常。

Code

#include<bits/stdc++.h>
#pragma GCC optimize(2)
#define ll long long
using namespace std;
const int N = 1005,M = 1000000009;
int n,h;
ll f[N][35],ans,g[2][35][35],p[2][35][35][35],q[2][35][35][35][35];
inline void amod(ll &a)
{
    if (a > M) a -= M;
}
inline void aamod(ll &a)
{
    while (a > M) a -= M;
}
signed main()
{
    ios_base::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    cin >> n >> h;
    f[0][0] = 1;
    for (int i = 0; i < n; i++)
        for (int j = 0; j < h; j++)
        {
            f[i+1][0] += f[i][j],f[i+1][j+1] += f[i][j]*3;
            amod(f[i+1][0]),aamod(f[i+1][j+1]);
        }
    g[0][0][0] = 1;
    int nowg = 0;
    for (int i = 0; i < n; i++)
    {
        for (int j = 0; j <= h; j++)
            for (int k = 0; k <= h; k++)
                g[nowg^1][j][k] = 0;
        for (int j = 0; j < h; j++)
            for (int k = 0; k < h; k++)
            {
                g[nowg^1][j+1][0] += g[nowg][j][k];
                g[nowg^1][0][k+1] += g[nowg][j][k];
                g[nowg^1][j+1][k+1] += g[nowg][j][k]*2;
                amod(g[nowg^1][j+1][0]);
                amod(g[nowg^1][0][k+1]);
                aamod(g[nowg^1][j+1][k+1]);
            }
        nowg ^= 1;
        int cnt = 0;
        for (int j = 0; j < h; j++)
            for (int k = 0; k < h; k++)
                cnt += g[nowg][j][k];
    }
    p[0][0][0][0] = 1;
    nowg = 0;
    for (int i = 0; i < n; i++)
    {
        for (int j = 0; j < h; j++)
            for (int k = 0; k < h; k++)
                for (int l = 0; l < h; l++)
                    p[nowg^1][j][k][l] = 0;
        for (int j = 0; j < h; j++)
            for (int k = 0; k < h; k++)
                for (int l = 0; l < h; l++)
                {
                    p[nowg^1][0][k+1][l+1] += p[nowg][j][k][l];
                    p[nowg^1][j+1][0][l+1] += p[nowg][j][k][l];
                    p[nowg^1][j+1][k+1][0] += p[nowg][j][k][l];
                    p[nowg^1][j+1][k+1][l+1] += p[nowg][j][k][l];
                    amod(p[nowg^1][0][k+1][l+1]);
                    amod(p[nowg^1][j+1][0][l+1]);
                    amod(p[nowg^1][j+1][k+1][0]);
                    amod(p[nowg^1][j+1][k+1][l+1]);
                }
        nowg ^= 1;
    }
    q[0][0][0][0][0] = 1;
    nowg = 0;
    for (int i = 0; i < n; i++)
    {
        for (int j = 0; j < h; j++)
            for (int k = 0; k < h; k++)
                for (int l = 0; l < h; l++)
                    for (int u = 0; u < (!l||!j||!k?h:1); u++)
                        q[nowg^1][j][k][l][u] = 0,aamod(q[nowg][j][k][l][u]);
        for (int j = 0; j < h; j++)
            for (int k = 0; k < h; k++)
                for (int l = 0; l < h; l++)
                {
                    for (int u = 0; u < (!l||!j||!k?h:1); u++)
                    {
                        if (!q[nowg][j][k][l][u]) continue;
                        q[nowg^1][0][k+1][l+1][u+1] += q[nowg][j][k][l][u];
                        q[nowg^1][j+1][0][l+1][u+1] += q[nowg][j][k][l][u];
                        q[nowg^1][j+1][k+1][0][u+1] += q[nowg][j][k][l][u];
                        q[nowg^1][j+1][k+1][l+1][0] += q[nowg][j][k][l][u];
                    }
                }
        nowg ^= 1;
    }
    for (int i = 0; i < h; i++)
    {
        ans += f[n][i];
        amod(ans);
    }
    ans = ans*4%M;
    ll num = 0;
    for (int i = 0; i < h; i++)
        for (int j = 0; j < h; j++)
        {
            num += g[nowg][i][j];
            amod(num);
        }
    ans = (ans-num*6%M+M)%M;
    num = 0;
    for (int j = 0; j < h; j++)
        for (int k = 0; k < h; k++)
            for (int l = 0; l < h; l++)
            {
                num += p[nowg][j][k][l];
                amod(num);
            }
    ans = (ans+num*4)%M;
    num = 0;
    for (int j = 0; j < h; j++)
        for (int k = 0; k < h; k++)
            for (int l = 0; l < h; l++)
                for (int u = 0; u < (!l||!j||!k?h:1); u++)
                {
                    num += q[nowg][j][k][l][u];
                    aamod(num);
                }
    ans = (ans-num+M)%M;
    cout << ans;
    return 0;
}