P14256 题解

· · 题解

验题人题解。

首先考虑如果确定了一个手势序列怎么求最大的平局次数。

考虑转化石头剪刀布问题,我们模拟一下相邻手势之间的胜负关系,发现可以设 > 表示左侧的手势会被右侧的击败< 反之;对于两个相同的手势,发现直接消去其中一个即可带来一次贡献,因此可以先将所有连续相同的手势消到只剩一个,这样就能将一个手势序列转化为一个只含 >< 的序列。

现在考虑处理这个新的序列。根据这个 pattern 手摸一下,容易发现对于相邻两个字符只有四种操作:消去一对 <> 可以转化为一次贡献(容易发现,我们仅有这一种贡献形式);对于一对 >< 可以消去其中任意一个(后面会证明该操作不优);两个 > 可以转化为一个 <,而两个 < 也能转化为一个 >

现在考虑如何利用这些性质解决这个最大括号匹配问题状物。下面给出处理方法及简单的证明:

  1. 首先如果出现了 <>,显然可以贪心地直接消去并获得一次贡献,因为即使不让它们匹配也无法通过它们造成大于一次的贡献。于是现在可以假定已经消去了所有的 <> 并计算了对应的贡献。

  2. 如果出现了 >>> 或者 <<<,显然根据上面的性质可以转化为 <> 并获得一次贡献,而现在要考虑的是如何计算这样的串的贡献;

  3. 根据刚刚的结论,我们知道如果相邻两个字符不同,一定是 >< 而非 <>(否则可以直接消去)。而由于不可能出现两次以上的 ><(否则也会出现 <> 而可以接着消去),所以处理后的串应当是形如 >>><<< 的串,即 u>v< 拼接在一起,且 u,v 非负。将当前串划分为这两个子串,发现如果要让它们共同匹配的话要让最右侧的两个 > 与最左侧的两个 < 分别转化后才能匹配,此时每四个字符才能产生一次更新,在 u\geq 3,v\geq 3 的时候是不优的:因为可以分别在两个子串中应用 2 中的处理方法,这样每三个字符就能产生一次贡献。当一直操作直到 u\leq 2,v\leq 2 时再尝试刚刚的操作至多一次(当且仅当 u=v=2 时可以再完成一次匹配),就能保证结果最优。

  4. 如何证明消去 >< 中的一个是不优的?注意到如果出现了相邻的 >< 且无法消去其中任意一个,那么 > 前面一定没有 <,且 < 后面一定没有 >;这样就形成了 3 中所说的串,发现删去其中一个是一定不优的。

现在考虑整合一下这些操作,形成一个比较好 dp 的 pattern。首先需要记录当前有多少个 < 待匹配,这样当出现 > 时可以直接匹配并产生贡献。如果当前串最左侧有一些 >,可以当长度为 3 时直接消去产生贡献(我们在 3 中已经证明了这么做的最优性),因此待消去的 > 至多有两个,可以在状态中记录它。

现在考虑将这个过程放在 dp 上面跑,形成一个类似于自动机的东西。首先要枚举当前位和上一位的手势从而获得一个 >< 串的新字符,因此要在 dp 状态中记录当前钦定的手势。接着考虑按照刚刚整合的操作,记录待匹配的 < 和最左侧的 >。对于转移,当出现 > 时能匹配就匹配,否则计入最左侧的 >;当出现 < 时直接加入待匹配的计数即可。

对于最终的 dp 状态,我们发现每个状态对应着一个消剩下的形如 3 中的串,按照上面所说的操作 O(1) 处理一下额外的贡献即可。

这样综合起来,可以设 f_{i,j,p,q} 为当前在第 i 位,有 j 个左括号待匹配,p\in\{0,1,2\} 代表当前位的手势为 pq\in\{0,1,2\} 代表有多少个右括号待匹配的方案数,同时设 g 为相同状态所对应的贡献之和。这样由于贡献系数总为 1,因此产生新的平局时直接让 g 加上对应的 f 即方案数,就能正确计算总贡献。

暴力枚举转移即可,复杂度 O(n^2)注意需要滚动数组优化空间。 cz 把空间开到 1G 了,可以不用滚动数组了。

const int N = 3e3 + 20;
const int mod = 1e9 + 7;
void add(int &x,int y){x += y;if(x>=mod) x -= mod;}
bool h[3][10];
int n;
int s[N];
int ans;
int f[2][N][3][3],g[2][N][3][3];
signed main(){
    h[0][1] = h[0][3] = h[0][5] = h[0][7] = 1;
    h[1][2] = h[1][3] = h[1][6] = h[1][7] = 1;
    h[2][4] = h[2][5] = h[2][6] = h[2][7] = 1;
    ios::sync_with_stdio(0);
    cin.tie(nullptr);
    cout.tie(nullptr);
    cin>>n;
    for(int i=1;i<=n;++i){
        char c;cin>>c;
        s[i] = c - '0';
    }
    for(int i=0;i<=2;++i){
        if(h[i][s[1]]) f[1][0][i][0] = 1;
    }
    for(int i=2;i<=n;++i){
        memset(f[i&1],0,sizeof f[i&1]);
        memset(g[i&1],0,sizeof g[i&1]);
        for(int j=0;j<N;++j){ // last position count
            for(int u=0;u<=2;++u){
                for(int y=0;y<=2;++y){
                    for(int x=0;x<=2;++x){
                        if(!h[x][s[i]]) continue;
                        if(x==y){
                            add(f[i&1][j][x][u],f[1^i&1][j][y][u]);
                            add(g[i&1][j][x][u],g[1^i&1][j][y][u]);
                            add(g[i&1][j][x][u],f[1^i&1][j][y][u]);
                        }
                        else if((x+1)%3!=y){
                            if(j==0){
                                if(u<=1){
                                    add(f[i&1][j][x][u+1],f[1^i&1][j][y][u]);
                                    add(g[i&1][j][x][u+1],g[1^i&1][j][y][u]);                                   
                                }
                                else{
                                    add(f[i&1][j][x][0],f[1^i&1][j][y][2]);
                                    add(g[i&1][j][x][0],g[1^i&1][j][y][2]);
                                    add(g[i&1][j][x][0],f[1^i&1][j][y][2]);                                         
                                }
                            }
                            else{
                                add(f[i&1][j-1][x][u],f[1^i&1][j][y][u]);
                                add(g[i&1][j-1][x][u],g[1^i&1][j][y][u]);
                                add(g[i&1][j-1][x][u],f[1^i&1][j][y][u]);
                            }
                        }
                        else{
                            add(f[i&1][j+1][x][u],f[1^i&1][j][y][u]);
                            add(g[i&1][j+1][x][u],g[1^i&1][j][y][u]);
                        }
                    }
                }
            }
        }
    }
    int ans = 0;
    for(int i=0;i<N;++i){
        for(int j=0;j<=2;++j){
            for(int k=0;k<=2;++k){
                add(ans,g[n%2][i][j][k]);
                if(i>=3)add(ans,1ll*i/3*f[n%2][i][j][k]%mod);   
            }
            if(i%3==2) add(ans,f[n%2][i][j][2]);    
        }
    }
    cout<<ans;
    return 0;
}