题解 P5293 【[HNOI2019]白兔之舞】

· · 题解

upd : 把所有 C 的符号换成了 binom

蒟蒻语

这是蒟蒻的第 500 道 AC, 第 30 道黑题, 写篇题解祭一下 qwq

开始全 T 以为是自己打挂了, 结果开了 O2 就 AC 了qwq

蒟蒻解

假设白兔跳了 i 次。

那么如果枚举选择这 i 个点跳这 i 次有多少选择 x 坐标的方案(白兔每一个点坐标为 (x, y) )。

显然是从 L 个中选择 i 个, 就是 \binom{L}{i}

如果要求白兔跳了 i 次有多少 (x, y) 的方案数 f_i

x, y 是互相不影响的, 那么现在只要算跳 i 次时这 iy 的方案个数了。

那么考虑 dp[i][y] 表示选择 i 个点中选择 y 坐标的方案个数。

dp[i][j] = \sum\limits_{k=1}^{n} dp[i-1][k] * w[k][j]

显然用矩阵优化

现在定义 W 是转移矩阵, 可以直接用题目中的 w

那么 $f_i = \binom{L}{i} * dp[i][y] ans_t = \sum\limits_{i = 0}^L [i \ mod \ k = t] f_i = \sum\limits_{i=0}^L [k | (i - t)] f_i = \sum\limits_{i=0}^L f_i \frac{1}{k} \sum\limits_{j = 0}^{k-1} \omega_k^{j(i-t)} = \frac{1}{k} \sum\limits_{i=0}^L f_i \sum\limits_{j = 0}^{k-1}(\omega_k^{-it} \omega_k^{ij}) = \frac{1}{k} \sum\limits_{j = 0}^{k-1}\omega_k^{-jt} \sum\limits_{i=0}^L \omega_{k}^{ij} f_i = \frac{1}{k} \sum\limits_{j = 0}^{k-1}\omega_k^{-jt} \sum\limits_{i=0}^L \omega_{k}^{ij} \binom{L}{i} * G * W^i \ \ \text{的第 $y$ 项} = \frac{1}{k} \sum\limits_{j = 0}^{k-1} G * \omega_k^{-jt} \sum\limits_{i=0}^L (\omega_{k}^{j})^i \binom{L}{i} * W^i \ \ \text{的第 $y$ 项} = \frac{1}{k} \sum\limits_{j = 0}^{k-1} G * \omega_k^{-jt}(\omega_{k}^{j} W+I)^L \ \ \text{的第 $y$ 项}

h_i = G(\omega_{k}^{j} W+I)^L \ \ \text{的第 y 项}

ans_t = \frac{1}{k} \sum\limits_{i = 0}^{k-1} \omega_k^{-it} h_i

因为 it = \binom{i+t}{2} - \binom{i}{2} - \binom{t}{2}

ans_t = \frac{1}{k} \sum\limits_{i = 0}^{k-1} \omega_k^{- \binom{i+t}{2} + \binom{i}{2} + \binom{t}{2}} h_i ans_t = \frac{1}{k} \sum\limits_{i = 0}^{k-1} \omega_k^{-\binom{i+t}{2}} \omega_{k}^{\binom{t}{2}}\omega_k^{\binom{i}{2}} h_i ans_t = \frac{\omega_k^{\binom{t}{2}}}{k}\sum\limits_{i = 0}^{k-1} \omega_k^{-\binom{i + t}{2}} \omega_k^{\binom{i}{2}} h_i

F(i) = \frac{\omega_k^{\binom{i}{2}}}{k}, P(i) = \omega_k^{\binom{i}{2}} h_i, S(i) = \omega_k^{-\binom{i}{2}}

ans_t = F(t) \sum\limits_{i = 0}^{k - 1}P(i)S(i + t)

接下来是 reverse 一下 S, ans \ \ (S, ans\text{长度为}2k)

ans_{2k - t - 1} = F(t) \sum\limits_{i=0}^{k-1} P(i)S(2k - 1 - i - t) ans_t = F(2k - t - 1) \sum\limits_{i=0}^{k-1} P(i) S(t - i)

由于 t > k - 1, 所以:

ans_t = F(2k - t - 1) \sum\limits_{i=0}^{t} P(i) S(t - i)

直接 \tt MTT 卷就好了

蒟蒻码

