P5637 ckw的树

· · 题解

树上随机游走模型的扩展。

E_kk 节点还需要走的期望步数,若 k 没有被标记,则有定义式(其中 k\in bro(k)):

E_k=\dfrac{1}{cnt_k}(E_{fa_k}+E_{gfa_k}+\sum_{t\in bro(k)}E_t+\sum_{nx\in son(k)}E_{nx}+\sum_{nx\in gson(k)}E_{nx})+1

不难发现,叶子节点的后两项是不存在的。所以后两项可以从儿子递推,可以考虑把 E_k 写成如下形式:

E_k=a_kE_{fa_k}+b_kE_{gfa_k}+c_k+\dfrac{1}{cnt_k}\sum_{t\in bro(k)}E_t

难点在于兄弟的贡献如何展开。假定已经算出了 E_ka,b,c 系数,对 k 的兄弟节点一起考虑。此时 bro 集合,fa_k,gfa_k 都固定,令 \sum_{t\in bro(k)}E_t=S

S=\sum_k(a_kE_{fa_k}+b_kE_{gfa_k}+c_k+\dfrac{1}{cnt_k}S) S=\dfrac{E_{fa_k}(\sum_k a_k)+E_{gfa_k}(\sum_k b_k)+(\sum_k c_k)}{1-\sum_k \dfrac{1}{cnt_k}}

所以可以把 S 化进 a,b,c 中。

接下来就只用考虑从儿子递推了,这是容易的,展开后的结果:

cnt_k\cdot E_k &=E_k(\sum_{nx\in son(k)}a_{nx}+\sum_{s\in son(k)}\sum_{nx\in son(s)}b_{nx}+\sum_{s\in son(k)}a_s\sum_{nx\in son(s)}a_{nx})\\ &+E_{fa_k}(1+\sum_{nx\in son(k)}b_{nx}+\sum_{s\in son(k)}b_s\sum_{nx\in son(s)}a_{nx})\\ &+E_{gfa_k}\\ &+S\\ &+(cnt_k+\sum_{nx\in son(k)}c_{nx}+\sum_{s\in son(k)}\sum_{nx\in son(s)}c_{nx}+\sum_{s\in son(k)}c_s\sum_{nx\in son(s)}a_{nx}) \end{aligned}

容易化简出系数。

由于根节点的 E_{fa}=E_{gfa}=0,可以直接推出 E_{root}=c_{root}。最后 dfs 一遍还原出每个节点的 E 即可。

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <ctype.h>
#include <cmath>
#include <vector>

char ST;

#define ll long long
#define inf 0x3f3f3f3f
//#define int long long
//#define inf 0x3f3f3f3f3f3f3f3f
#define debug(...) fprintf(stderr, __VA_ARGS__)
#define gline debug("now is #%d\n", __LINE__)
#define pii std::pair <int, int>
#define mkp std::make_pair
#define fi first
#define se second

