题解:P14256 平局(draw)

· · 题解

依旧验题人题解,这个题我做了将近 8h,虽然不知道有效思考时间是多少。

首先对于一个固定的局面,有简单的 O(n^3) 区间 dp 做法,但是恐怕只能拿到 8 分。并且这不仅是一个区间 dp,还是一个带 \max 的东西,不太能数。

记石头布剪刀分别为 0,1,2,在模 3 意义下讨论,则 i 能击败 i-1

先想办法以能转成数数的方法来求固定序列的答案。

考虑转成一个从左往右扫的 dp。考虑记 f_{i,S} 为考虑了前 i 个数完成了所有能完成的合并后剩下的局面为 S 的最大值,观察 S 的性质。显然如果存在形如 i,i,则一定会消成 i

一个可能没那么显然的结论是如果存在 i,i-1,i,则一定会消成 i。理由是如果右边第 i-1 要和中间的 i-1 合并,则靠右的 i 要被删去,一定会出现一个形如 i,i-1 的局面,并产生了一次合并。但是显然可以先把这三个东西缩成 i,右边再施以同样的操作合成 i-1,这样可以和上面达到一样的效果,还能给右边留出更大的操作空间。而这个操作显然对 ii+1 的合并次数都是不劣的。

所以 S 必然是一个先递增后递减的单峰序列,可以用两边的长度(均不包括峰顶) x,y 已经开头元素来刻画它。把整个序列扫完后的答案显然为 \lfloor\dfrac{x}3\rfloor+\lfloor\dfrac{y}3\rfloor+[x\equiv 2\pmod 3\land y\equiv 2\pmod 3]

注意到上面这个做法没有任何关于取 \max 的部分,所以可以直接设 f_{i,x,y,0/1/2} 表示考虑前 i 个数,两段长为 x,y,序列开头为 0/1/2 的方案数。如果某个时刻产生了一个 +1,直接把它乘上后 n-i 个数的方案数然后加到答案上就好了。转移就根据上面的做法分类即可。复杂度为 O(n^3)

注意到上升段长度不降,那么只要记录 x\bmod 3 即可,\dfrac{x}3 的贡献可以和 +1 以同样方式计算,复杂度变为 O(n^2)

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int mod=1e9+7;
int n,a[3005],f[3][4][3005],g[3][4][3005],suf[3005];
string s;
int main(){
    ios::sync_with_stdio(0),cin.tie(0);
    cin>>n>>s;
    for(int i=1;i<=n;i++) a[i]=s[i-1]-'0',suf[i]=__builtin_popcount(a[i]);
    suf[n+1]=1;
    for(int i=n;i;i--) suf[i]=1ll*suf[i]*suf[i+1]%mod;
    int res=0;
    f[0][0][0]=1;
    for(int i=1;i<=n;i++){
//      cout<<i-1<<'\n';
        memset(g,0,sizeof(g));
        for(int s=0;s<3;s++)
            for(int t=!!s;t<4;t++)
                for(int x=0;x<i&&(t||!x);x++){
                    int tmp=f[s][t][x],e2=((s+t-x)%3+2)%3;
//                  if(tmp) cout<<s<<' '<<t<<' '<<x<<' '<<e2<<'\n';
                    for(int o=0;o<3;o++)
                        if(a[i]>>o&1){
                            if(!t) g[o][1][0]=(g[o][1][0]+tmp)%mod;
                            else if((e2+2)%3==o) g[s][t][x+1]=(g[s][t][x+1]+tmp)%mod;
                            else if(e2==o) res=(res+1ll*tmp*suf[i+1])%mod,g[s][t][x]=(g[s][t][x]+tmp)%mod;
                            else if(x) res=(res+1ll*tmp*suf[i+1])%mod,g[s][t][x-1]=(g[s][t][x-1]+tmp)%mod;
                            else{
                                int nt=t+1;
                                if(nt>3) res=(res+1ll*tmp*suf[i+1])%mod,nt-=3;
                                g[s][nt][0]=(g[s][nt][0]+tmp)%mod;  
                            }
                        }
                }
        memcpy(f,g,sizeof(f));
    }
//  cout<<n<<'\n';
    for(int i=0;i<3;i++)
        for(int j=1;j<4;j++)
            for(int k=0;k<n;k++){
                int x=j-1,y=k;
//              if(f[i][j][k]) cout<<i<<' '<<j<<' '<<k<<'\n';
                res=(res+1ll*f[i][j][k]*(y/3+(x%3==2&&y%3==2)))%mod;
            }
    cout<<res<<'\n';
    return 0;
}
/*
7
7777777
*/