题解:AT_abc279_g [ABC279G] At Most 2 Colors
Wy_x
·
·
题解
Solution:
对于本题,我们将染色视为填数。具体地,对于每个位置,染第 i 种色视为填数 i。
观察题目,我们设 dp[i][j] 表示前 i 个数中,末尾相同的数恰好有 j 个的合法方案数。
换言之,区间 [i-j+1,i] 中的数全部相同,而若 j \ne k-1,还需满足 a[i-j] \ne a[i]。
令 j \le \min(i,k-1)。若 j> k-1 ,视为 j=k-1。j 上界不设为 k 的原因是 dp[i][k] 不会参与后续转移,k-1 就够用了。j>k 同理。而且这样可以简化后续转移。
最终答案为 ans=\sum_{j=1}^{k-1} dp[n][j]。
设 a[i] 表示第 i 个数填 a[i]。
容易得到非边界情况的转移,就是在序列末尾增添一个与前 j-1 个数相同的数,方案数不变:
dp[i][j]=dp[i-1][j-1]
边界情况:
dp[i][k-1]=dp[i-1][k-2]+dp[i-1][k-1]
我们最大只记录到 k-1,如果原先末尾已经有 k-1 个相同的数,再加一个相同的数,末尾就有 k 个相同的数,但我们视为 k-1 个。
\begin{equation*}
\begin{aligned}
dp[i][1] &= \sum_{j=1}^{\min(i-1,k-2)}{dp[i-1][j]}+dp[i-1][\min(i,k-1)] \times (c-1) \\
&= \sum_{j=1}^{\min(i-1,k-2)}{dp[i-1][j]}+dp[i-1][\min(i,k-1)]+dp[i-1][\min(i,k-1)] \times (c-2) \\
&= \sum_{j=1}^{\min(i,k-1)}{dp[i-1][j]}+dp[i-1][\min(i,k-1)] \times (c-2) \\
\end{aligned}
\end{equation*}
现在末尾只有一个相同的数,有两种情况:当 j < \min(i,k-1) 时,我们只能填 a[i-j];当 j = \min(i,k-1) 时,这个数可以填除 a[i] 的所有数,也就是共 c-1 种方案。
发现每次转移只会有 dp[i][1] 和 dp[i][k-1] 会发生实质性修改。由 dp[i][j]=dp[i-1][j-1],考虑网格图表示。
以 n=7 , k=6 为例。注意红边转移时需要乘 c-1。
可以转化为
这样我们按上图转移,每次至多只会有 2 个 dp 值会被修改,时间复杂度为 O(n),可以通过。
设 s=\sum_{j=1}^{\min(i,k-1)}{dp[i-1][j]}。我们需要快速求出 s。
由第二张图可以发现,每次 s 的变化量为 dp[i][1]。所以对于每个 i,我们只需要让 s = s+dp[i][1] 即可。
注意,若省略第一维,$dp$ 数组空间要开两倍。
综上,这道题就做完了。
## Code:
```cpp
#include<bits/stdc++.h>
#define int long long
using namespace std;
inline int read()
{
int x=0,c=getchar(),f=0;
for(;c>'9'||c<'0';f=c=='-',c=getchar());
for(;c>='0'&&c<='9';c=getchar())
x=(x<<1)+(x<<3)+(c^48);
return f?-x:x;
}
inline void write(int x)
{
if(x<0) x=-x,putchar('-');
if(x>9) write(x/10);
putchar(x%10+'0');
}
const int N=1<<20;
int n,k,c;
int dp[N<<1];
const int mod=998244353;
int ksm(int x,int p)
{
int ans=1;
for(int i=1;i<=p;i++) ans*=x,ans%=mod;
return ans;
}
signed main()
{
n=read();
k=read();
c=read();
if(k==2||c==2) // 特判,防止不必要的锅
{
cout<<ksm(c,n)<<"\n";
return 0;
}
// 左右边界,区间长度为 k-1
int l=n,r=n+k-2;
dp[l]=c;
int sum=c;
for(int i=2;i<=n;i++)
{
l--;
r--;
dp[l]=(dp[min(r+1,n)]*(c-2)+sum)%mod; // 更新 dp[i][1]
dp[r]=(dp[r]+dp[r+1])%mod; // 更新 dp[i][k-1]
sum+=dp[l]; // sum 的变化量为 dp[i][1]
sum%=mod;
}
int ans=0;
for(int i=l;i<=r;i++) ans+=dp[i];
cout<<ans%mod<<"\n";
return 0;
}
/*
1 1 1
2 2 2
3 3 3
1 2 2
1 3 3
2 1 1
2 3 3
3 1 1
3 2 2
1 1 2
1 1 3
1 2 1
1 3 1
2 1 2
2 2 1
2 2 3
2 3 2
3 1 3
3 2 3
3 3 1
3 3 2
*/
```