【题解】P3176 [HAOI2015]数字串拆分
SUNCHAOYI
·
·
题解
先处理函数 f_i,有 f_i = \sum \limits _{j = i - m}^{i - 1} f_j,这个递推式显然可以通过矩阵乘法进行优化。设 F_i 表示通过递推函数 f_i 得到的矩阵,则有以下矩阵的递推(以 m = 5 为例):
F_i =
\begin{bmatrix}
f_i\\
f_{i - 1}\\
f_{i - 2}\\
f_{i - 3}\\
f_{i -4}
\end {bmatrix}
=
\begin{bmatrix}
1 & 1 & 1 & 1 & 1\\
1 & 0 & 0 & 0 & 0\\
0 & 1 & 0 & 0 & 0\\
0 & 0 & 1 & 0 & 0\\
0 & 0 & 0 & 1 & 0
\end {bmatrix}^k
\cdot
\begin{bmatrix}
f_{i - 1}\\
f_{i - 2}\\
f_{i - 3}\\
f_{i - 4}\\
f_{i - 5}
\end{bmatrix}
现在的复杂度为计算矩阵 $A_{i,j}$,不难发现 $A_{i,j} = F_{s_i^{10^{n - i}}} \times A_{i + 1,j}$,稍作变换得 $A_{i,j} = F_{({10^{n - i}})^{s_i}} \times A_{i + 1,j}$。这样只需要预处理出 $F_{10^i}$ 的矩阵即可。
两个小细节:
- 在写矩阵时用到了 `+` 与 `*` 的重载,这时候原来的优先级已经不复存在,因此需要通过括号来处理顺序。
- 矩阵的初始化要小心,需要避免未定义行为。
代码如下:
```cpp
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#define ll long long
#define ull unsigned long long
#define INF 0x3f3f3f3f
using namespace std;
const int MAX = 505;
const int MOD = 998244353;
int m,n;ll ans;
char s[MAX];
struct Mat
{
ll a[6][6];
Mat() {memset(a,0,sizeof(a));}//全部 memset 肯定没问题
Mat operator * (const Mat &y)
{
Mat z;
for (int k = 1;k <= m;++k)
for (int i = 1;i <= m;++i)
for (int j = 1;j <= m;++j)
z.a[i][j] += a[i][k] * y.a[k][j] % MOD,z.a[i][j] %= MOD;
return z;
}
Mat operator + (const Mat &y)
{
Mat z;
for (int i = 1;i <= m;++i)
for (int j = 1;j <= m;++j) z.a[i][j] = a[i][j] + y.a[i][j],z.a[i][j] %= MOD;
return z;
}
Mat qpow (Mat x,ll y)
{
Mat res;
for (int i = 1;i <= m;++i) res.a[i][i] = 1;
while (y)
{
if (y & 1) res = res * x;
x = x * x;
y >>= 1;
}
return res;
}
} p[MAX],f[MAX][MAX],g[MAX];
int main ()
{
//freopen (".in","r",stdin);
//freopen (".out","w",stdout);
scanf ("%s%d",s + 1,&m);
n = strlen (s + 1);
for (int i = 1;i <= m;++i) p[0].a[1][i] = 1;
for (int i = 2;i <= m;++i) p[0].a[i][i - 1] = 1;
for (int i = 1;i <= n;++i) p[i] = p[i].qpow (p[i - 1],10);//预处理 10^i 的矩阵
for (int j = 1;j <= n;++j)//扩展 需要注意一下 i = j 的特殊情况
for (int i = j;i;--i)//f[i][j] = f[i + 1][j] * f(10^(j - i))^(s_i)
f[i][j] = (i == j) ? f[i][j].qpow (p[0],s[i] - '0') : f[i + 1][j] * f[i][j].qpow (p[j - i],s[i] - '0');
for (int i = 1;i <= m;++i) g[0].a[i][i] = 1;
for (int i = 1;i <= n;++i)//g[i] 表示前 i 位的所有情况的答案
for (int j = 0;j < i;++j) g[i] = g[i] + (g[j] * f[j + 1][i]);//注意优先级
printf ("%lld\n",g[n].a[1][1]);//最后的答案显然是左上角的那个值
return 0;
}
```