P9745 「KDOI-06-S」树上异或 - 简要题解

· · 题解

树形 dp,拆位。

大概就是设 f_u 表示子树 u 的答案,g_{u, i, 0/1} 表示以 u 为根的子树,考虑点 u 所在的连通块,其点权异或和在二进制下第 i 位为 0/1 时,将以 u 为根的子树去掉根所处的连通块后,剩余各部分的 f 在各种方案下的积的总和(即一个 \sum \prod 的形式)。

初始 g_{u, i, {a_u}_i} \gets 1,转移其实挺简单,考虑每次加入一个儿子 v 时,g_{u, i, 0} \gets g_{u, i, 0} \times f_v + g_{u, i, 0} \times g_{v, i, 0} + g_{u, i, 1} \times g_{v, i, 1},对 g_{u, i, 1} 同理。

最终令 f_u \gets \sum \limits_{i = 0} ^ {60} g_{u, i, 1} \times 2 ^ i 即可。

如果不理解为什么这样设计状态,可以先从链的部分分入手。

可以把链理解成对序列计数,f_i 表示考虑前 i 个位置的总答案。设 s 为前缀异或和,那么 f_i = \sum \limits_{j = 0} ^ {i - 1} f_j \times (s_i \oplus s_j),初值 f_0 = 1

显然你不能对整体进行拆位,但是很快我们发现,可以在计算每个 f_u 的时候拆位,具体来讲就是枚举 s_i \oplus s_j 的每一位的情况,然后这样就可以 \log 转移了。

把链上的思想类比到树上来,应该就能比较自然地得到上面那个状态设计。

#include <bits/stdc++.h>
using namespace std;
using i64 = long long;

static constexpr int mod = 998244353, N = 5e5 + 10, P = 60;
namespace basic {
    inline int add(int x, int y) {return (x + y >= mod ? x + y - mod : x + y);}
    inline int dec(int x, int y) {return (x - y < 0 ? x - y + mod : x - y);}
    inline void ad(int &x, int y) {x = add(x, y);}
    inline void de(int &x, int y) {x = dec(x, y);}

    inline int qpow(int a, int b) {
        int r = 1;
        while(b) {
            if(b & 1) r = 1ll * r * a % mod;
            a = 1ll * a * a % mod; b >>= 1;
        }
        return r;
    }
    inline int inv(int x) {return qpow(x, mod - 2);}

    int fac[N], ifac[N];
    inline void fac_init(int n = N - 1) {
        fac[0] = 1;
        for(int i = 1; i <= n; i++)
            fac[i] = 1ll * fac[i - 1] * i % mod;
        ifac[n] = inv(fac[n]);
        for(int i = n - 1; i >= 0; i--)
            ifac[i] = 1ll * ifac[i + 1] * (i + 1) % mod;
    }
    int invx[N];
    inline void inv_init(int n = N - 1) {
        invx[1] = 1;
        for(int i = 2; i <= n; i++)
            invx[i] = 1ll * (mod - mod / i) * invx[mod % i] % mod;
    }
    inline int binom(int n, int m) {
        if(n < m || m < 0) return 0;
        return 1ll * fac[n] * ifac[m] % mod * ifac[n - m] % mod;
    }

    int rev[N];
    inline void rev_init(int n) {
        for(int i = 1; i < n; i++)
            rev[i] = (rev[i >> 1] >> 1) + (i & 1 ? n >> 1 : 0);
    }
}
using namespace basic;

i64 a[N]; int p2[P];
int f[N], g[N][P][2];

int n; vector<int> G[N];

void dfs(int u) {
    for(int i = 0; i < P; i++) {
        g[u][i][a[u] >> i & 1] = 1; 
    }
    for(auto v : G[u]) {
        dfs(v);
        for(int i = 0; i < P; i++) {
            int x = g[u][i][0], y = g[u][i][1]; g[u][i][0] = g[u][i][1] = 0;
            ad(g[u][i][0], 1ll * x * f[v] % mod);
            ad(g[u][i][0], add(1ll * x * g[v][i][0] % mod, 1ll * y * g[v][i][1] % mod));
            ad(g[u][i][1], 1ll * y * f[v] % mod);
            ad(g[u][i][1], add(1ll * x * g[v][i][1] % mod, 1ll * y * g[v][i][0] % mod));
        }
    }
    for(int i = 0; i < P; i++) {
        ad(f[u], 1ll * g[u][i][1] * p2[i] % mod);
    }
}

int main() {
    // freopen("xor.in", "r", stdin);
    // freopen("xor.out", "w", stdout);
    ios::sync_with_stdio(false); 
    cin.tie(nullptr); cout.tie(nullptr);

    p2[0] = 1;
    for(int i = 1; i < P; i++) {
        p2[i] = p2[i - 1] * 2 % mod;
    }

    cin >> n;
    for(int i = 1; i <= n; i++) {
        cin >> a[i];
    }
    for(int i = 2; i <= n; i++) {
        int father; cin >> father;
        G[father].push_back(i);
    }
    dfs(1);
    cout << f[1] << "\n";
}