题解 P8504 No will to break

· · 题解

传送门

分析:

Subtask 1:
##### Subtask 2: 注意到 $a$ 较小,我们可以存下上述贪心中的所有状态。具体地,开一个 $a$ 位三进制数表示当前窗口状态,每一位上,$0$ 表示不安全,$1$ 表示安全但还没用于聚集,$2$ 表示已经用于聚集。移动窗口相当于把这个三进制数左移一位,在末尾加一个 $0$ 或 $1$,再根据新的三进制数中 $2$ 的数量来将最右侧的若干个 $1$ 变为 $2$。记 $dp_{i,s}$ 表示已经处理完 $[i-a,i-1]$ 及之前的窗口,现在处理 $[i-a+1,i]$,并且三进制数为 $s$ 的方案数,按上文所述转移到 $dp_{i+1,s'}\times x_i$ 或 $dp_{i+1,s'}\times y_i$ 即可。 预处理每一个三进制数的转移,复杂度 $\mathcal O(3^an)$,可能需要滚动数组。时限很松,常数稍微小一点应该就可以 AC 了。 ##### Subtask 3: 注意到每个窗口对应三进制数中 $2$ 的数量应当总是为 $b$,我们只需要遍历三进制下 $2$ 的数量为 $b$ 的数即可。同时,转移与三进制下第一位是什么无关,只与后 $a-1$ 位中 $1,2$ 的位置、个数有关,所以只需要存储 $a-1$ 位。此时状态数为 $\binom{a-1}{b-1}\times2^{a-b}+\binom{a-1}b\times2^{a-1-b}$,最大约为 $3584$,是 $3^a=19683$ 的约五分之一,非常可过。 --- 代码中用四进制来存题解中所说的三进制数。 ```cpp #include <bits/stdc++.h> #define rep(i, l, r) for(int i=l, _=r; i<=_; ++i) using namespace std; typedef long long ll; const int mN=1.5e4+9, mA=9+2, mT=1<<16|9, mS=4000; const int mod=998244353; inline int qpow(int x, int y) { int res=1; for(; y; y>>=1, x=(ll) x*x%mod) if(y&1) res=(ll) res*x%mod; return res; } #define pm(a) ((a)<mod? a: (a)-mod) #define inc(a, b) a=pm(a+b) int n, a, b, x[mN][2], sumx=1; int dp[2][mS][2]; int trans[mS][2]; bool pl[mS][2]; int p[mA], mp[mT], tot; void dfs(int stp) { //预处理出每个三进制数。 if(stp>=a) { int cnt2=0, t=0; rep(i, 1, a-1) cnt2+=p[i]==2, t|=p[i]<<(i-1)*2; if(cnt2<b-1 || cnt2>b) return; //共有 b 个 2,故后 a-1 位有 b-1 或 b 个 2 mp[t]=++tot; if(cnt2==b) { //不需要将 1 改为 2 trans[tot][0]=t>>2; trans[tot][1]=t>>2|1<<(a-2)*2; } else { //需要将 1 改为 2 int k; for(k=a-1; k>=1; --k) if(p[k]==1) break; trans[tot][1]=t>>2|2<<(a-2)*2, pl[tot][1]=1; if(!k) return trans[tot][0]=mT-1, void(); //没有 1 且下一位为 0 t+=1<<(k-1)*2; trans[tot][0]=t>>2, pl[tot][0]=1; } return; } rep(t, 0, 2) p[stp]=t, dfs(stp+1); } int q[mA]; void pre(int stp) { //初始化 [1, a] 的答案 if(stp>a) { int lft=b, sum=1, t=0; for(int i=a; i>=1; --i) { sum=(ll) sum*x[i][q[i]=p[i]]%mod; if(q[i]==1 && lft) q[i]=2, --lft; } if(lft) return; for(int i=2; i<=a; ++i) t|=q[i]<<(i-2)*2; inc(dp[a&1][mp[t]][0], sum), dp[a&1][mp[t]][1]=(dp[a&1][mp[t]][1]+(ll) sum*b)%mod; return; } rep(t, 0, 1) p[stp]=t, pre(stp+1); } int main() { scanf("%d%d%d", &n, &a, &b), dfs(1); rep(i, 1, tot) rep(t, 0, 1) trans[i][t]=mp[trans[i][t]]; rep(i, 1, n) scanf("%d%d", x[i]+1, x[i]), sumx=(ll) sumx*(x[i][0]+x[i][1])%mod; pre(1); rep(i, a+1, n) { memset(dp[i&1], 0, sizeof dp[i&1]); rep(j, 1, tot) if(dp[i&1^1][j][0] || dp[i&1^1][j][1])rep(t, 0, 1) { const int k=trans[j][t]; dp[i&1][k][0]=(dp[i&1][k][0]+(ll) x[i][t]*dp[i&1^1][j][0])%mod; dp[i&1][k][1]=(dp[i&1][k][1]+(ll) (dp[i&1^1][j][1]+ pl[j][t]*dp[i&1^1][j][0])*x[i][t])%mod; } } int ans=0; rep(j, 1, tot) inc(ans, dp[n&1][j][1]); printf("%lld\n", (ll) ans*qpow(sumx, mod-2)%mod); return 0; } ```