题解:P14135 【MX-X22-T6】「TPOI-4F」Miserable EXperience

· · 题解

出题人做法。和现有的两篇思路不太一样。

考虑把所有 1 操作提到最前面,那么操作完要求每层都是一样的。因此一开始直接用 1 操作把每一层削成对应的最小值一定是不劣的。

对原树差分,那么子树 -1 变成单点 -1,层 -1 变为这一层 -1 下一层 +1。只有 2 能处理负数,那就先用 2 把负数一层一层往根推,到根都推不掉就无解了。

现在相当于一个序列问题:设树的最大深度为 d,深度为 i 的点个数为 c_i,权值均为 a_i,每次可以执行以下操作之一:

要求将序列所有数都变为 0 的最小代价。

容易发现每个位置都是独立的,只用对于每个位置 i 考虑将 a_i 变为 0 的代价。进一步地,对于 a_i 中的每一个单位 1,他们之间也是独立的,所以只用考虑将 i 上的 1 变成 0 的代价。每次要么将这个 1 直接消除,要么继续下推。容易设计 dp_i 表示该问题的答案,那么有转移:

dp_i=\min(c_i,dp_{i+1}+1)

即可做到单次询问 \mathcal{O}(n)

接下来要求所有子树的答案。发现我们只用维护深度相关的信息,具体地,对于每个深度,我们要维护每一层的:

考虑长剖,合并时信息难以维护,那就全部暴力重构即可。时间复杂度 \mathcal{O}(n)

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

const int MAXN = 1e6 + 10;

vector<int> g[MAXN]; int d[MAXN], s[MAXN], sz[MAXN]; ll a[MAXN];

void init(int u) {
    sz[u] = 1;
    for (int v : g[u]) { init(v), sz[u] += sz[v]; if (d[v] > d[s[u]]) s[u] = v; }
    d[u] = d[s[u]] + 1;
}

int tint[MAXN << 2], *ti = tint; ll tll[MAXN], *tl = tll;

int *c[MAXN], *mv[MAXN], *dp[MAXN], *p[MAXN];

ll *num[MAXN], ans[MAXN];

inline 
void alloc(int u) {
    c[u] = ti, ti += d[u], mv[u] = ti, ti += d[u], dp[u] = ti, ti += d[u], p[u] = ti, ti += d[u];
    num[u] = tl, tl += d[u];
}

void dfs(int u) {
    c[u][0] = dp[u][0] = 1, ans[u] = mv[u][0] = a[u];
    if (s[u]) {
        c[s[u]] = c[u] + 1, mv[s[u]] = mv[u] + 1, p[s[u]] = p[u] + 1;
        dp[s[u]] = dp[u] + 1, num[s[u]] = num[u] + 1, dfs(s[u]);
        if (ans[s[u]] == -1) ans[u] = -1;
        else {
            ans[u] = ans[s[u]] - (mv[u][1] + p[u][1]);
            a[s[u]] -= a[u], mv[u][1] = a[s[u]];
            if (mv[u][1] + p[u][1] < 0) {
                ans[u] -= mv[u][1] + p[u][1];
                p[u][0] += mv[u][1] + p[u][1], p[u][1] = -mv[u][1];
            } else ans[u] += mv[u][1] + p[u][1];
            if (mv[u][0] + p[u][0] < 0) ans[u] = -1;
            else ans[u] += mv[u][0] + p[u][0];
        }
    }
    for (int v : g[u]) {
        if (v == s[u]) continue; alloc(v), dfs(v);
        if (ans[u] == -1 || ans[v] == -1) { ans[u] = -1; continue; }
        a[v] -= a[u], mv[v][0] = a[v];
        for (int i = 1; i <= d[v]; i++) {
            ans[u] -= dp[u][i] * (mv[u][i] + p[u][i]) + num[u][i];
            num[u][i] += num[v][i - 1];
            if (mv[u][i] < mv[v][i - 1]) num[u][i] += (ll)c[v][i - 1] * (mv[v][i - 1] - mv[u][i]);
            else num[u][i] += (ll)c[u][i] * (mv[u][i] - mv[v][i - 1]), mv[u][i] = mv[v][i - 1];
            ans[u] += num[u][i], c[u][i] += c[v][i - 1];
        }
        ans[u] -= (ll)dp[u][0] * (mv[u][0] + p[u][0]);
        for (int i = d[v]; i; i--) {
            if (mv[u][i] + p[u][i] < 0) {
                ans[u] -= mv[u][i] + p[u][i];
                p[u][i - 1] += mv[u][i] + p[u][i], p[u][i] = -mv[u][i];
            }
        }
        if (mv[u][0] + p[u][0] < 0) { ans[u] = -1; continue; }
        for (int i = d[v]; i; i--) dp[u][i] = min((i + 1 < d[u] ? dp[u][i + 1] : 0) + 1, c[u][i]);
        for (int i = 0; i <= d[v]; i++) ans[u] += (ll)dp[u][i] * (mv[u][i] + p[u][i]);
    }
}

int n, fa[MAXN]; ll sum; char buf[1 << 24];

int main() {
    scanf("%d%s", &n, buf);
    for (int i = 2, j = 0; i <= n; i++) {
        for (; buf[i + j - 2] == '0'; j++);
        fa[i] = j, g[j].emplace_back(i);
    }
    scanf("%s", buf);
    for (int i = 1, j = 0; i <= n; i++) {
        a[i] |= buf[j++] - 33 << 24;
        a[i] |= buf[j++] - 33 << 18;
        a[i] |= buf[j++] - 33 << 12;
        a[i] |= buf[j++] - 33 << 6;
        a[i] |= buf[j++] - 33;
    }
    init(1), alloc(1), dfs(1);
    for (int i = 1; i <= n; i++) sum ^= ans[i] + 1;
    printf("%lld", sum);
}