题解 P4910 【帕秋莉的手环】
囧仙
2020-04-13 16:06:54
经典的矩阵快速幂优化 dp,作为例题还是值得一做的
- **题意**
对一个 $n$ 个点组成的的环,每个点染色黄或绿,要求没有连续的两个绿色点,求方案数,答案对 $10 ^ 9 + 7$ 取模
- **简单的dp**
环上计数一般都是枚举第一个点的情况,然后断开在序列上用 dp 计数
首先肯定是设 $dp_i$ 表示前 $i$ 个点满足要求的方案数,而第 $i$ 个点能染什么颜色与 $i - 1$ 点有关,所以第 $i$ 个点的颜色也要加入状态
所以,设 $dp_{i,0}$ 表示第 $i$ 个点染黄色,前面全部符合要求的方案数, $dp_{i,1}$ 表示第 $i$ 个点染绿色,前面全部符合要求的方案数
$$dp_{i,0} = dp_{i - 1,0} + dp_{i - 1,1}$$
$$dp_{i,1} = dp_{i - 1,0}$$
然后要枚举第一个点的颜色,如果是黄色的话,那么 $dp_{1,0} = 1$,答案可以加上 $dp_{n,0} + dp_{n,1}$,如果是绿色,那么 $dp_{1,1} = 1$,答案可以加上 $dp_{n,0}$
代码:
```cpp
#include <cstdio>
#define ll long long
#define mod 1000000007
int T,n;
ll f[1000005][2];//f[i][0]表示最后一个金色,前面全部满足的方案,f[i][1]表示最后一个绿色,前面全部满足的方案
ll ans;
void work(){
scanf("%d",&n);
f[1][0] = 1;
f[1][1] = 0;
for(int i = 2;i <= n;i++){
f[i][0] = (f[i - 1][1] + f[i - 1][0]) % mod;
f[i][1] = f[i - 1][0];
}
ans = f[n][0] + f[n][1];
f[1][0] = 0;
f[1][1] = 1;
for(int i = 2;i <= n;i++){
f[i][0] = (f[i - 1][1] + f[i - 1][0]) % mod;
f[i][1] = f[i - 1][0];
}
ans += f[n][0];
printf("%lld\n",ans % mod);
}
int main(){
scanf("%d",&T);
while(T--) work();
return 0;
}
```
- **矩阵快速幂优化dp**
可以发现,$dp_i$ 实际上只由 $dp_{i - 1}$ 得到
所以说,我们可以把每次的转移写成一个乘上一个固定的矩阵
构造初始矩阵:
$$st = \begin{bmatrix}f_{i,0}&f_{i,1}\\0&0\end{bmatrix}$$
根据 dp 的转移,每次的转移矩阵就是:
$$nxt = \begin{bmatrix}1&1\\1&0\end{bmatrix}$$
这样,答案矩阵就是 $st \times nxt ^ {n - 1}$
这样做有一个好处,就是矩阵是满足结合律的,所以可以用快速幂的方式求出 $nxt ^ {n - 1}$
这样,我们就在 $O(T\log n)$ 的时间复杂度内求出了答案
代码:
```cpp
#include <cstdio>
#define ll long long
#define mod 1000000007
struct M{
ll a,b,c,d;
/*
[a,b]
[c,d]
*/
}st,nxt;
M operator * (M a,M b){
M c;
c.a = (a.a * b.a + a.b * b.c) % mod;
c.b = (a.a * b.b + a.b * b.d) % mod;
c.c = (a.c * b.a + a.d * b.c) % mod;
c.d = (a.c * b.b + a.d * b.d) % mod;
return c;
}
M power(M n,ll k){
M ans = {1,0,0,1};
while(k){
if(k % 2 == 1){
ans = n * ans;
}
n = n * n;
k /= 2;
}
return ans;
}
int T;
ll n,ans;
void solve(){
scanf("%lld",&n);
st = {1,0,0,0};
nxt = {1,1,1,0};
M res = st * power(nxt,n - 1);
ans = res.a + res.b;
st = {0,1,0,0};
nxt = {1,1,1,0};
res = st * power(nxt,n - 1);
ans += res.a;
printf("%lld\n",ans % mod);
}
int main(){
scanf("%d",&T);
while(T--) solve();
return 0;
}
```