题解:P13954 [ICPC 2023 Nanjing R] 红黑树
lovely_nst
·
·
题解
P13954 红黑树
暴力 DP
设 f_{u,i} 表示点 u 为根的子树中的叶子节点到 u 的路径上黑点个数均为 i,此时的最小修改数。
分两种情况考虑。
第一种,点 u 的颜色最终为红色,可得转移方程:
f_{u,i}=\left[s_u=1\right]+\sum_{v\in \text{son}(u)} f_{v,i}
第二种,点 u 的颜色最终为黑色,可得转移方程:
f_{u,i}=\left[s_u=0\right]+\sum_{v\in \text{son}(u)} f_{v,i-1}
两种情况取 \min 即可。
void dfs (int u)
{
if (g[u].size () == 0)
{
f[u][0] = a[u];
f[u][1] = !a[u];
return ;
}
f[u][0] = a[u];
for (int v : g[u]) dfs (v) , f[u][0] += f[v][0];
for (int j = 1;j <= n;j ++)
{
int l1 = 0 , l2 = 0;
for (int v : g[u]) l1 += f[v][j - 1] , l2 += f[v][j];
f[u][j] = min (l1 + !a[u] , l2 + a[u]);
}
}
可以使用 STL 优化以获取更多的分数。
优化
把 f_{u,i} 看成一个图像 f_u(x),可以发现图像是一个凸包(因为最底层的 f_u 可以看作一个只有两个点的凸包,而凸包与凸包按位加也会是一个凸包)。设 \sum_{v\in \text{son}(u)} f_{v,x} 的图像为 F(x), \sum_{v\in \text{son}(u)} f_{v,x-1} 的图像为 F'(x)。如下图所示:
例如这是一个 F(x) 的图像,那么将其向右移一个单位长度即可变成 F'(x) 的图像,如下:
设 (\sum_{v\in \text{son}(u)} f_{v,x})(x) 斜率为 0 的一段区间是 \left[L,R\right]。接下来分成两种情况。
s_u=0
此时 F'(x) 的图像会向上一个单位长度,如图:
发现当 x\le R 时 F(x) 会更小;否则,F'(x) 更小。而形成的新图像则是 F(x) 在 R 的位置插入了一条斜率为 1 的线段。
s_u=1
此时 F(x) 的图像会向上一个单位长度,如图:
发现当 x\le L 时 F(x) 会更小;否则,F'(x) 更小。而形成的新图像则是 F(x) 在 L 的位置添加了一条斜率为 -1 的线段。
那么就使用小根堆存储 f_u(x) 对应图像每一位的斜率,在最后插入新线段即可。
## 答案
$f_{u,0}$ 的值其实就是 $u$ 子树内的黑点数,用 $f_{u,0}$ 加上所有负数斜率即为最小值。
## AC Code
```cpp
#include <bits/stdc++.h>
#define int long long
#define fre(s) freopen(s".in","r",stdin);freopen(s".out","w",stdout);
using namespace std;
const int N = 1e5 + 5 , inf = 1e9 + 7;
int n , c[N] , f0[N];
bool a[N];
vector <int> g[N];
priority_queue <int , vector <int> , greater <int> > s[N] , q;
void dfs (int u)
{
f0[u] = a[u] , c[u] = 0;
for (int v : g[u])
{
dfs (v);
f0[u] += f0[v];
if (s[u].empty ()) swap (s[u] , s[v]) , c[u] = c[v];
else
{
while (!q.empty ()) q.pop ();
c[u] = 0;
while (!s[u].empty () && !s[v].empty ())
{
int tmp = s[u].top () + s[v].top ();
q.push (tmp);
if (tmp < 0) c[u] += tmp;
s[u].pop () , s[v].pop ();
}
s[u] = q;
}
}
if (a[u]) s[u].push (-1) , c[u] --;
else s[u].push (1);
}
signed main ()
{
ios::sync_with_stdio(0); cin.tie(0),cout.tie(0);
int T;
cin >> T;
while (T --)
{
cin >> n;
for (int i = 1;i <= n;i ++)
{
char op;
cin >> op;
a[i] = op - 48;
}
for (int i = 2;i <= n;i ++)
{
int p;
cin >> p;
g[p].push_back (i);
}
dfs (1);
for (int i = 1;i <= n;i ++)
{
cout << f0[i] + c[i] << ' ';
g[i].clear ();
while (!s[i].empty ()) s[i].pop ();
}
cout << '\n';
}
return 0;
}
```