题解 P5343 【【XR-1】分块】

· · 题解

题目链接:

题目

题目大意:

给定两个数组 A,B,现在要给 n 分组,每块的长度必须在两个数组的交集里(即 \sum_{i=1}a_i=n\quad(a_i\in A\cap B)),求方案数。

正文:

很明显的一个 DP,设 f_i 表示长度为 i 的分组的方案数,那么这么转移:

f_i=\sum_{j=1}^{m}f_{i-a_j} 时间复杂度:$O(nm)$,$n\leq10^{18}$,这玩意过不去。 考虑用矩阵乘法加速递推。 按照矩乘加速递推的尿性,我们可以用一个 $m\times 1$ 的矩阵来表示状态: $$\begin{bmatrix} f_{m}\\ f_{m-1}\\ \vdots\\ f_1 \end{bmatrix}$$ 要想办法转移到 $\begin{bmatrix} f_{m+1}\\ f_{m}\\ \vdots\\ f_3\\ f_2 \end{bmatrix}$,必须再构造一个 $m\times m$ 的转移矩阵。 如果要转移到 $f_{m+1}$,根据转移方程 $f_i=\sum_{j=1}^{m}f_{i-a_j}$ 和矩阵乘法的运算可以推断转移矩阵的第一行(转移到 $f_{m+1}$)肯定和 $A\cap B$ 相关。 拿样例为例。 样例中 $A\cap B = \{1,2\}$,转移矩阵第一行相应的位置就是 $1$,比如样例的转移矩阵的第一行就是 $\begin{bmatrix}1&1&0&0&\cdots&0\end{bmatrix}$。因为这样,$f_{m+1}$ 就能被转移得到: $$\begin{bmatrix} f_{m}\\ f_{m-1}\\ \vdots\\ f_1 \end{bmatrix} \times\begin{bmatrix}1&1&0&0&\cdots&0\end{bmatrix}=\begin{bmatrix}f_{m+1-1}\times 1+f_{m+1-2}\times 1+f_{m+1-2}\times0+\cdots+f_{1+1-1}\times0\end{bmatrix}=\begin{bmatrix}f_{m+1}\end{bmatrix}$$ 接下来 $m-1$ 行就更好办了,由于目标矩阵第 $2$ 到第 $m$ 行就是初始矩阵的第 $1$ 到第 $m-1$ 行,那转移矩阵就是: $$\begin{bmatrix}1&1&\cdots&0&0\\ 1&0&\cdots&0&0\\ 0&1&\cdots&0&0\\ \vdots&\vdots&\ddots&\vdots&\vdots\\ 0&0&\cdots&1&0\end{bmatrix}$$ 再套个矩阵快速幂就A了。 # 代码: ```cpp int f[N]; ll n; bool PR[N], NF[N]; struct matrix { ll mat[N][N]; int n, m; matrix(){memset(mat, 0, sizeof mat);} inline ll* operator [] (int b) { return mat[b];} }F, stp; inline matrix operator*(matrix &a, matrix &b) { matrix c; c.n = a.n, c.m = b.m; for (int i = 1; i <= a.n; i++) for (int j = 1; j <= b.m; j++) for (int k = 1; k <= a.m; k++) c[i][j] = (c[i][j] + (a[i][k] * b[k][j]) % mod) % mod; return c; } matrix qpow(matrix stp, ll b) { matrix ans; ans.n = ans.m = stp.n; for (int i = 1; i <= ans.n; i++) ans[i][i] = 1; for (; b; b >>= 1) { if(b & 1) ans = ans * stp; stp = stp * stp; } return ans; } int m; int main() { scanf ("%lld", &n); int pr;scanf ("%d", &pr); for (int i = 1, x; i <= pr; i++) scanf ("%d", &x), PR[x] = 1; int nf;scanf ("%d", &nf); for (int i = 1, x; i <= nf; i++) scanf ("%d", &x), NF[x] = (1 & PR[x]), m = (NF[x] == 1? max(x, m): m); stp.n = stp.m = F.n = m, F.m = 1; f[0] = 1; for (int i = 1; i < m; i++) { for (int j = 1; j <= i; j++) if(NF[j]) f[i] = (f[i] + f[i - j]) % mod; F[m - i][1] = f[i]; } F[m][1] = 1; for (int i = 1; i <= m; i++) stp[1][i] = NF[i], stp[i + 1][i] = 1; stp[m + 1][m] = 0; F = qpow(stp, n - m + 1ll) * F; printf("%lld", F[1][1]); return 0; } ```