题解:P11799 【MX-X9-T3】『GROI-R3』Powerless

· · 题解

提供一种好想好写的做法!

\sum_{i=1}^n\sum_{j=1}^n\sum_{k=0}^m \min(a_i \oplus k, a_j \oplus k)

要拆掉 \min

现在我们要计算

2\sum_{i=1}^n \sum_{k=0}^m(a_i \oplus k)[a_i\oplus k < a_j \oplus k] + c_x^2\sum\limits_{k=0}^m x\oplus k

先考虑第一项:

2\sum_{i=1}^n \sum_{k=0}^m(a_i \oplus k)[a_i\oplus k < a_j \oplus k]

\textrm{msb}(x) 表示 x 的最高有效位,x_k 表示 x 在二进制下的第 k 位。

发现影响 a_i \oplus ka_j \oplus k 大小的是 \textrm{msb}(a_i \oplus k)\textrm{msb}(a_j \oplus k),也就取决于两数在第 \textrm{msb}\left((a_i \oplus k) \oplus (a_j \oplus k)\right) = \textrm{msb}\left(a_i \oplus a_j\right) 位的值。和 k 无关。

c = \textrm{msb}(a_i \oplus a_j),这意味着 {a_i}_c \neq {a_j}_c,所以能贡献的 k 一定满足 k_c = {a_i}_c,这样,(a_i\oplus k)_c = 0, (a_j \oplus k)_c = 1,所以一定有 a_i\oplus k < a_j \oplus k

枚举 a_ic,这部分的贡献可以写成

2\left(\sum_{j=1}^n[\textrm{msb}(a_i \oplus a_j) = c]\right)\left(\sum_{k=0}^m(a_i \oplus k)[{a_i}_c = k_c]\right)

左边括号的值,相当于求 (c, 29] 位都与 a_i 相同,第 c 位不同的数的个数。可以把所有 a 插入 01 trie,就变成求子树大小。记为 sz

右边括号的值,拆位,对于第 j 位,所有满足 {a_i}_j \neq k_j{a_i}_c = k_ck 都能产生 2^j 的贡献。

考虑预处理数组 f_{i, j, 0/1, 0/1} 表示 [0, m] 中满足 k_i = 0/1, k_j = 0/1k 的个数。可以用数位 dp 求出。

2\left(\sum_{j=1}^n[\textrm{msb}(a_i \oplus a_j) = c]\right)\left(\sum_{j=0}^{29}2^jf_{j, c, {a_i}_j\oplus 1, {a_i}_c}\right)

再考虑第二项:

c_x^2\sum\limits_{k=0}^m x\oplus k

这就容易多了,一样的思路,拆位,对于第 j 位,所有满足 x_j \neq k_jk 都能产生 2^j 的贡献。利用预处理好的 f 数组。

c_x^2\sum\limits_{j=0}^{29} 2^jf_{j, j, x_j \oplus 1, x_j\oplus 1}

加起来就是最终的答案:

\boxed{2\sum_{i=1}^n\sum_{c=0}^{29}sz\left(\sum_{j=0}^{29}2^jf_{j, c, {a_i}_j\oplus 1, {a_i}_c}\right) + \sum_xc_x^2\sum\limits_{j=0}^{29} 2^jf_{j, j, x_j \oplus 1, x_j\oplus 1}} :::success[实现] ```cpp line-numbers #include <bits/stdc++.h> using namespace std; using ll = long long; using pii = pair<int, int>; #ifdef ONLINE_JUDGE #define debug(...) 0 #else #define debug(...) fprintf(stderr, __VA_ARGS__), fflush(stderr) #endif constexpr int N = 2e5 + 5, mod = 998244353; int n, m; int a[N]; unordered_map<int, int> mp; int tr[N * 29][2], idx; int siz[N * 29]; void insert(int x) { int p = 0; for (int i = 29; i >= 0; i--) { int v = x >> i & 1; if (!tr[p][v]) tr[p][v] = ++idx; p = tr[p][v]; } siz[p]++; } void dfs(int u) { if (!u) return; for (int i = 0; i <= 1; i++) { dfs(tr[u][i]); siz[u] += siz[tr[u][i]]; } } int f[30][30][2][2]; // f[i][j][0 / 1][0 / 1] k \in [0, m] 且使得 k 第 i 位为 0/1,第 j 位为 0/1 的 k 的数量 int dp[30]; int dfs(int pos, int p1, int p2, int v1, int v2, int lim) { if (pos == -1) return 1; if (!lim && ~dp[pos]) return dp[pos]; int res = 0; for (int i = 0; i <= (lim ? m >> pos & 1 : 1); i++) { if (pos == p1 && i != v1) continue; if (pos == p2 && i != v2) continue; res += dfs(pos - 1, p1, p2, v1, v2, lim && (i == (m >> pos & 1))); } return lim ? res : dp[pos] = res; } int main() { ios::sync_with_stdio(false); cin.tie(nullptr); cin >> n >> m; for (int i = 1; i <= n; i++) { cin >> a[i]; mp[a[i]]++; insert(a[i]); } sort(a + 1, a + 1 + n); for (int i = 0; i <= 29; i++) for (int j = 0; j <= 29; j++) for (int v1 = 0; v1 <= 1; v1++) for (int v2 = 0; v2 <= 1; v2++) { memset(dp, -1, sizeof dp); f[i][j][v1][v2] = dfs(29, i, j, v1, v2, 1); } for (int i = 0; i <= 1; i++) if (tr[0][i]) dfs(tr[0][i]); int ans = 0; for (int i = 1; i <= n; i++) { int x = a[i]; for (int c = 0; c <= 29; c++) { int p = 0; for (int j = 29; j > c; j--) p = tr[p][x >> j & 1]; int sz = siz[tr[p][!(x >> c & 1)]]; int con = 0; for (int j = 0; j <= 29; j++) con = (con + 1ll * (1 << j) * f[j][c][!(x >> j & 1)][x >> c & 1]) % mod; ans = (ans + 1ll * sz * con) % mod; } } ans = 1ll * ans * 2 % mod; n = unique(a + 1, a + 1 + n) - a - 1; for (int i = 1; i <= n; i++) { int res = 0; for (int j = 0; j <= 29; j++) { int v = a[i] >> j & 1; res = (res + 1ll * (1 << j) * f[j][j][v ^ 1][v ^ 1]) % mod; } ans = (ans + 1ll * mp[a[i]] * mp[a[i]] % mod * res) % mod; } cout << ans << "\n"; return 0; } ``` :::