题解 P5320 【[BJOI2019]勘破神机】

· · 题解

先看 m=2。众所周知的是 2\times n 填骨牌的方案是斐波那契数列 f_{n+1},其中 f_0=0,f_1=1, f_n=f_{n-1}+f_{n-2},而要求的东西就是 \sum_{n=l}^r{f_n\choose k} (为了方便我们假设方案是 f_n 而不是 f_{n+1},反正只需要把 l,r 都加上 1 就可以了)

然而斐波那契数列的通项公式为

f_n=\frac1{\sqrt5}\left(\frac{1+\sqrt5}{2}\right)^n-\frac1{\sqrt5}\left(\frac{1-\sqrt5}{2}\right)^n

A=\frac1{\sqrt5},B=-\frac1{\sqrt5},x=\frac{1+\sqrt5}{2},y=\frac{1-\sqrt5}{2},那么有 f_n=Ax^n+Bx^n

再加上下降幂与阶乘幂之间的转换

x^{\underline k}=\sum_{i=0}^k(-1)^{k-i}s(k,i)x^i

(其中 s(k,i) 为第一类斯特林数)

于是

\begin{aligned}&\sum_{n=l}^r{f_n\choose k}\\=&\frac1{k!}\sum_{n=l}^r\sum_{i=0}^k(-1)^{k-i}s(k,i)f_n^i\\=&\frac1{k!}\sum_{n=l}^r\sum_{i=0}^k(-1)^{k-i}s(k,i)(Ax^n+By^n)^i\\=&\frac1{k!}\sum_{n=l}^r\sum_{i=0}^k(-1)^{k-i}s(k,i)\sum_{j=0}^i{i\choose j}A^jB^{i-j}(x^jy^{i-j})^n\\=&\frac1{k!}\sum_{i=0}^k(-1)^{k-i}s(k,i)\sum_{j=0}^i{i\choose j}A^jB^{i-j}\sum_{n=l}^r(x^jy^{i-j})^n\\\end{aligned}

只需要枚举 i,j,然后后面那个 \sum 就是一个等比数列。

不过还有一个问题就是, \sqrt5 这个东西在 \bmod998244353 意义下不存在。

我们对剩余系进行扩域,把每个数表示成 a+b\sqrt5 的形式,就可以做了。

对于 m=3,可以发现 n 是奇数的时候方案数是 0,而偶数的时候(令 g_n 表示 2n 的答案)可以搞出递推式 g_n=4g_{n-1}-g_{n-2},从而解出通项公式 g_n=\frac{3+\sqrt3}6(2+\sqrt3)^n+\frac{3-\sqrt3}6(2-\sqrt3)^n,仍然用上述方法计算即可。

(不过有一个本题中没什么用处的就是,实际上模素数意义下当 \sqrt{x} 不存在的时候 \sqrt{-x} 一定是存在的,因此可以将 a+b\sqrt x 写成 a+b'i,其中 b'=b\sqrt{-x}i^2=1 (模意义复数!))

附代码:

#include <algorithm>
#include <cstdio>
#include <cstring>

typedef long long LL;
const int mod = 998244353;

LL pow_mod(LL a, LL b) {
  LL ans = 1;
  for (a %= mod; b; b >>= 1, a = a * a % mod)
    if (b & 1) ans = ans * a % mod;
  return ans;
}

template <int a>
struct Sq { // x + y*sqrt(a) (mod 998244353)
typedef Sq<a> sq;

LL x, y;
Sq(LL x = 0, LL y = 0) : x(x % mod), y(y % mod) {}

inline sq conj() const { return sq(x, -y); }
// assume a is small (a * mod * mod < 2^63)
friend inline sq operator+(const sq &l, const sq &r) { return sq(l.x + r.x, l.y + r.y); }
friend inline sq operator-(const sq &l, const sq &r) { return sq(l.x - r.x, l.y - r.y); }
friend inline sq operator*(const sq &l, const sq &r) { return sq(l.x * r.x + a * l.y * r.y, l.x * r.y + l.y * r.x); }
friend inline sq operator/(const sq &l, const sq &r) { return sq(l.x * r.x - a * l.y * r.y, l.y * r.x - l.x * r.y) * pow_mod(r.x * r.x - a * r.y * r.y, mod - 2); }
inline sq& operator+=(const sq &t) { x = (x + t.x) % mod; y = (y + t.y) % mod; return *this; }
inline sq& operator-=(const sq &t) { x = (x - t.x) % mod; y = (y - t.y) % mod; return *this; }
inline sq& operator*=(const sq &t) { return *this = *this * t; }
sq operator^(LL t) const {
  sq ans = 1, x = *this;
  for (; t; t >>= 1, x *= x) if (t & 1) ans *= x;
  return ans;
}
inline sq sum_pow(LL t) { return (1 + this->pow(t + 1)) / (1 + *this); } // (x^(t+1)-1) / (x-1)
inline void put() { printf("%lld+%llds%d", (x + mod) % mod, (y + mod) % mod, a); }
};

const int K = 550;
LL C[K][K], ss[K][K], fac[K];

template <int a>
Sq<a> Solve(Sq<a> A, Sq<a> p, Sq<a> B, Sq<a> q, int k, LL l, LL r) {
  ++r;
  Sq<a> ans = 0, qr = q^r, ql = q^l, pr = p^r, pl = p^l;
  Sq<a> ppl = 1, ppr = 1, pp = 1, pA = 1;
  for (int u = 0; u <= k; ++u) {
    Sq<a> pql = 1, pqr = 1, pq = 1, pB = 1;
    for (int v = 0; u + v <= k; ++v) {
      Sq<a> quq = pp * pq - 1, qvq = ppr * pqr - ppl * pql;
      Sq<a> qaq = !quq.x && !quq.y ? (r - l) % mod : qvq / quq;
      Sq<a> qwq = ss[k][u + v] * C[u + v][u] % mod * pA * pB * qaq;
      ans += qwq;
      pB *= B; pql *= ql; pqr *= qr; pq *= q;
    }
    pA *= A; ppl *= pl; ppr *= pr; pp *= p;
  }
  return ans * pow_mod(fac[k], mod - 2);
}

void Init() {
  for (int i = 0; i < K; ++i)
    for (int j = C[i][0] = 1; j <= i; ++j)
      C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % mod;
  fac[0] = 1;
  for (int i = 1; i < K; ++i) fac[i] = i * fac[i - 1] % mod;
  ss[0][0] = 1;
  for (int i = 0; i + 1 < K; ++i) {
    ss[i + 1][0] = - i * ss[i][0];
    for (int j = 1; j <= i + 1; ++j)
      ss[i + 1][j] = (ss[i][j - 1] - i * ss[i][j]) % mod;
  }
}

int main() {
  Init();
  int T, m;
  scanf("%d%d", &T, &m);
  while (T--) {
    LL l, r; int k;
    scanf("%lld%lld%d", &l, &r, &k);
    if (m == 2) {
      Sq<5> p = Sq<5>(1, 1) / 2, t = Sq<5>(0, 1) / 5;
      printf("%lld\n", (Solve(t, p, t.conj(), p.conj(), k, l + 1, r + 1).x * pow_mod(r - l + 1, mod - 2) % mod + mod) % mod);
    } else {
      Sq<3> p = Sq<3>(2, 1), t = Sq<3>(3, 1) / 6;
      printf("%lld\n", (Solve(t, p, t.conj(), p.conj(), k, (l + 1) / 2, r / 2).x * pow_mod(r - l + 1, mod - 2) % mod + mod) % mod);
    }
  }
}