题解:P10808 [LMXOI Round 2] Annihilation
T0mle
·
·
题解
真是紫题吗 /yun,写吐了。
我们记 S(u,d) 表示 u 子树内编号为 d 的倍数的点集。
化式子
\begin{aligned}
ans_u&=\sum_{d=1}^nb_d\sum_{S'\subset S(u,d)}[\gcd\{S'\}=d\land k\mid\max_{x\in S'}\{a_x\}]
\\&=\sum_{T=1}^n(\sum_{d\mid T}b_d\mu(\frac Td))\sum_{S'\subset S(u,T)}[k\mid\max_{x\in S'}\{a_x\}]
\end{aligned}
对于每个 $T$ 枚举所有编号为 $T$ 倍数的点集 $S_{all}$,显然对于所有 $u$ 有 $S(u,T)\subset S_{all}$。
我们可以对 $S_{all}$ 的所有点建虚树,接下来需要解决 $k\mid\max\{a\}$ 的限制。
假如我们能够得到 $u$ 子树内所有点的点权,我们可以对点权建权值线段树,然后在线段树上维护区间包含多少点权,以及区间内选点使最大值为 $k$ 倍数的方案数。pushup 是容易的。
然后在虚树上线段树合并即可。统计答案树上差分即可。
复杂度?每种 $T$ 的 $S_{all}$ 大小之和是 $O(n\ln n)$,加上线段树合并是 $O(n\ln n\log n)$。
史:
```cpp
#include <bits/stdc++.h>
using namespace std;
const int N = 100005;
// const int V = ;
const int mod = 998244353;
typedef unsigned us;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair <int, int> pii;
typedef vector <int> vi;
typedef vector <pii> vpi;
typedef vector <ll> vl;
template <class T> using pq = priority_queue <T>;
template <class T> using pqg = priority_queue <T, vector <T>, greater <T> >;
#define rep(i, a, b) for (int i = (a); i <= (b); ++i)
#define repr(i, a, b) for (int i = (a); i < (b); ++i)
#define per(i, a, b) for (int i = (a); i >= (b); --i)
#define perr(i, a, b) for (int i = (a); i > (b); --i)
#define fi first
#define se second
#define lb lower_bound
#define ub upper_bound
#define pb push_back
template <class T1, class T2> inline void ckmn(T1 &a, T2 b) { (a > b) && (a = b, 0); }
template <class T1, class T2> inline void ckmx(T1 &a, T2 b) { (a < b) && (a = b, 0); }
namespace IO {
// char buf[1 << 23], *p1 = buf, *p2 = buf;
// #define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
template <class T> void rd(T &a, unsigned c = 0) {
while (c = getchar(), c < 48 || c > 57);
for (a = 0; c >= 48 && c <= 57; c = getchar()) a = (a << 3) + (a << 1) + (c ^ 48);
}
template <class T> void wrt(T x) { if (x > 9) wrt(x / 10); putchar(x % 10 ^ 48); }
} using IO::rd; using IO::wrt;
int n, k, u, v;
int a[N], b[N];
vi T[N], vT[N];
int dep[N], anc[N][20], dfn[N], ind;
int ans[N];
int mk[N];
int A[N], B[N];
int p[N], pc, np[N], mu[N];
int pw2[N];
struct Seg {
int ls, rs, cnt, dat;
#define ls(p) tr[p].ls
#define rs(p) tr[p].rs
#define cnt(p) tr[p].cnt
#define dat(p) tr[p].dat
} tr[N << 5];
#define M (L + R >> 1)
int stot;
inline int Nw() { stot++, ls(stot) = rs(stot) = cnt(stot) = dat(stot) = 0; return stot; }
void sieve() {
mu[1] = 1;
rep(i, 2, 1e5) {
if (!np[i]) p[++pc] = i, mu[i] = mod - 1;
for (int j = 1; i * p[j] <= 1e5; j++) {
np[i * p[j]] = 1;
if (i % p[j]) mu[i * p[j]] = mu[i] ? mod - mu[i] : 0;
else break;
}
}
}
void dfs1(int u, int fa) {
dep[u] = dep[fa] + 1, anc[u][0] = fa, dfn[u] = ++ind;
for (auto v : T[u]) if (v != fa) dfs1(v, u);
}
int lca(int u, int v) {
if (dep[u] < dep[v]) swap(u, v);
per(i, 18, 0) if (dep[anc[u][i]] >= dep[v]) u = anc[u][i];
if (u == v) return u;
per(i, 18, 0) if (anc[u][i] != anc[v][i]) u = anc[u][i], v = anc[v][i];
return anc[u][0];
}
bool cmp(int x, int y) { return dfn[x] < dfn[y]; }
int make_vtree(int len) {
sort(A + 1, A + len + 1, cmp);
int len2 = 0;
repr(i, 1, len) {
B[++len2] = A[i];
B[++len2] = lca(A[i], A[i + 1]);
}
B[++len2] = A[len];
sort(B + 1, B + len2 + 1, cmp);
len2 = unique(B + 1, B + len2 + 1) - B - 1;
rep(i, 1, len2) vT[B[i]].clear();
repr(i, 1, len2) {
int lc = lca(B[i], B[i + 1]);
// cout << B[i] << " " << B[i + 1] << " " << lc << endl;
vT[lc].pb(B[i + 1]);
}
return B[1];
}
inline void pushup(int p) {
dat(p) = (dat(ls(p)) + 1ll * pw2[cnt(ls(p))] * dat(rs(p))) % mod;
cnt(p) = cnt(ls(p)) + cnt(rs(p));
}
void ins(int p, int x, int L, int R) {
if (L == R) {
cnt(p)++;
if (x % k == 0) dat(p) = (pw2[cnt(p)] - 1 + mod) % mod;
else dat(p) = 0;
return;
}
if (x <= M) {
if (!ls(p)) ls(p) = Nw();
ins(ls(p), x, L, M);
} else {
if (!rs(p)) rs(p) = Nw();
ins(rs(p), x, M + 1, R);
}
pushup(p);
}
int merge(int u, int v, int L, int R) {
if (!u || !v) return u | v;
if (L == R) {
cnt(u) += cnt(v);
if (L % k == 0) dat(u) = (pw2[cnt(u)] - 1 + mod) % mod;
else dat(u) = 0;
return u;
}
ls(u) = merge(ls(u), ls(v), L, M);
rs(u) = merge(rs(u), rs(v), M + 1, R);
pushup(u);
return u;
}
// void out(int p, int L, int R) {
// if (!p) return;
// cout << p << " " << L << " " << R << " " << cnt(p) << " " << dat(p) << endl;
// out(ls(p), L, M), out(rs(p), M + 1, R);
// }
int dfs2(int u, int fa, int ca) {
int rt = Nw();
if (mk[u]) ins(rt, a[u], 1, n);
for (auto v : vT[u]) {
int x = dfs2(v, u, ca);
rt = merge(rt, x, 1, n);
}
int e = 1ll * dat(rt) * ca % mod;
// cout << u << " " << fa << " " << ca << " " << e << " " << mk[u] << endl;
ans[u] = (ans[u] + e) % mod;
ans[fa] = (ans[fa] - e + mod) % mod;
// puts("tree:");
// out(rt, 1, n);
return rt;
}
void dfs3(int u, int fa) {
for (auto v : T[u]) {
if (v == fa) continue;
dfs3(v, u);
ans[u] = (ans[u] + ans[v]) % mod;
}
}
void solve() {
rd(n), rd(k);
repr(i, 1, n) {
rd(u), rd(v);
T[u].pb(v), T[v].pb(u);
}
rep(i, 1, n) rd(a[i]);
rep(i, 1, n) rd(b[i]);
sieve();
rep(i, 1, n) {
for (int j = 1; i * j <= n; j++) {
A[i * j] = (A[i * j] + 1ll * mu[i] * b[j]) % mod;
}
}
swap(b, A);
dfs1(1, 0);
rep(j, 1, 18) {
rep(i, 1, n) {
anc[i][j] = anc[anc[i][j - 1]][j - 1];
}
}
rep(i, pw2[0] = 1, n) pw2[i] = 2 * pw2[i - 1] % mod;
rep(i, 1, n) {
// cout << i << endl;
int tot = 0;
for (int j = i; j <= n; j += i) A[++tot] = j, mk[j] = 1;
int rt = make_vtree(tot);
stot = 0;
dfs2(rt, 0, b[i]);
for (int j = i; j <= n; j += i) mk[j] = 0;
}
dfs3(1, 0);
rep(i, 1, n) wrt(ans[i]), putchar(32);
}
int main() {
int T = 1;
if (0) rd(T);
while (T--) solve();
}
```