P9745 「KDOI-06-S」树上异或 - 简要题解
树形 dp,拆位。
大概就是设
初始
最终令
如果不理解为什么这样设计状态,可以先从链的部分分入手。
可以把链理解成对序列计数,
显然你不能对整体进行拆位,但是很快我们发现,可以在计算每个
把链上的思想类比到树上来,应该就能比较自然地得到上面那个状态设计。
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
static constexpr int mod = 998244353, N = 5e5 + 10, P = 60;
namespace basic {
inline int add(int x, int y) {return (x + y >= mod ? x + y - mod : x + y);}
inline int dec(int x, int y) {return (x - y < 0 ? x - y + mod : x - y);}
inline void ad(int &x, int y) {x = add(x, y);}
inline void de(int &x, int y) {x = dec(x, y);}
inline int qpow(int a, int b) {
int r = 1;
while(b) {
if(b & 1) r = 1ll * r * a % mod;
a = 1ll * a * a % mod; b >>= 1;
}
return r;
}
inline int inv(int x) {return qpow(x, mod - 2);}
int fac[N], ifac[N];
inline void fac_init(int n = N - 1) {
fac[0] = 1;
for(int i = 1; i <= n; i++)
fac[i] = 1ll * fac[i - 1] * i % mod;
ifac[n] = inv(fac[n]);
for(int i = n - 1; i >= 0; i--)
ifac[i] = 1ll * ifac[i + 1] * (i + 1) % mod;
}
int invx[N];
inline void inv_init(int n = N - 1) {
invx[1] = 1;
for(int i = 2; i <= n; i++)
invx[i] = 1ll * (mod - mod / i) * invx[mod % i] % mod;
}
inline int binom(int n, int m) {
if(n < m || m < 0) return 0;
return 1ll * fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}
int rev[N];
inline void rev_init(int n) {
for(int i = 1; i < n; i++)
rev[i] = (rev[i >> 1] >> 1) + (i & 1 ? n >> 1 : 0);
}
}
using namespace basic;
i64 a[N]; int p2[P];
int f[N], g[N][P][2];
int n; vector<int> G[N];
void dfs(int u) {
for(int i = 0; i < P; i++) {
g[u][i][a[u] >> i & 1] = 1;
}
for(auto v : G[u]) {
dfs(v);
for(int i = 0; i < P; i++) {
int x = g[u][i][0], y = g[u][i][1]; g[u][i][0] = g[u][i][1] = 0;
ad(g[u][i][0], 1ll * x * f[v] % mod);
ad(g[u][i][0], add(1ll * x * g[v][i][0] % mod, 1ll * y * g[v][i][1] % mod));
ad(g[u][i][1], 1ll * y * f[v] % mod);
ad(g[u][i][1], add(1ll * x * g[v][i][1] % mod, 1ll * y * g[v][i][0] % mod));
}
}
for(int i = 0; i < P; i++) {
ad(f[u], 1ll * g[u][i][1] * p2[i] % mod);
}
}
int main() {
// freopen("xor.in", "r", stdin);
// freopen("xor.out", "w", stdout);
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);
p2[0] = 1;
for(int i = 1; i < P; i++) {
p2[i] = p2[i - 1] * 2 % mod;
}
cin >> n;
for(int i = 1; i <= n; i++) {
cin >> a[i];
}
for(int i = 2; i <= n; i++) {
int father; cin >> father;
G[father].push_back(i);
}
dfs(1);
cout << f[1] << "\n";
}