Solution-CF1060F

· · 题解

好像是一个不太一样的 \mathcal{O}(n^4) 做法,只不过比较麻烦。

考虑对单点求答案,假设它是根。先对原问题做一步转化,发现问题有两个维度:删边的顺序、每次删边保留的点的编号。后者只需要保证如果两个点其中有一个是根,根一定保留就行。所以问题可以看做:计算所有删边顺序构成的排列的权值之和,排列的权值为 2^{-c}c 为删边的过程中两端有一个点是根的边的数量,下文将这些边称为“关键边”。

在原树上钦定一些边为关建边,设钦定 i 条边的权值之和为 g_i,则根据二项式反演,最后的答案为:

\sum_{i=0}^{n-1}2^{-i}\sum_{j\geq i}\binom{j}{i}(-1)^{j-i}g_j=\sum_{j=0}^{n-1} g_j(-1)^j\sum_{i\leq j}\binom{j}{i}(-1)^i 2^{-i}=\sum_{j=0}^{n-1}g_j\left(-\frac{1}{2}\right)^j

考虑怎么求出上述式子。先进行进一步的观察,如果已经确定钦定了哪些边,排列需要满足什么条件才能合法。对于一条关键边,它需要满足它被缩掉之前,它的所有祖先边都已经被缩掉。转述一下就是,一条边在排列中的位置需要比它子树中所有边在排列中的位置靠前。

据此进行 DP,对于每个子树,维护它子树内的边的排列,再进行归并。f_{u,i} 表示 u 子树内边构成的排列中,最靠前的关建边位置为 i 的权值总和。转移需要考虑两个问题:

单次时间复杂度 \mathcal{O}(n^3),总时间复杂度 \mathcal{O}(n^4)

关于浮点数的问题,注意到 50 以内最大的组合数也只在 10^{14} 级别,可以放心运算。

#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define db long double
#define pb push_back
#define mp make_pair
#define pii pair<int, int>
#define FR first
#define SE second
using namespace std;
inline int read() {
    int x = 0; bool op = 0;
    char c = getchar();
    while(!isdigit(c))op |= (c == '-'), c = getchar();
    while(isdigit(c))x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
    return op ? -x : x;
}
const int N = 55;
db C[N][N];
void init() {
    C[0][0] = 1;
    for(int i = 1; i < N; i++) {
        C[i][0] = 1;
        for(int j = 1; j <= i; j++) {
            C[i][j] = C[i - 1][j] + C[i - 1][j - 1];
        }
    }
    return ;
}
int n;
int sz[N];
db f[N][N], g[N][N];
vector<int> G[N];
void dfs(int u, int fa) {
    for(int i = 0; i <= sz[u]; i++)f[u][i] = 0;
    sz[u] = 0; f[u][0] = 1;
    for(int v : G[u]) {
        if(v == fa)continue;
        dfs(v, u);
        for(int i = 0; i <= sz[u]; i++) {
            g[u][i] = f[u][i]; f[u][i] = 0;
        }
        f[u][0] += g[u][0] * f[v][0]; // 都没有钦定
        for(int i = 1; i <= sz[v]; i++) {
            for(int j = 0; j <= sz[u]; j++) {
                int x = sz[v] - i, y = sz[u] - j;
                db w = g[u][0] * f[v][i];
                w /= C[sz[u] + sz[v]][sz[u]];
                w *= C[i - 1 + j][j] * C[x + y][x];
                f[u][i + j] += w;
            }
        } // v 有钦定
        for(int i = 1; i <= sz[u]; i++) {
            for(int j = 0; j <= sz[v]; j++) {
                int x = sz[u] - i, y = sz[v] - j;
                db w = g[u][i] * f[v][0];
                w /= C[sz[u] + sz[v]][sz[u]];
                w *= C[i - 1 + j][j] * C[x + y][x];
                f[u][i + j] += w;
                // printf("w:%d %d %d %Lf %Lf\n", v, i, j, w, f[u][i + j]);
            }
        } // u 有钦定
        for(int i = 1; i <= sz[u]; i++) {
            for(int j = 1; j <= sz[v]; j++) {
                for(int k = 0; k < j; k++) {
                    int x = sz[u] - i, y = sz[v] - k;
                    db w = g[u][i] * f[v][j];
                    w /= C[sz[u] + sz[v]][sz[u]];
                    w *= C[i - 1 + k][k] * C[x + y][x];
                    f[u][i + k] += w;
                }
            }
        } // 都有钦定,u 的在前面
        for(int i = 1; i <= sz[u]; i++) {
            for(int j = 1; j <= sz[v]; j++) {
                for(int k = 0; k < i; k++) {
                    int x = sz[u] - k, y = sz[v] - j;
                    db w = g[u][i] * f[v][j];
                    w /= C[sz[u] + sz[v]][sz[u]];
                    w *= C[j - 1 + k][k] * C[x + y][x];
                    f[u][j + k] += w;
                }
            }
        } // 都有钦定,v 的在前面
        sz[u] += sz[v];
    }
    if(fa) {
        for(int i = 0; i <= sz[u]; i++) {
            g[u][i] = f[u][i]; f[u][i] = 0;
        }
        f[u][0] += g[u][0]; // 不钦定,之前没有
        for(int i = 1; i <= sz[u] + 1; i++) {
            f[u][i] -= g[u][0] / (sz[u] + 1) / 2;
        } // 钦定,之前没有
        for(int i = 1; i <= sz[u]; i++) {
            f[u][i + 1] += g[u][i] * i / (sz[u] + 1); // 不钦定,之前有
            for(int j = 1; j <= i; j++) {
                f[u][j] -= g[u][i] / (sz[u] + 1) / 2; // 钦定,之前有
            }
        }
    }
    sz[u]++;
    return ;
}
int main() { 
    init();
    n = read();
    for(int i = 1; i < n; i++) {
        int u = read(), v = read();
        G[u].pb(v); G[v].pb(u);
    }
    for(int i = 1; i <= n; i++) {
        dfs(i, 0);
        db res = 0;
        for(int j = 0; j < n; j++)res += f[i][j];
        printf("%.10Lf\n", res);
    }
    return 0;
}