P14160 [ICPC 2022 Nanjing R] 工厂重现

· · 题解

考虑 DP。设 f_{u, i}u 子树内有 k 个工厂的答案。加入一个边权为 d 的儿子 v 时,有转移:

f_{u, i + j} = \max(f_{u, i + j}, f_{u, i} + f_{v, j} + d \times j \times (k - j))

显然 f_{u, i} 关于 i 是凸的。平衡树维护 f_u 的差分数组即可,归并差分数组可以启发式合并或平衡树有交并。f_{u, i} 加上 d \times i \times (k - i) 的操作就是差分数组加等差数列,打 tag 维护即可。

时间复杂度 O(n \log^2 n)

:::info[代码]

// Problem: P14160 [ICPC 2022 Nanjing R] 工厂重现
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P14160
// Memory Limit: 512 MB
// Time Limit: 1000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 100100;

ll n, m;
vector<pii> G[maxn];

struct line {
    ll k, b;
    line(ll _k = 0, ll _b = 0) : k(_k), b(_b) {}
} tag[maxn];

inline line operator + (const line &a, const line &b) {
    return line(a.k + b.k, a.b + b.b);
}

int nt, ls[maxn], rs[maxn], sz[maxn], p[maxn];
ll val[maxn];
mt19937 rnd(chrono::steady_clock::now().time_since_epoch().count());

inline int newnode(int k) {
    int u = ++nt;
    sz[u] = 1;
    p[u] = rnd();
    val[u] = k;
    return u;
}

inline void pushup(int x) {
    sz[x] = sz[ls[x]] + sz[rs[x]] + 1;
}

inline void pushtag(int x, line y) {
    if (!x) {
        return;
    }
    val[x] += y.k * sz[ls[x]] + y.b;
    tag[x] = tag[x] + y;
}

inline void pushdown(int x) {
    if (tag[x].k == 0 && tag[x].b == 0) {
        return;
    }
    pushtag(ls[x], tag[x]);
    pushtag(rs[x], line(tag[x].k, tag[x].b + (sz[ls[x]] + 1) * tag[x].k));
    tag[x] = line();
}

void split(int u, ll k, int &x, int &y) {
    if (!u) {
        x = y = 0;
        return;
    }
    pushdown(u);
    if (val[u] >= k) {
        x = u;
        split(rs[u], k, rs[u], y);
    } else {
        y = u;
        split(ls[u], k, x, ls[u]);
    }
    pushup(u);
}

int merge(int x, int y) {
    if (!x || !y) {
        return x | y;
    }
    pushdown(x);
    pushdown(y);
    if (p[x] < p[y]) {
        rs[x] = merge(rs[x], y);
        pushup(x);
        return x;
    } else {
        ls[y] = merge(x, ls[y]);
        pushup(y);
        return y;
    }
}

int mrg(int x, int y) {
    if (!x || !y) {
        return x | y;
    }
    pushdown(x);
    pushdown(y);
    if (p[x] > p[y]) {
        swap(x, y);
    }
    int u, v;
    split(y, val[x], u, v);
    ls[x] = mrg(ls[x], u);
    rs[x] = mrg(rs[x], v);
    pushup(x);
    return x;
}

ll f[maxn], tot;

void dfs(int u) {
    if (!u) {
        return;
    }
    pushdown(u);
    dfs(ls[u]);
    f[++tot] = val[u];
    dfs(rs[u]);
}

int rt[maxn];

void dfs(int u, int fa) {
    rt[u] = newnode(0);
    for (pii p : G[u]) {
        ll v = p.fst, d = p.scd;
        if (v == fa) {
            continue;
        }
        dfs(v, u);
        pushtag(rt[v], line(-2 * d, (m - 1) * d));
        rt[u] = mrg(rt[u], rt[v]);
    }
}

void solve() {
    scanf("%lld%lld", &n, &m);
    for (int i = 1, u, v, d; i < n; ++i) {
        scanf("%d%d%d", &u, &v, &d);
        G[u].pb(v, d);
        G[v].pb(u, d);
    }
    dfs(1, -1);
    dfs(rt[1]);
    ll ans = 0;
    for (int i = 1; i <= m; ++i) {
        ans += f[i];
    }
    printf("%lld\n", ans);
}

int main() {
    int T = 1;
    // scanf("%d", &T);
    while (T--) {
        solve();
    }
    return 0;
}

:::