[USACO24JAN] Island Vacation P 题解

· · 题解

题解

建出原图的广义圆方树,令所有方点的编号为 N+1,N+2,\dots。钦定圆点 1 为根,且钦定每个方点的儿子是按照环上的顺序排列的。设圆方树上一个方点的权值 \mathrm{val}_u 为其儿子个数对 2\min 的值。

原图上的随机游走等价于,设当前在圆方树上的点 u

也就是说,对于一个圆点,我们可能走到它的一些儿子,然后绕一圈回来,最后可能走到它的兄弟,或是停在它上面。

开始 DP:设 g_u 表示到达圆点 u 后走向其左右兄弟的概率(这里只有可能走向左右兄弟中的一个,但两种情况的概率是相等的),h_u 表示走到方点 u 后从底下绕一圈回到 u 的概率。

对于每个点 ug_u 可以用一次背包计算,h_u 就是直接把所有儿子的 g_u 乘起来。

算出 g,h 后,设 f_{u,0/1} 表示:

由于已经算出了 g,h,所以在兄弟之间的转移是容易的;对于从圆点 u 走到 u 的某个儿子的情况,在 u 处做一次可撤销背包即可。

具体地,从 u 到儿子 v 的转移是:设

F(x)=\prod_{v'\in \mathrm{son}_u\land v'\neq v} (1+h_{v'}x),

那么

f_{v,0}=\sum_{i\ge 0} [x^i]F(x)\cdot i!\cdot (1-p_u)^{i+1}\cdot \mathrm{val}_v\cdot (s-2i)^{-1}\cdot \frac{2^i s!!}{(s-2i)!!}.

s 的定义在上面有。)

而停在 u 的概率就是 1 减去走向兄弟的概率,再减去走向某个儿子且不回来的概率,在算完 f 以后这些都是好算的。

时间复杂度 O(N^2)

代码

#include <bits/stdc++.h>
using namespace std;
#define For(Ti, Ta, Tb) for (auto Ti = (Ta); Ti <= (Tb); ++Ti)
#define Dec(Ti, Ta, Tb) for (auto Ti = (Ta); Ti >= (Tb); --Ti)
#define debug(...) fprintf(stderr, __VA_ARGS__)
#define range(Tx) begin(Tx), end(Tx)
using ll = long long;
using mint = modint1000000007;
const int N = 1e4 + 5;
int T, n, m, dfn[N], low[N], dfx, stk[N], top, tot;
mint p[N];
vector<int> gr[N], e[N * 2];
void link(int u, int v) {
    e[u].push_back(v);
    e[v].push_back(u);
}
void tarjan(int u) {
    low[u] = dfn[u] = ++dfx;
    stk[++top] = u;
    for (int v : gr[u]) {
        if (!dfn[v]) {
            tarjan(v);
            low[u] = min(low[u], low[v]);
            if (low[v] == dfn[u]) {
                link(++tot, u);
                for (int x = 0; x != v;) {
                    link(tot, x = stk[top--]);
                }
            }
        } else {
            low[u] = min(low[u], dfn[v]);
        }
    }
}
int fa[N * 2], val[N * 2], pre[N], nxt[N], sum[N];
mint f[N * 2][2], g[N], h[N * 2], ans[N];
void get_dp(int u, int &tot, mint *dp) {
    fill(dp, dp + n + 1, 0);
    dp[0] = 1;
    for (int v : e[u]) {
        if (v != fa[u] && val[v] > 1) {
            Dec(i, tot, 0) { dp[i + 1] += dp[i] * h[v]; }
            ++tot;
        }
    }
}
void dfs(int u) {
    val[u] = (u <= n ? val[fa[u]] : min(2, int(e[u].size()) - 1));
    sum[u] = (val[u] > 1);
    for (int v : e[u]) {
        if (v != fa[u]) {
            fa[v] = u;
            dfs(v);
            sum[u] += val[v];
        }
    }
    if (u > n) {
        int x = u;
        for (int v : e[u]) {
            if (v != fa[u]) {
                pre[v] = x;
                x = v;
            }
        }
        reverse(range(e[u]));
        x = u;
        for (int v : e[u]) {
            if (v != fa[u]) {
                nxt[v] = x;
                x = v;
            }
        }
        reverse(range(e[u]));
    }
    if (val[u] <= 1)
        return;
    if (u <= n) {
        static mint dp[N];
        int tot = 0;
        get_dp(u, tot, dp);
        mint pw = 1;
        For(i, 0, tot) {
            pw *= 1 - p[u];
            g[u] += dp[i] * C.fac(i) * pw * C.inv(sum[u] - i * 2);
            pw *= 2 * C.inv(sum[u] - i * 2);
        }
    } else {
        h[u] = 1;
        for (int v : e[u]) {
            if (v != fa[u]) {
                h[u] *= g[v];
            }
        }
    }
}
void dfs2(int u) {
    if (u <= n) {
        static mint dp[N];
        int tot = 0;
        get_dp(u, tot, dp);
        mint all = 1 - g[u];
        for (int v : e[u]) {
            if (v == fa[u])
                continue;
            if (val[v] > 1) {
                For(i, 1, tot) { dp[i] -= dp[i - 1] * h[v]; }
                --tot;
            }
            mint pr = 0, pw = 1;
            For(i, 0, tot) {
                pw *= 1 - p[u];
                pr += dp[i] * C.fac(i) * pw * val[v] * C.inv(sum[u] - i * 2);
                pw *= 2 * C.inv(sum[u] - i * 2);
            }
            all -= pr * (1 - h[v]);
            f[v][0] += pr * (f[u][0] + f[u][1]);
            if (val[v] > 1) {
                ++tot;
                Dec(i, tot, 1) { dp[i] += dp[i - 1] * h[v]; }
            }
        }
        ans[u] = (f[u][0] + f[u][1]) * all;
    } else {
        mint cur = f[u][0] / 2;
        for (int v : e[u]) {
            if (v != fa[u]) {
                f[v][0] = cur;
                cur *= g[v];
            }
        }
        reverse(range(e[u]));
        cur = f[u][0] / 2;
        for (int v : e[u]) {
            if (v != fa[u]) {
                f[v][1] = cur;
                cur *= g[v];
            }
        }
        reverse(range(e[u]));
    }
    for (int v : e[u]) {
        if (v != fa[u]) {
            dfs2(v);
        }
    }
}
int main() {
    cin.tie(nullptr)->sync_with_stdio(false);
    cin >> T;
    while (T--) {
        cin >> n >> m;
        For(i, 1, n) { cin >> p[i]; }
        For(i, 1, m) {
            int u, v;
            cin >> u >> v;
            gr[u].push_back(v);
            gr[v].push_back(u);
        }
        tot = n;
        tarjan(1);
        f[1][0] = 1;
        dfs(1);
        dfs2(1);
        For(i, 1, n) { cout << ans[i] << " \n"[i == n]; }
        For(i, 1, n) {
            dfn[i] = low[i] = 0;
            g[i] = ans[i] = 0;
            gr[i].clear();
        }
        For(i, 1, tot) {
            f[i][0] = f[i][1] = h[i] = fa[i] = 0;
            e[i].clear();
        }
        dfx = top = 0;
    }
    return 0;
}

有板子的:https://loj.ac/s/2014541。