题解 P5296 【[北京省选集训2019]生成树计数】

· · 题解

把这个式子化开就相当于

(w_1+w_2+\cdots +w_m)^k = k![z^k]\prod_{i=1}^m \mathrm{e}^{w_iz}

注意矩阵树定理的内容扩展一下实际上:Laplacian 矩阵的主余子式等于 \sum_T \prod_{(u, v)\in T} w(u,v),因此只需要我们构造一个多项式矩阵 w(u,v) = \sum_{i=0}^k \frac{(wz)^i}{i!} ,运算时仅保留至第 k 次然后用求生成树数量的方法算行列式就行了。复杂度 \Theta(n^3 k^2)

#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <ctime>
#include <cctype>

#include <algorithm>
#include <random>
#include <bitset>
#include <queue>
#include <functional>
#include <set>
#include <map>
#include <vector>
#include <chrono>
#include <iostream>
#include <limits>
#include <numeric>

#define LOG(FMT...) fprintf(stderr, FMT)

using namespace std;

typedef long long ll;
typedef unsigned long long ull;

// mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

const int N = 32, P = 998244353;

int n, k;
int ifac[N], res[N];
int a[N][N];
int g[N][N][N];

int norm(int x) { return x >= P ? x - P : x; }

void println(int* a, FILE* fp) {
  for (int i = 0; i < k; ++i)
    fprintf(fp, "%d ", a[i]);
  fprintf(fp, "%d\n", a[k]);
}

void exGcd(int a, int b, int& x, int& y) {
  if (!b) {
    x = 1;
    y = 0;
    return;
  }
  exGcd(b, a % b, y, x);
  y -= a / b * x;
}

int inv(int a) {
  int x, y;
  exGcd(a, P, x, y);
  return norm(P + x);
}

int* mul(int* a, int* b, int* out) {
  static int tmp[N];
  memset(tmp, 0, sizeof(tmp));
  for (int i = 0; i <= k; ++i)
    for (int j = 0; j <= k - i; ++j)
      tmp[i + j] = (tmp[i + j] + a[i] * (ll)b[j]) % P;
  memcpy(out, tmp, sizeof(tmp));
  return tmp;
}

int* pinv(int* a, int* out) {
  static int b[N], tmp[N];
  memset(tmp, 0, sizeof(tmp));
  tmp[0] = inv(a[0]);
  for (int i = 1; i <= k; ++i)
    b[i] = a[i] * (ll)tmp[0] % P;
  for (int i = 1; i <= k; ++i)
    for (int j = 1; j <= i; ++j)
      tmp[i] = (tmp[i] + (P - b[j]) * (ll)tmp[i - j]) % P;
  memcpy(out, tmp, sizeof(tmp));
  return out;
}

int main() {
  scanf("%d%d", &n, &k);
  ifac[1] = 1;
  for (int x = 2; x <= k; ++x)
    ifac[x] = -(P / x) * (ll)ifac[P % x] % P + P;
  ifac[0] = 1;
  for (int x = 1; x <= k; ++x)
    ifac[x] = ifac[x - 1] * (ll)ifac[x] % P;
  for (int i = 1; i <= n; ++i)
    for (int j = 1; j <= n; ++j)
      scanf("%d", &a[i][j]);
  for (int i = 1; i <= n; ++i)
    for (int j = 1; j < i; ++j) {
      int pw = 1;
      for (int t = 0; t <= k; ++t) {
        int v = pw * (ll)ifac[t] % P;
        g[i][j][t] = g[j][i][t] = norm(P - v);
        g[i][i][t] = norm(g[i][i][t] + v);
        g[j][j][t] = norm(g[j][j][t] + v);
        pw = pw * (ll)a[i][j] % P;
      }
    }

  res[0] = 1;
  --n;
  for (int i = 1; i <= n; ++i) {
    static int cur[N];
    mul(g[i][i], res, res);
    pinv(g[i][i], cur);
    for (int j = n; j >= i; --j)
      mul(g[i][j], cur, g[i][j]);
    memset(g[i][i], 0, sizeof(g[i][i]));
    g[i][i][0] = 1;
    for (int j = n; j > i; --j) {
      for (int t = n; t >= i; --t) {
        mul(g[i][t], g[j][i], cur);
        for (int l = 0; l <= k; ++l)
          g[j][t][l] = norm(g[j][t][l] + P - cur[l]);
      }
    }
  }
  int ans = res[k];
  for (int i = 1; i <= k; ++i)
    ans = ans * (ll)i % P;
  printf("%d\n", ans);
  return 0;
}