ABC391E题解

· · 题解

题目翻译

对于一个长度为 3^N 的二进制字符串 B,定义一种操作:将 B 分成每组 3 个字符,取每组的多数值作为新的字符,生成一个长度为 3^{N-1} 的新字符串。

现在给定一个长度为 3^N 的二进制字符串 A,重复上述操作 N 次,最终得到一个长度为 1 的字符串 A',要求找出最少需要修改 A 中的多少个字符(0110),才能使最终的 A' 值翻转。

思路

可以把原题看成一颗三叉树,每个结点的值可以由它的左、中、右三个孩子得到。对于样例 1,我们可以有下图:

类似于一个树形 dp,可以记录在每个结点改变这个结点的值所需要的最小修改次数。
记当前结点为 p,左、中、右儿子分别为 lp, mp, rp,可以分为两种情况:

  1. 只需要改其中一个儿子即可修改当前结点。及三个儿子中有一个的值与其他两个不同。
  2. 需要修改两个儿子。及三个儿子的值相同。

对于第一种情况,我们只需要求那两个和当前结点的值相同的儿子的答案的最小值。
对于第二种情况,我们要求三个儿子的答案分别两两相加的最小值。
转移代码如下:

f[p] = 0x3f3f3f3f;
//只需要改一个的情况
if((nowv == 0 && cnt1 == 1) || (nowv == 1 && cnt0 == 1))
{
  if(lpv == nowv) f[p] = min(f[p], f[lp]);
  if(mpv == nowv) f[p] = min(f[p], f[mp]);
  if(rpv == nowv) f[p] = min(f[p], f[rp]);
}
//需要改两个的情况
else
{
  f[p] = min(f[p], f[lp] + f[mp]);
  f[p] = min(f[p], f[lp] + f[rp]);
  f[p] = min(f[p], f[mp] + f[rp]);
}

其中 nowv 为当前结点的值,cnt0cnt1 分别是子节点中值为 01 的数量。

完整代码

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;

char s[1594323+5];
int n;

int f[(1594323<<2)+5];
int cnt = 1;

/// @param p 当前结点编号
/// @param l 当前结点对应字符串的左端点
/// @param r 右端点
/// @return 未修改时当前结点的值
int dfs(int p, int l, int r)
{
    if(l == r)
    {
        f[p] = 1;
        return s[r] - '0';
    }
    int lp = ++cnt, mp = ++cnt, rp = ++cnt;
    int len = (r - l + 1);
    int lpv = dfs(lp, l,            l+len/3-1);     //左孩子的值
    int mpv = dfs(mp, l+len/3,      l+len/3*2-1);   //中孩子的值
    int rpv = dfs(rp, l+len/3*2,    r);             //右孩子的值
    int cnt0 = 0, cnt1 = 0;
    if(lpv == 0) cnt0++; else cnt1++;
    if(mpv == 0) cnt0++; else cnt1++;
    if(rpv == 0) cnt0++; else cnt1++;
    int nowv = cnt0 > cnt1 ? 0 : 1;                 //当前结点的值
    f[p] = 0x3f3f3f3f;
    //只需要改一个的情况
    if((nowv == 0 && cnt1 == 1) || (nowv == 1 && cnt0 == 1))
    {
        if(lpv == nowv) f[p] = min(f[p], f[lp]);
        if(mpv == nowv) f[p] = min(f[p], f[mp]);
        if(rpv == nowv) f[p] = min(f[p], f[rp]);
    }
    //需要改两个的情况
    else
    {
        f[p] = min(f[p], f[lp] + f[mp]);
        f[p] = min(f[p], f[lp] + f[rp]);
        f[p] = min(f[p], f[mp] + f[rp]);
    }
    return nowv;
}

int main()
{
    scanf("%d%s", &n, s+1);
    dfs(1, 1, strlen(s+1));
    printf("%d", f[1]);
    return 0;
}