#include<bits/stdc++.h>
#define db long double
#define N 800010
#define ll long long
using namespace std;
int n, k, L, x, y, mod, pp[N];
struct CP {
    db x, y;
    CP (db xx = 0, db yy = 0) {
        x = xx, y = yy;
    }
};
int fac[123], cnt, omg, omk;
void getfac(int x) {
    for(int i = 2; i <= sqrt(x); i++) 
        if(x % i == 0) {
            fac[++cnt] = i;
            while(x % i == 0) x /= i;
        }
}
int qpow(int x, int y, int pmod) {
    int res = 1;
    if(x == 0) return 0;
    for(; y; x = 1ll * x * x % pmod, y >>= 1) if(y & 1) res = 1ll * res * x % pmod;
    return res;
}
bool cheak(int x) {
    for(int i = 1; i <= cnt; i++) 
        if(qpow(x, (mod - 1) / fac[i], mod) == 1) return 0;
    return 1;
}
void find() {
    getfac(mod - 1);
    for(int i = 2; i < mod; i++) 
        if(cheak(i)) {
            omg = i;
            return;
        } 
}
db pi = acos(-1);
CP operator + (CP aa, CP bb) {
    return CP(aa.x + bb.x, aa.y + bb.y);
}
CP operator - (CP aa, CP bb) {
    return CP(aa.x - bb.x, aa.y - bb.y);
}
CP operator * (CP aa, CP bb) {
    return CP(aa.x * bb.x - aa.y * bb.y, aa.x * bb.y + aa.y * bb.x);
}
CP operator / (CP aa, db bb) {
    return CP(aa.x / bb, aa.y / bb);
}
struct Matrix {
    int a[4][4];
} S;
Matrix operator * (Matrix aa, Matrix bb) {
    Matrix res;
    for(int i = 1; i <= n; i++) for(int j = 1; j <= n; j++) res.a[i][j] = 0;
    for(int i = 1; i <= n; i++) 
        for(int j = 1; j <= n; j++) 
            for(int k = 1; k <= n; k++) 
                (res.a[i][j] += 1ll * aa.a[i][k] * bb.a[k][j] % mod) %= mod;
    return res;
}
Matrix operator * (Matrix aa, int bb) {
    Matrix res;
    for(int i = 1; i <= n; i++) 
        for(int j = 1; j <= n; j++) 
            res.a[i][j] = 1ll * aa.a[i][j] * bb % mod;
    return res;
}
Matrix ps(Matrix aa) {
    Matrix res;
    for(int i = 1; i <= n; i++) 
        for(int j = 1; j <= n; j++) 
            res.a[i][j] = (aa.a[i][j] + (i == j)) % mod;
    return res;
}
Matrix operator ^ (Matrix aa, int bb) {
    Matrix res;
    for(int i = 1; i <= n; i++) for(int j = 1; j <= n; j++) res.a[i][j] = 0;
    for(int i = 1; i <= n; i++) res.a[i][i] = 1;
    for(; bb; aa = aa * aa, bb >>= 1) if(bb & 1) res = res * aa;
    return res;
}
void FFT(CP *f, int len, int flag) {
    for(int i = 0; i < len; i++) if(pp[i] < i) swap(f[pp[i]], f[i]);
    for(int i = 2; i <= len; i <<= 1) {
        int FL = (i >> 1);
        CP ch(cos(2 * pi / i), flag * sin(2 * pi / i));
        for(int j = 0; j < len; j += i) {
            CP now(1, 0);
            for(int k = j; k < j + FL; k++) {
                CP Ga = f[k], Gb = f[k + FL] * now;
                f[k] = Ga + Gb, f[k + FL] = Ga - Gb, now = now * ch;
            }
        }
    }
    if(flag == -1) for(int i = 0; i < len; i++) f[i].x = f[i].x / len;
}
CP fa[N], fb[N], ga[N], gb[N], ssB[N], sB[N], B[N];
void MTT(int *a, int *b, int *c, int len) {
    int flen;
    for(flen = 1; flen <= len * 2; flen <<= 1);
    for(int i = 0; i < flen; i++) pp[i] = ((pp[i >> 1] >> 1) | ((i & 1) * (flen >> 1)));
    for(int i = 0; i < flen; i++) fa[i].x = (a[i] >> 15), fb[i].x = (a[i] & 32767);
    for(int i = 0; i < flen; i++) ga[i].x = (b[i] >> 15), gb[i].x = (b[i] & 32767);
    FFT(fa, flen, 1), FFT(fb, flen, 1), FFT(ga, flen, 1), FFT(gb, flen, 1);
    for(int i = 0; i < flen; i++) {
        ssB[i] = fa[i] * ga[i];
        sB[i] = (fa[i] * gb[i]) + (fb[i] * ga[i]);
        B[i] = fb[i] * gb[i];
    }
    FFT(ssB, flen, -1), FFT(sB, flen, -1), FFT(B, flen, -1);
    for(int i = 0; i < flen; i++) {
        ll dssB = (ll)(ssB[i].x + 0.49) % mod;
        ll dsB = (ll)(sB[i].x + 0.49) % mod;
        ll dB = (ll)(B[i].x + 0.49) % mod;
        c[i] = (0ll + (dssB << 30) % mod + (dsB << 15) % mod + dB) % mod;
    }
}
int C(int x) {
    return 1ll * x * (x - 1) / 2 % (mod - 1); 
}
int h[N], F[N], P[N], s[N], ans[N];
void init() {
    omk = qpow(omg, (mod - 1) / k, mod);
    int nyk = qpow(k, mod - 2, mod);
    for(int i = 0; i < k; i++) {
        h[i] = (ps((S * qpow(omk, i, mod))) ^ L).a[x][y];
        P[i] = 1ll * qpow(omk, C(i), mod) * h[i] % mod;
    }
    k <<= 1;
    for(int i = 0; i < k; i++) {
        F[i] = 1ll * qpow(omk, C(i), mod) * nyk % mod;
        s[i] = qpow(omk, mod - 1 - C(i), mod);
    }
    reverse(s, s + k);
    MTT(P, s, ans, k);
    // for(int t = 0; t < k; t++) 
        // for(int i = 0; i <= t; i++) 
            // (ans[t] += 1ll * P[i] * s[t - i] % mod) %= mod;
    reverse(ans, ans + k);
    for(int i = 0; i < (k >> 1); i++) ans[i] = 1ll * ans[i] * F[i] % mod, printf("%d\n", ans[i]);
}
int main() {
    scanf("%d%d%d%d%d%d", &n, &k, &L, &x, &y, &mod), find();
    for(int i = 1; i <= n; i++) 
        for(int j = 1; j <= n; j++) 
            scanf("%d", &S.a[i][j]);
    init();
    return 0;
}