int read()
{
    int x = 0, f = 1;
    char c = getchar();
    for(; !isdigit(c); c = getchar()) if(c == '-') f = -1;
    for(;  isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + (c ^ 48);
    return x * f;
}

void ckmax(int &x, int y) { x = x > y ? x : y; }
void ckmin(int &x, int y) { x = x < y ? x : y; }

#define mod 998244353
//#define mod 1000000007
void plus_(int &x, int y) { x = x + y >= mod ? x + y - mod : x + y; }
void mul_(int &x, int y) { x = 1ll * x * y % mod; }
int ksm(int a, int b)
{
    int res = 1;
    for(; b; b >>= 1, mul_(a, a)) if(b & 1) mul_(res, a);
    return res;
}

#define N 100010
int n, m;
bool flag[N];
int h[N], e[N << 1], ne[N << 1], idx = -1;
void add_edge(int x, int y) { ne[++idx] = h[x], h[x] = idx, e[idx] = y; }
void add(int x, int y) { add_edge(x, y), add_edge(y, x); }

int a[N], b[N], c[N], d[N];
int suma[N], sumb[N], sumc[N];

int du[N], son[N];

void dfs1(int k, int fa, int ff)
{
    for(int i = h[k]; ~i; i = ne[i])
    {
        int nx = e[i];
        if(nx == fa) continue;
        dfs1(nx, k, fa);
        son[k]++;
        du[k] += son[nx] + 1;
    }
    for(int i = h[k]; ~i; i = ne[i])
    {
        int nx = e[i];
        if(nx == fa) continue;
        du[nx] += son[k];
    }
    du[k] += (!!fa) + (!!ff);
}

void dfs2(int k, int fa)
{
    int A, B, C, D;
    A = B = C = D = 0; 
    for(int i = h[k]; ~i; i = ne[i])
    {
        int nx = e[i];
        if(nx == fa) continue;
        dfs2(nx, k);
        plus_(A, a[nx]), plus_(B, b[nx]);
        plus_(C, c[nx]), plus_(D, d[nx]);
    }
    D = (mod + 1 - D) % mod;
    D = ksm(D, mod - 2);
    for(int i = h[k]; ~i; i = ne[i])
    {
        int nx = e[i];
        if(nx == fa) continue;
        plus_(a[nx], 1ll * A * D % mod * d[nx] % mod);
        plus_(b[nx], 1ll * B * D % mod * d[nx] % mod);
        plus_(c[nx], 1ll * C * D % mod * d[nx] % mod);
        d[nx] = 0;
        plus_(suma[k], a[nx]);
        plus_(sumb[k], b[nx]);
        plus_(sumc[k], c[nx]);
    }
    if(flag[k])
    {
        a[k] = b[k] = c[k] = d[k] = 0;
        return;
    }
    int S;
    a[k] = b[k] = d[k] = 1;
    S = c[k] = du[k];
    for(int i = h[k]; ~i; i = ne[i])
    {
        int nx = e[i];
        if(nx == fa) continue;
        plus_(a[k], b[nx]);
        plus_(a[k], 1ll * b[nx] * suma[nx] % mod);
        plus_(S, mod - a[nx]);
        plus_(S, mod - sumb[nx]);
        plus_(S, mod - 1ll * a[nx] * suma[nx] % mod);
        plus_(c[k], c[nx]);
        plus_(c[k], sumc[nx]);
        plus_(c[k], 1ll * c[nx] * suma[nx] % mod);
    }
    S = ksm(S, mod - 2);
    mul_(a[k], S), mul_(b[k], S), mul_(c[k], S), mul_(d[k], S);
}

int E[N];

void dfs3(int k, int fa, int gfa)
{
    E[k] = c[k];
    plus_(E[k], 1ll * a[k] * E[fa] % mod);
    plus_(E[k], 1ll * b[k] * E[gfa] % mod);
    for(int i = h[k]; ~i; i = ne[i])
    {
        int nx = e[i];
        if(nx == fa) continue;
        dfs3(nx, k, fa);
    }
}

char ED;
int main()
{
    debug("1/2 = %d, 1/3 = %d\n", ksm(2, mod - 2), ksm(3, mod - 2));
    debug("%.3f MB\n", abs(&ST - &ED) / 1024.0 / 1024);
    memset(h, idx = -1, sizeof(h));
    n = read(), m = read();
    for(int i = 1, x, y; i < n; i++) x = read(), y = read(), add(x, y);
    for(int i = 1; i <= m; i++) flag[read()] = 1;
    dfs1(1, 0, 0);
    du[1]++;
    dfs2(1, 0);
    int D = (mod + 1 - d[1]) % mod;
    D = ksm(D, mod - 2);
    plus_(a[1], 1ll * a[1] * D % mod * d[1] % mod);
    plus_(b[1], 1ll * b[1] * D % mod * d[1] % mod);
    plus_(c[1], 1ll * c[1] * D % mod * d[1] % mod);
    d[1] = 0;
    dfs3(1, 0, 0);
    for(int i = 1; i <= n; i++) printf("%d\n", E[i]);
    return 0;
}