P3734 [HAOI2017]方案数题解

· · 题解

直接考虑不经过 n 个点的方案难度较大,所以考虑进行容斥。

根据题意,一个点可到达的其他点是把原来点坐标二进制表示下某些 0 变为 1 得到的。记 x 二进制表示下 1 的个数为 count(x)dp[i][j][k] 为从 (0,0,0) 到满足 count(x)=i,count(y)=j,count(z)=k 的点 P(x,y,z) 的方案数。由于每一步只能改变一维坐标,所以可以得到转移方程

dp[i][j][k]=\sum_{x=0}^{i-1}dp[x][j][k]\times \dbinom{i}{x}+\sum_{y=0}^{j-1}dp[i][y][k]\times \dbinom{j}{y}+\sum_{z=0}^{k-1}dp[i][j][z]\times \dbinom{k}{z}

再考虑障碍点。

记第 i 个障碍为 P_i(x_i,y_i,z_i)f[i] 为不经过其他障碍到达第 i 个障碍的方案数。首先将障碍按字典序排序,则可以保证无后效性。从 dp[count(x_i)][count(y_i)][count(z_i)] 得到 f[i],则依次枚举从 (0,0,0)P_i 的过程中经过的第一个其他障碍 P_j(j<i) 若从 P_jP_i 的方案数为 to(j, i) ,则可得 f[i] 转移方程

f[i]=dp[count(x_i)][count(y_i)][count(z_i)]-\sum_{j=1}^{i-1}f[j]\times to(j,i)

最后只剩下求 to(j,i) 了。事实上,dp[i][j][k] 仅与下标变化有关,即从起点到任一可到达,且 x,y,z 坐标比起点的坐标在二进制表示下分别多i,j,k1 的点的方案数。所以若从 P_i 可以到达 P_i(即 x_j\land x_i=x_jy_j\land y_i=y_j,且z_j\land z_i=z_j),则有

to(j, i)=dp[count(x_i)-count(x_j)][count(y_i)-count(y_j)][count(z_i)-count(z_j)]

否则 to(j,i)=0

为了方便得到答案,可以将终点(n,m,r)当做第 o+1 个障碍,则答案为 f[o+1]

#include <bits/stdc++.h>
typedef long long LL;
int read()
{
    int s = 0, f = 1;
    char ch = getchar();
    while (!(ch >= '0' && ch <= '9'))
    {
        if (ch == '-')
            f = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9')
    {
        s = (s << 3) + (s << 1) + (ch ^ 48);
        ch = getchar();
    }
    return s * f;
}
LL readll()
{
    LL s = 0, f = 1;
    char ch = getchar();
    while (!(ch >= '0' && ch <= '9'))
    {
        if (ch == '-')
            f = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9')
    {
        s = (s << 3) + (s << 1) + (ch ^ 48);
        ch = getchar();
    }
    return s * f;
}
struct Block
{
    LL x, y, z;
    int xx, yy, zz;
} ;
LL n, m, r;
int o;
Block a[10005];
const int MOD = 998244353;
LL dp[70][70][70] = {{{0}}}, f[10005];
LL c[70][70] = {{0}};
LL comb(int n, int m)
{
    if (c[n][m])
        return c[n][m];
    if (n == m || m == 0)
        return c[n][m] = 1;
    return c[n][m] = (comb(n - 1, m - 1) + comb(n - 1, m)) % MOD;
}
int count(LL x)
{
    int cnt = 0;
    for (; x; x -= x & -x) 
        cnt++;
    return cnt;
}
inline bool chk(int i, int j)
{
    return ((a[i].x & a[j].x) == a[i].x)
        && ((a[i].y & a[j].y) == a[i].y)
        && ((a[i].z & a[j].z) == a[i].z);
}
int main()
{
    n = readll(), m = readll(), r = readll(), o = read();
    for (int i = 1; i <= o; i++)
        a[i].x = readll(), a[i].y = readll(), a[i].z = readll();
    dp[0][0][0] = 1;
    for (int i = 0; i <= 64; i++)
        for (int j = 0; j <= 64; j++)
            for (int k = 0; k <= 64; k++)
            {
                for (int ii = 0; ii < i; ii++)
                    dp[i][j][k] = (dp[i][j][k] + dp[ii][j][k] * comb(i, i - ii) % MOD) % MOD;
                for (int jj = 0; jj < j; jj++)
                    dp[i][j][k] = (dp[i][j][k] + dp[i][jj][k] * comb(j, j - jj) % MOD) % MOD;
                for (int kk = 0; kk < k; kk++)
                    dp[i][j][k] = (dp[i][j][k] + dp[i][j][kk] * comb(k, k - kk) % MOD) % MOD;
            }
    std::sort(a + 1, a + o + 1, [](Block p, Block q){
        return (p.x != q.x) ? (p.x < q.x) : (p.y != q.y ? p.y < q.y : p.z < q.z);
    }); 
    a[++o].x = n, a[o].y = m, a[o].z = r;
    for (int i = 1; i <= o; i++)
        a[i].xx = count(a[i].x), a[i].yy = count(a[i].y), a[i].zz = count(a[i].z);
    for (int i = 1; i <= o; i++)
    {
        f[i] = dp[a[i].xx][a[i].yy][a[i].zz];
        for (int j = 1; j < i; j++)
            if (chk(j, i))
                f[i] = ((f[i] - dp[a[i].xx - a[j].xx][a[i].yy - a[j].yy][a[i].zz - a[j].zz] * f[j] % MOD) % MOD + MOD) % MOD;
    }
    printf("%lld\n", f[o]);
    return 0;
} 

这题真的能评黑吗