住在拔作岛上的贫测机该如何是好
联考 NOIP T4,由于莫反细节写错导致
先把 LCA 限制去掉,我们考虑求出
然后用莫反把 vector。
卡常小技巧:处理
#include <bits/stdc++.h>
#define ull unsigned long long
#define uint unsigned int
#define LL long long
using namespace std;
const int N = 2e5 + 10;
const int M = 5e5 + 10;
const uint MOD = 998244353;
int n, m = 5e5, A[N]; vector<int> G[N];
int prime[M], ptot; uint mu[M]; bool is[M], vis[M];
vector<int> D[M]; int cnt[M];
void Sieve() {
mu[1] = 1;
for (int i = 2; i <= m; i ++) {
if (!is[i]) { mu[i] = MOD - 1, prime[++ ptot] = i; }
for (int j = 1; j <= ptot && i * prime[j] <= m; j ++) {
is[i * prime[j]] = 1, mu[i * prime[j]] = (MOD - mu[i]) % MOD;
if (i % prime[j] == 0) { mu[i * prime[j]] = 0; break; }
}
}
return ;
}
uint cur = 0, sum[M], Ans[N];
int sz[N], hson[N], dfn[N], R[N], dfncnt, inv[N]; uint W[N];
void DFS1(int u, int f) {
sz[u] = 1; dfn[u] = R[u] = ++ dfncnt; inv[dfncnt] = u; W[u] = A[u];
for (int v : G[u]) if (v != f) {
DFS1(v, u); if (sz[v] > sz[hson[u]]) hson[u] = v;
R[u] = R[v]; sz[u] += sz[v];
Ans[u] = (Ans[u] + Ans[v] + 1ull * W[u] * W[v]) % MOD;
W[u] = (W[u] + W[v]) % MOD;
} return ;
}
void add(int u, uint k) {
for (int d : D[u]) {
cur = (cur + 1ull * mu[d] * sum[d] % MOD * k % MOD) % MOD;
sum[d] = (sum[d] + k) % MOD;
} return ;
}
void del(int u, uint k) {
for (int d : D[u]) {
sum[d] = (sum[d] + MOD - k) % MOD;
cur = (cur + 1ull * mu[d] * sum[d] % MOD * (MOD - k) % MOD) % MOD;
} return ;
}
void DFS2(int u, int f, bool kp) {
for (int v : G[u]) if (v != f && v != hson[u]) DFS2(v, u, false);
if (hson[u]) { DFS2(hson[u], u, true); }
for (int v : G[u]) if (v != f && v != hson[u])
for (int i = dfn[v]; i <= R[v]; i ++) add(A[inv[i]], A[inv[i]]);
add(A[u], A[u]); Ans[u] = (Ans[u] + MOD - cur) % MOD;
if (!kp) for (int i = dfn[u]; i <= R[u]; i ++) del(A[inv[i]], A[inv[i]]);
}
#define getchar getchar_unlocked
#define putchar putchar_unlocked
int read() {
int x = 0; char c = getchar();
for (; !isdigit(c); c = getchar());
for (; isdigit(c); c = getchar()) x = (x << 1) + (x << 3) + (c ^ 48);
return x;
}
void write(int x) {
if (x >= 10) write(x / 10);
putchar(x % 10 + 48);
}
int main() {
freopen(".in", "r", stdin); freopen(".ans", "w", stdout);
n = read(); Sieve();
for (int i = 1; i <= n; i ++) A[i] = read(), vis[A[i]] = 1;
for (int i = 1, u, v; i < n; i ++) {
u = read(), v = read(); G[u].push_back(v); G[v].push_back(u);
}
for (int i = 1; i <= m; i ++) if (mu[i] != 0)
for (int j = i; j <= m; j += i) if (vis[j]) cnt[j] ++;
for (int i = 1; i <= m; i ++) if (vis[i]) D[i].resize(cnt[i]), cnt[i] = 0;
for (int i = 1; i <= m; i ++) if (mu[i] != 0)
for (int j = i; j <= m; j += i) if (vis[j]) D[j][cnt[j] ++] = i;
DFS1(1, 0); DFS2(1, 0, true);
for (int i = 1; i <= n; i ++)
for (int j : G[inv[i]]) if (dfn[j] > i) Ans[inv[i]] = (Ans[inv[i]] + MOD - Ans[j]) % MOD;
for (int i = 1; i <= n; i ++) write(Ans[i]), putchar('\n');
return 0;
}