题解:P1411 树

· · 题解

一个简单的树形 dp。

Solution

考虑设 f_{i, j} 表示 i 的子树内,包含 i 的连通块大小为 j 的,除此连通块的其他连通块的大小的乘积的最大值。

转移就是一个树上背包,没啥好说的。

注意的是答案的大小,考虑答案最大应该是一条链的时候,每 3 个一段,那么最大也就是 < 3^{234}。直接写压位高精度即可。我写的是每 6 位压一位,所以最后常数也就 20

:::success[AC Code]

#include <bits/stdc++.h>
using namespace std;
#define x first
#define y second
#define mp(Tx, Ty) make_pair(Tx, Ty)
#define For(Ti, Ta, Tb) for(auto Ti = (Ta); Ti <= (Tb); Ti++)
#define Dec(Ti, Ta, Tb) for(auto Ti = (Ta); Ti >= (Tb); Ti--)
#define debug(...) fprintf(stderr, __VA_ARGS__)
#define range(Tx) begin(Tx),end(Tx)
const int N = 705, K = 20;
const unsigned long long B = 1e6;
int n;
int h[N], e[N * 2], ne[N * 2], idx;
void add(int a, int b) {
    e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
struct node {
    unsigned long long a[K + 5];
    node operator * (const node &t)const {
        node ans;
        memset(ans.a, 0, sizeof(ans.a));
        For(i, 1, K) if (a[i]) For(j, 1, K) if (t.a[j]) ans.a[i + j - 1] += a[i] * t.a[j];
        unsigned long long now = 0;
        For(i, 1, K) {
            unsigned long long noww = ans.a[i] + now;
            ans.a[i] = noww % B;
            now = noww / B;
        }
        assert(!now);
        return ans;
    }
} f[N][N], g[N];
node max(node a, node b) {
    Dec(i, K, 1) {
        if (a.a[i] > b.a[i]) return a;
        if (a.a[i] < b.a[i]) return b;
    }
    return a;
}
int sz[N];
void dfs(int x, int fa) {
    sz[x] = 1;
    f[x][1].a[1] = 1;
    for (int i = h[x]; ~i; i = ne[i]) {
        int j = e[i];
        if (j == fa) continue;
        dfs(j, x);
        For(i, 0, sz[x] + sz[j]) memset(g[i].a, 0, sizeof(g[i].a));
        For(k, 0, sz[j]) {
            node b;
            memset(b.a, 0, sizeof(b));
            b.a[1] = k;
            For(i, 0, sz[x]) {
                g[i + k] = max(g[i + k], f[x][i] * f[j][k]);
                g[i] = max(g[i], f[x][i] * b * f[j][k]);
            }
        }
        sz[x] += sz[j];
        For(i, 0, sz[x]) f[x][i] = g[i];
    }
}
int main() {
    cin.tie(nullptr)->sync_with_stdio(false);
    memset(h, -1, sizeof(h));
    For(i, 0, N - 1) For(j, 0, N - 1) memset(f[i][j].a, 0, sizeof(f[i][j].a));
    cin >> n;
    For(i, 1, n - 1) {
        int a, b;
        cin >> a >> b;
        add(a, b), add(b, a);
    }
    dfs(1, 0);
    node ans;
    memset(ans.a, 0, sizeof(ans.a));
    For(i, 0, n) {
        node b;
        memset(b.a, 0, sizeof(b.a));
        b.a[1] = i;
        ans = max(ans, f[1][i] * b);
    }
    bool f = 1;
    Dec(i, K, 1) {
        if (f && ans.a[i] == 0) continue;
        if (!f && ans.a[i] < 10) cout << "00000";
        else if (!f && ans.a[i] < 100) cout << "0000";
        else if (!f && ans.a[i] < 1000) cout << "000";
        else if (!f && ans.a[i] < 10000) cout << "00";
        else if (!f && ans.a[i] < 100000) cout << "0";
        f = 0;
        cout << ans.a[i];
    }
    if (f) cout << 0;
    return 0;
}

:::