题解:AT_abc425_g [ABC425G] Sum of Min of XOR

· · 题解

解法

看到异或想到 01-Trie。

考虑对于每一个 x,求出最小的异或值。

我们可以贪心地做。假如我们已经知道了一个数,什么情况下另一个数与这个数的异或值最小呢?对了,就是在这两个数相等的时候。

所以我们在 01-Trie 上从上往下走,每次尽可能走与 x 的这一二进制位一致的数即可。

但是这种方法的时间复杂度依旧过不去,怎么优化呢?

我们换种角度,从 01-Trie 节点的角度来考虑。

从上往下走,记录经过该点的 x 数量 cnt 和深度(从大到小) dep

延续之前的贪心做法,如果两个子节点都存在,那么将所有 x 放进它的下一位对应的子节点里,否则全部放进唯一的子节点里,并且对答案产生贡献。

但有个问题,怎么求出下一位为 01x 个数呢?

很明显,由于每个 x 都是连续的,所以经过这个节点的 x 的下一位一定是从 01 的,由于 0 是一定会先加满才会变成 1,所以下一位为 0 的个数为 \min(cnt,2^{dep-1}),为 1 的个数为 cnt-\min(cnt,2^{dep-1})

可以发现,cnt2^{dep} 的节点会经常出现,所以我们可以先预处理出所有这样的节点的答案。

这样时间复杂度就可以足够通过本题了。

代码

#include <bits/stdc++.h>
#define int long long
const int N = 2e5 + 5;
const int Mod = 1e9 + 7;
using namespace std;
int n, m;
int a[N];
struct node
{
    int ch[2] = {-1, -1};
    int cnt;
    int val;
} trie[N * 30];
int tot;
void insert(int x, int p)
{
    for (int i = 29; i >= 0; i--)
    {
        bool c = x & (1ll << i);
        if (trie[p].ch[c] == -1)
        {
            trie[p].ch[c] = ++tot;
            trie[tot].cnt = 1ll << i;
        }
        p = trie[p].ch[c];
    }
}
// 求走满情况下
void pre(int s)
{
    int ls = trie[s].ch[0];
    int rs = trie[s].ch[1];
    if (ls != -1 && rs != -1)
    {
        pre(ls);
        pre(rs);
        trie[s].val = trie[ls].val + trie[rs].val;
    }
    else if (ls != -1)
    {
        pre(ls);
        trie[s].val = trie[ls].val * 2 + trie[ls].cnt * trie[ls].cnt;
    }
    else if (rs != -1)
    {
        pre(rs);
        trie[s].val = trie[rs].val * 2 + trie[rs].cnt * trie[rs].cnt;
    }
}
int bfs()
{
    int ans = 0;
    queue<pair<int, int>> q;
    q.emplace(0, m);
    while (!q.empty())
    {
        auto [s, cnt] = q.front();
        q.pop();
        if (!cnt)
        {
            continue;
        }
        int ls = trie[s].ch[0];
        int rs = trie[s].ch[1];
        int nl = min(cnt, trie[s].cnt >> 1);
        int nr = cnt - nl;
        if (!nr)
        {
            if (ls != -1)
            {
                q.emplace(ls, nl);
            }
            else
            {
                q.emplace(rs, nl);
                ans += trie[rs].cnt * nl;
            }
        }
        else
        {
            if (ls != -1 && rs != -1)
            {
                ans += trie[ls].val;
                q.emplace(rs, nr);
            }
            else if (ls != -1)
            {
                q.emplace(ls, nr);
                ans += trie[ls].cnt * nr + trie[ls].val;
            }
            else if (rs != -1)
            {
                q.emplace(rs, nr);
                ans += trie[rs].cnt * nl + trie[rs].val;
            }
        }
    }
    return ans;
}
signed main()
{
    cin >> n >> m;
    for (int i = 1; i <= n; i++)
    {
        cin >> a[i];
        insert(a[i], 0);
    }
    trie[0].cnt = 1ll << 30;
    pre(0);
    cout << bfs();
    return 0;
}