P5637 ckw的树
树上随机游走模型的扩展。
令
不难发现,叶子节点的后两项是不存在的。所以后两项可以从儿子递推,可以考虑把
难点在于兄弟的贡献如何展开。假定已经算出了
所以可以把
接下来就只用考虑从儿子递推了,这是容易的,展开后的结果:
容易化简出系数。
由于根节点的
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <ctype.h>
#include <cmath>
#include <vector>
char ST;
#define ll long long
#define inf 0x3f3f3f3f
//#define int long long
//#define inf 0x3f3f3f3f3f3f3f3f
#define debug(...) fprintf(stderr, __VA_ARGS__)
#define gline debug("now is #%d\n", __LINE__)
#define pii std::pair <int, int>
#define mkp std::make_pair
#define fi first
#define se second
int read()
{
int x = 0, f = 1;
char c = getchar();
for(; !isdigit(c); c = getchar()) if(c == '-') f = -1;
for(; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + (c ^ 48);
return x * f;
}
void ckmax(int &x, int y) { x = x > y ? x : y; }
void ckmin(int &x, int y) { x = x < y ? x : y; }
#define mod 998244353
//#define mod 1000000007
void plus_(int &x, int y) { x = x + y >= mod ? x + y - mod : x + y; }
void mul_(int &x, int y) { x = 1ll * x * y % mod; }
int ksm(int a, int b)
{
int res = 1;
for(; b; b >>= 1, mul_(a, a)) if(b & 1) mul_(res, a);
return res;
}
#define N 100010
int n, m;
bool flag[N];
int h[N], e[N << 1], ne[N << 1], idx = -1;
void add_edge(int x, int y) { ne[++idx] = h[x], h[x] = idx, e[idx] = y; }
void add(int x, int y) { add_edge(x, y), add_edge(y, x); }
int a[N], b[N], c[N], d[N];
int suma[N], sumb[N], sumc[N];
int du[N], son[N];
void dfs1(int k, int fa, int ff)
{
for(int i = h[k]; ~i; i = ne[i])
{
int nx = e[i];
if(nx == fa) continue;
dfs1(nx, k, fa);
son[k]++;
du[k] += son[nx] + 1;
}
for(int i = h[k]; ~i; i = ne[i])
{
int nx = e[i];
if(nx == fa) continue;
du[nx] += son[k];
}
du[k] += (!!fa) + (!!ff);
}
void dfs2(int k, int fa)
{
int A, B, C, D;
A = B = C = D = 0;
for(int i = h[k]; ~i; i = ne[i])
{
int nx = e[i];
if(nx == fa) continue;
dfs2(nx, k);
plus_(A, a[nx]), plus_(B, b[nx]);
plus_(C, c[nx]), plus_(D, d[nx]);
}
D = (mod + 1 - D) % mod;
D = ksm(D, mod - 2);
for(int i = h[k]; ~i; i = ne[i])
{
int nx = e[i];
if(nx == fa) continue;
plus_(a[nx], 1ll * A * D % mod * d[nx] % mod);
plus_(b[nx], 1ll * B * D % mod * d[nx] % mod);
plus_(c[nx], 1ll * C * D % mod * d[nx] % mod);
d[nx] = 0;
plus_(suma[k], a[nx]);
plus_(sumb[k], b[nx]);
plus_(sumc[k], c[nx]);
}
if(flag[k])
{
a[k] = b[k] = c[k] = d[k] = 0;
return;
}
int S;
a[k] = b[k] = d[k] = 1;
S = c[k] = du[k];
for(int i = h[k]; ~i; i = ne[i])
{
int nx = e[i];
if(nx == fa) continue;
plus_(a[k], b[nx]);
plus_(a[k], 1ll * b[nx] * suma[nx] % mod);
plus_(S, mod - a[nx]);
plus_(S, mod - sumb[nx]);
plus_(S, mod - 1ll * a[nx] * suma[nx] % mod);
plus_(c[k], c[nx]);
plus_(c[k], sumc[nx]);
plus_(c[k], 1ll * c[nx] * suma[nx] % mod);
}
S = ksm(S, mod - 2);
mul_(a[k], S), mul_(b[k], S), mul_(c[k], S), mul_(d[k], S);
}
int E[N];
void dfs3(int k, int fa, int gfa)
{
E[k] = c[k];
plus_(E[k], 1ll * a[k] * E[fa] % mod);
plus_(E[k], 1ll * b[k] * E[gfa] % mod);
for(int i = h[k]; ~i; i = ne[i])
{
int nx = e[i];
if(nx == fa) continue;
dfs3(nx, k, fa);
}
}
char ED;
int main()
{
debug("1/2 = %d, 1/3 = %d\n", ksm(2, mod - 2), ksm(3, mod - 2));
debug("%.3f MB\n", abs(&ST - &ED) / 1024.0 / 1024);
memset(h, idx = -1, sizeof(h));
n = read(), m = read();
for(int i = 1, x, y; i < n; i++) x = read(), y = read(), add(x, y);
for(int i = 1; i <= m; i++) flag[read()] = 1;
dfs1(1, 0, 0);
du[1]++;
dfs2(1, 0);
int D = (mod + 1 - d[1]) % mod;
D = ksm(D, mod - 2);
plus_(a[1], 1ll * a[1] * D % mod * d[1] % mod);
plus_(b[1], 1ll * b[1] * D % mod * d[1] % mod);
plus_(c[1], 1ll * c[1] * D % mod * d[1] % mod);
d[1] = 0;
dfs3(1, 0, 0);
for(int i = 1; i <= n; i++) printf("%d\n", E[i]);
return 0;
}