题解 CF241B 【Friends】
tommymio
·
·
题解
\texttt{可能更好的阅读体验}
链接
CodeForces 241B
题意
给定一个长度为 N 的序列,求前 k 大的 a_i \operatorname{xor} a_j 之和,答案对 10^9+7 取模。
其中 a_i \operatorname{xor} a_j 与 a_j \operatorname{xor} a_i 被看成是同一对。
## 约定
以下称 $a_i \operatorname{xor} a_j$ 的值为异或值。
$b_v[i..j]$ 为值 $v$ 在 **二进制意义** 下第 $i$ 位至第 $j$ 位的取值,如 $b_3[0..1]=11_2$,$b_5[1..2]=10_2$,$b_5[0,2]=101_2$(下标 $2$ 即表示二进制意义下)。
二进制意义下的最高位为 $mx$。
## 题解
此题是 **[十二省联考2019]异或粽子** 的加强版,解法吊打 $\texttt{std}$。
提供了一种时间复杂度与 $k$ 无关的解法,时间复杂度为 $O(n\omega^2+n\log n)$,其中 $\omega$ 是一个常数,意义为在二进制意义下 $a_i$ 的最高位,在本题中 $\omega \approx 29$。
首先可以想到,要求出前 $k$ 大的异或和,必须先求出第 $k$ 大的异或和。
有一种很暴力的做法,枚举每个 $a_i,a_j$,求出第 $k$ 大的异或值。无疑可以优化,枚举每一个 $a_i$,将上述过程放至 $\texttt{trie}$ 上,从高位贪心,用类似权值线段树求第 $k$ 大的方式在 $\texttt{trie}$ 上二分即可。
更加具体地说,我们将一棵 $\texttt{trie}$ 分为 $n$ 层,第 $i$ 层的指针即 $a_i$ 所对应的 $a_j$ 的前缀。从最高位开始枚举第 $k$ 大的 $a_i \operatorname{xor} a_j$ 的每一位上取 $0/1$,在枚举到第 $x$ 位时,在每一层上计算 $a_i \operatorname{xor} a_j$ 的第 $x$ 位为 $1$ 的方案数,记这个数为 $cnt$。若 $cnt\geq k$,则说明这一位必须取 $1$,否则所求会小于第 $k$ 大。若 $cnt < k$,则说明这一位必须取 $0$ ,否则所求会大于第 $k$ 大。在 $n$ 层 $trie$ 上对应更新指针即可,并统计第 $x$ 位的取值,更新第 $k$ 大的异或值在 $x$ 上的取值即可。
求出第 $k$ 大的 $a_i \operatorname{xor} a_j$ 后,记其为 $val$。若 $val$ 第 $x$ 位上为 $0$,则若存在一个异或值 $t$,令$b_{val}[x+1..mx]=b_t[x+1..mx]$,仅在 $t$ 的第 $x$ 位上为 $1$,$t$ 一定大于 $val$ ,属于前 $k$ 大的取值。
考虑计算这部分值的贡献。仍然是从高位贪心,分成 $n$ 层。每层按照 $val$ 的值在 $\texttt{trie}$ 上递归,枚举值 $val$ 在二进制意义上为 $0$ 的那些位,统计仅在这一位上取 $1$ ,第 $x+1-mx$ 位上与 $val$ 在二进制意义下取值相同的那些异或值对答案造成的贡献。
由于即使确定了 $a_i$,$a_j$ 也是在一棵子树中,对一棵子树统计答案,不太好做。可以在初始时直接对 $a_i$ 排序,这样 **$\texttt{trie}$ 的一棵子树即对应了 $a$ 上一段连续的区间**。统计贡献时按位统计,对每一位上的 $0/1$ 分别维护一个前缀和,这样就可以了。
**Show the Code**
```cpp
#include<cstdio>
#include<algorithm>
typedef long long ll;
const ll mod=1e9+7;
int n,tot=0,rt=0;
int a[50005],tmp[50005];
int sum[35][50005][2];
int ch[1000005][2],size[1000005],l[1000005],r[1000005];
inline ll read() {
register ll x=0,f=1;register char s=getchar();
while(s>'9'||s<'0') {if(s=='-') f=-1;s=getchar();}
while(s>='0'&&s<='9') {x=x*10+s-'0';s=getchar();}
return x*f;
}
inline ll Add(ll x,ll y) {return ((x+y)%mod+mod)%mod;}
inline ll pow(ll x,ll p) {
ll res=1;
for(;p;p>>=1) {if(p&1) res=res*x%mod;x=x*x%mod;}
return res;
}
inline void insert(int val,int id) {
if(!rt) rt=++tot,l[rt]=1,r[rt]=n;
int p=rt;++size[rt];
for(register int d=31;d>=0;--d) {
int &nxt=ch[p][val>>d&1];
if(!nxt) {nxt=++tot;l[nxt]=id;}
p=nxt;r[p]=id;++size[p];
}
}
inline int getVal(ll &k) {
int res=0;
for(register int i=1;i<=n;++i) tmp[i]=rt;
for(register int d=31;d>=0;--d) {//若当前取1的个数为cnt,cnt>k
int x=0;
ll cnt=0;
for(register int i=1;i<=n;++i) cnt+=size[ch[tmp[i]][(a[i]>>d&1)^1]];
if(cnt>=k) res|=(1<<d),x=1;//cnt<k
else k-=cnt,x=0;
for(register int i=1;i<=n;++i) tmp[i]=ch[tmp[i]][(a[i]>>d&1)^x];
}
return res;
}
inline ll ask(int l,int r,int val) {
if(!l||!r) return 0;
ll res=0;
for(register int d=31;d>=0;--d) res=Add(res,(ll)(sum[d][r][(val>>d&1)^1]-sum[d][l-1][(val>>d&1)^1])*((1ll<<d)%mod)%mod);
return res;
}
inline ll getSum(ll k) {
int val=getVal(k);
ll res=val*k%mod;
for(register int i=1;i<=n;++i) tmp[i]=rt;
for(register int d=31;d>=0;--d) {
int x=val>>d&1;
if(!x) {for(register int i=1;i<=n;++i) res=Add(res,ask(l[ch[tmp[i]][(a[i]>>d&1)^1]],r[ch[tmp[i]][(a[i]>>d&1)^1]],a[i]));}
for(register int i=1;i<=n;++i) tmp[i]=ch[tmp[i]][(a[i]>>d&1)^x];
}
return res;
}
signed main() {
n=read();
ll k=read()*2ll;//25*1e8=2.5*1e9
for(register int i=1;i<=n;++i) a[i]=read();
std::sort(a+1,a+1+n);
for(register int i=1;i<=n;++i) insert(a[i],i);
for(register int i=1;i<=n;++i) {
for(register int d=31;d>=0;--d) {
sum[d][i][0]=sum[d][i-1][0];
sum[d][i][1]=sum[d][i-1][1];
++sum[d][i][a[i]>>d&1];
}
}
printf("%lld\n",getSum(k)*pow(2,mod-2)%mod);
return 0;
}
```