题解:P12002 吃猫粮的玉桂狗

· · 题解

又拿这个套路炒了一波冷饭。看赛时榜好像大家还是不太会,所以希望这个题能对这个套路有一定普及作用。

猜你想搜:

考虑这个每种猫粮的数量多于一半有什么用:在满足位置限制(所有的 (a_i, b_i) 对)时,任意不合法方案都只有一种猫粮超出数量限制。因为如果有两种猫粮 x,y 超出了限制,假设分别放了 w_1, w_2 个,那么有 w_1 > c_x \geq \lfloor\frac{n}{2}\rfloor w_2>c_y \geq \lfloor\frac{n}{2}\rfloor,则 w_1 + w_2 > n。而所有猫粮加起来只会选 n 个,所以这种情况不会存在。

可以考虑枚举那个超限制的猫粮 x,那么其它猫粮都可以随便选,假设种类 x 的猫粮超限制的方案数是 A(x),在不考虑猫粮数量限制时的方案数是 S,则可以容斥一下合法方案数就是 S - \sum_{x= 1}^m A(x)。考虑计算 A(x)

注意到我们关于位置的限制(所有的 (a_i, b_i) 对)其实等价于:以 1 为根的树,不存在任何父亲-孩子对使得父亲的猫粮是 a_i,孩子的猫粮是 b_i。考虑枚举超限制的猫粮种类 x,进行树形 DP:f_{u,i,j} 表示以 u 为根的子树,u 的猫粮种类是 i,且种类为 x 的猫粮已经放了 j 个的方案数。可以进行树上背包转移:

枚举 u 的孩子 v,设 u 当前已经转移的孩子放了 i 个猫粮 xv 这个孩子选了 j 个猫粮 xu 的猫粮种类是 hv 的猫粮种类是 k,当前已转移的孩子的状态放在 f' 里,有:

f(u,h,i+j) \leftarrow \sum_{k = 1}^m (f(v,k,j) \times f'(u, h,i) \times[\mathrm{ok}(h,k)])

这里当 (h,k) 这个限制不存在时,[\mathrm{ok}(h,k)]=1,否则为 0

如果不理解树上背包转移可以去看代码。

如果把第二维拿掉,那么这个 dp 是在做标准树上背包。所以上式是在树上背包的基础上枚举了 hk 两维。标准树上背包的复杂度是 O(n^2),因此上面这个 dp 的复杂度是 O(n^2m^2)。算上我们外面枚举的不合法的猫粮种类 x,总的复杂度是 O(n^2m^3)

那么 A(t) = \sum_{i= 1}^m \sum_{j = c_{t} + 1}^nf(1,i,j)

S(总方案数)的计算也是简单的,直接把记录猫粮 x 的数量一维删掉即可。

g(u,h) \leftarrow g'(u,h) \times \sum_{k = 1}^m g(v,k)\times[\mathrm{ok}(h,k)]

减一下就做完了。

注意树上 DP 的过程里不要对最后一维有无效的枚举,否则复杂度是不对的。

#include <bits/stdc++.h>

const int p = 353'442'899;

int main() {
  std::ios::sync_with_stdio(false);
  std::cin.tie(nullptr);
  int n, m, t;
  std::cin >> n >> m >> t;
  std::vector<int> c(m + 1);
  for (int i = 1; i <= m; ++i) std::cin >> c[i];
  std::vector e(n + 1, std::vector<int>()), lim(m + 1, std::vector<int>(m + 1, 0));
  for (int i = 1, u, v; i < n; ++i) {
    std::cin >> u >> v;
    e[u].push_back(v);
    e[v].push_back(u);
  }
  for (int u, v; t; --t) {
    std::cin >> u >> v;
    lim[u][v] = 1;
  }
  std::vector g(n + 1, std::vector(m + 1, 0));
  auto Dfs = [&](auto &&dfs, int u, int pre) -> void {
    for (int col = 1; col <= n; ++col) g[u][col] = 1;
    for (auto v : e[u]) if (v != pre) {
      dfs(dfs, v, u);
      for (int curCol = 1; curCol <= m; ++curCol) {
        int sum = 0;
        for (int childCol = 1; childCol <= m; ++childCol) if (!lim[curCol][childCol]) {
          (sum += g[v][childCol]) %= p;
        }
        g[u][curCol] = 1ll * g[u][curCol] * sum % p;
      }
    }
  };
  Dfs(Dfs, 1, 0);
  int ans = 0;
  for (int col = 1; col <= m; ++col) {
    (ans += g[1][col]) %= p;
  }
  for (int tarCol = 1; tarCol <= m; ++tarCol) {
    std::vector f(n + 1, std::vector(m + 1, std::vector(n + 1, 0)));
    std::vector<int> sz(n + 1);
    auto dfs = [&](auto &&dfs, int u, int pre) -> void {
      sz[u] = 1;
      for (int curCol = 1; curCol <= m; ++curCol) {
        f[u][curCol][curCol == tarCol] = 1;
      }
      for (auto v : e[u]) if (v != pre) {
        dfs(dfs, v, u);
        for (int curCol = 1; curCol <= m; ++curCol) {
          std::vector curf(sz[u] + sz[v] + 1, 0);
          for (int childCol = 1; childCol <= m; ++childCol) if (!lim[curCol][childCol]) {
            for (int lsh = 0; lsh <= sz[u]; ++lsh) {
              for (int rsh = 0; rsh <= sz[v]; ++rsh) {
                curf[lsh + rsh] += 1ll * f[v][childCol][rsh] * f[u][curCol][lsh] % p;
                curf[lsh +rsh] %= p;
              }
            }
          }
          for (int i = 0; i <= sz[u] + sz[v]; ++i) f[u][curCol][i] = curf[i];
        }
        sz[u] += sz[v];
      }
    };
    dfs(dfs, 1, 0);
    for (int i = c[tarCol] + 1; i <= n; ++i) {
      for (int col = 1; col <= m; ++col) {
        ans -= f[1][col][i];
        ans = (ans + p) % p;
      }
    }
  }
  std::cout << ans << std::endl;
}

不知道下次把这个精品小套路掏出来扔到比赛里会是什么时间。