题解:AT_abc279_g [ABC279G] At Most 2 Colors

· · 题解

Solution:

对于本题,我们将染色视为填数。具体地,对于每个位置,染第 i 种色视为填数 i

观察题目,我们设 dp[i][j] 表示前 i 个数中,末尾相同的数恰好j 个的合法方案数。

换言之,区间 [i-j+1,i] 中的数全部相同,而若 j \ne k-1,还需满足 a[i-j] \ne a[i]

j \le \min(i,k-1)。若 j> k-1 ,视为 j=k-1j 上界不设为 k 的原因是 dp[i][k] 不会参与后续转移,k-1 就够用了。j>k 同理。而且这样可以简化后续转移。

最终答案为 ans=\sum_{j=1}^{k-1} dp[n][j]

a[i] 表示第 i 个数填 a[i]

容易得到非边界情况的转移,就是在序列末尾增添一个与前 j-1 个数相同的数,方案数不变:

dp[i][j]=dp[i-1][j-1]

边界情况:

dp[i][k-1]=dp[i-1][k-2]+dp[i-1][k-1]

我们最大只记录到 k-1,如果原先末尾已经有 k-1 个相同的数,再加一个相同的数,末尾就有 k 个相同的数,但我们视为 k-1 个。

\begin{equation*} \begin{aligned} dp[i][1] &= \sum_{j=1}^{\min(i-1,k-2)}{dp[i-1][j]}+dp[i-1][\min(i,k-1)] \times (c-1) \\ &= \sum_{j=1}^{\min(i-1,k-2)}{dp[i-1][j]}+dp[i-1][\min(i,k-1)]+dp[i-1][\min(i,k-1)] \times (c-2) \\ &= \sum_{j=1}^{\min(i,k-1)}{dp[i-1][j]}+dp[i-1][\min(i,k-1)] \times (c-2) \\ \end{aligned} \end{equation*}

现在末尾只有一个相同的数,有两种情况:当 j < \min(i,k-1) 时,我们只能填 a[i-j];当 j = \min(i,k-1) 时,这个数可以填除 a[i] 的所有数,也就是共 c-1 种方案。

发现每次转移只会有 dp[i][1]dp[i][k-1] 会发生实质性修改。由 dp[i][j]=dp[i-1][j-1],考虑网格图表示。

n=7 , k=6 为例。注意红边转移时需要乘 c-1

可以转化为

这样我们按上图转移,每次至多只会有 2dp 值会被修改,时间复杂度为 O(n),可以通过。

s=\sum_{j=1}^{\min(i,k-1)}{dp[i-1][j]}。我们需要快速求出 s

由第二张图可以发现,每次 s 的变化量为 dp[i][1]。所以对于每个 i,我们只需要让 s = s+dp[i][1] 即可。

注意,若省略第一维,$dp$ 数组空间要开两倍。 综上,这道题就做完了。 ## Code: ```cpp #include<bits/stdc++.h> #define int long long using namespace std; inline int read() { int x=0,c=getchar(),f=0; for(;c>'9'||c<'0';f=c=='-',c=getchar()); for(;c>='0'&&c<='9';c=getchar()) x=(x<<1)+(x<<3)+(c^48); return f?-x:x; } inline void write(int x) { if(x<0) x=-x,putchar('-'); if(x>9) write(x/10); putchar(x%10+'0'); } const int N=1<<20; int n,k,c; int dp[N<<1]; const int mod=998244353; int ksm(int x,int p) { int ans=1; for(int i=1;i<=p;i++) ans*=x,ans%=mod; return ans; } signed main() { n=read(); k=read(); c=read(); if(k==2||c==2) // 特判,防止不必要的锅 { cout<<ksm(c,n)<<"\n"; return 0; } // 左右边界,区间长度为 k-1 int l=n,r=n+k-2; dp[l]=c; int sum=c; for(int i=2;i<=n;i++) { l--; r--; dp[l]=(dp[min(r+1,n)]*(c-2)+sum)%mod; // 更新 dp[i][1] dp[r]=(dp[r]+dp[r+1])%mod; // 更新 dp[i][k-1] sum+=dp[l]; // sum 的变化量为 dp[i][1] sum%=mod; } int ans=0; for(int i=l;i<=r;i++) ans+=dp[i]; cout<<ans%mod<<"\n"; return 0; } /* 1 1 1 2 2 2 3 3 3 1 2 2 1 3 3 2 1 1 2 3 3 3 1 1 3 2 2 1 1 2 1 1 3 1 2 1 1 3 1 2 1 2 2 2 1 2 2 3 2 3 2 3 1 3 3 2 3 3 3 1 3 3 2 */ ```