题解:P13818 「LDOI R3」泡泡抗特

· · 题解

废话(可跳过)

somewhere over the rainbow way up high.

there's a land that I heard of once in a lullaby.

only blue. only blue.

愛讓人,好憂鬱。

我的心。我的心。藍藍的。

题目背景出自陶喆 soul power 演唱会版本的《沙滩》。赞美出题人。

调了超级久,最后发现先模再除先除再模是不一样的,卡了两天。。

正文

根据异或零元律的性质推出以下三点:

:::info[零元律] 若 a \oplus b = 0,则 a = b。 :::

一个最直接的想法是枚举 (i,j),然后判断三个数是否符合条件。而这个方法的时间复杂度似乎是 O(n^2) 的。

但实际上不是 O(n^2),或者说不需要 O(n^2)

根据 \mathrm{popcount}(x) 的定义,a_ia_j 在二进制下 1 的个数不多于 2。 进一步地想,所有满足 \mathrm{popcount}(x) \le 2x,其实不超过 C^2_{120} + C^1_{120} = 7260 个。这个证明很简单,相当于二进制随机取两位或一位为 1

那么我们在输入的时候,将满足 \mathrm{popcount}(a_i) \le 2a_i 筛选出来不就好了吗?这样一来,数的总数被我们压缩到 7260 以下。而这样的 a_i 总是可以被拆成 2^p + 2^q 的形式,我们完全可以用两个变量 p,q 表示一个符合条件的 a_i

这时候我们枚举 arr2 就简单多了,m = 7000 时完全可以在 O(m^2) 的时间复杂度下跑过去。

你说这样枚举有啥用?还不是要看每个数的相对位置?

大可不必。注意到我们要满足的三元组仅需要下标递增,也就是说参与运算的数怎么排都行。那我们找到三个符合条件的数时,我们必定能找到一个下标递增的排列方式。基于此,每对 (i,j) 能提供的方案数就是三个数出现次数之积。

具体实现的时候,我用 mp[p][q] 记录一个数的出现次数,每次枚举 ij,然后模拟一遍异或的过程。

手玩发现只有两个 a_ia_j 有相同位置的 1 或都只有一个 1 时,它们异或的结果才满足 \mathrm{popcount}(a_i \oplus a_j) \le 2

模拟的时候分类讨论一下即可。还要注意三元组中每两个数都会被计入 ans,输出答案时应输出 \frac{ans}{3}

# include <bits/stdc++.h>
# define mod 1000000007

using namespace std;

long long mp[125][125];
bool chk[125][125][125][125];
__int128 arr1[300005];

struct new_arr
{
    int id1,id2;
};
struct new_arr arr2[300005];

long long ans;
int newn;

void redi (__int128& ret) {
    ret = 0; int f = 1; char ch = getchar();
    while (ch < '0' || ch > '9') {if (ch == '-') f = -f; ch = getchar();}
    while (ch >= '0' && ch <= '9') ret = ret * 10 + ch - '0', ch = getchar();
    ret *= f;
} // 调用 redi(x) 以读入变量 x。

void add(__int128 x)  // 筛选 
{
    int id1=125,id2=125;  // 第一个1出现的位置,第二个1出现的位置 
    int idx=1; // 当前二进制位置 
    __int128 tmp = x;
    while (x > 0)
    {
        if ((x&1) == 1)
        {
            if (id1 == 125) id1 = idx;
            else if(id2 == 125) id2 = idx;
            else return ;
        } 
        x = x>>1;
        idx++;
    }
    if (!mp[id1][id2])
    {
        arr2[newn].id1 = id1;
        arr2[newn++].id2 = id2;
    }
    mp[id1][id2]++; 
    return ;
}

int main (void)
{
    int T;
    scanf ("%d",&T);

    while (T--)
    {
        int n;
        scanf ("%d",&n);
        for (int i=0;i<n;i++)
        {
            redi(arr1[i]);
            add(arr1[i]);
        }
        for (int i=0;i<newn;i++)
        {
            for (int j=i+1;j<newn;j++)
            {
                int id1,id2;
                if (arr2[i].id1 == arr2[j].id1) // 其中两位相同 
                {
                    id1 = min(arr2[i].id2,arr2[j].id2);
                    id2 = max(arr2[i].id2,arr2[j].id2);                 
                } 
                else if (arr2[i].id1 == arr2[j].id2)
                {
                    id1 = min(arr2[i].id2,arr2[j].id1);
                    id2 = max(arr2[i].id2,arr2[j].id1);                     
                }
                else if (arr2[i].id2 == arr2[j].id1)
                {
                    id1 = min(arr2[i].id1,arr2[j].id2);
                    id2 = max(arr2[i].id1,arr2[j].id2);                     
                }               
                else if (arr2[i].id2 == arr2[j].id2)
                {
                    id1 = min(arr2[i].id1,arr2[j].id1);
                    id2 = max(arr2[i].id1,arr2[j].id1);                     
                }               
                else 
                {
                    continue;
                }

                long long n1 = mp[id1][id2];
                long long n2 = mp[arr2[i].id1][arr2[i].id2];
                long long n3 = mp[arr2[j].id1][arr2[j].id2];

                ans += n1*n2*n3;
            }
        }
        printf ("%lld\n",(ans/3) % mod);
        for (int i=0;i<125;i++) //重置
        {
            for (int j=0;j<125;j++)
            {
                mp[i][j] = 0;
            }
        }
        newn=0;
        ans=0;
    }
    return 0;   
}