[COCI2021-2022#1] Set - 题解

· · 题解

这里给一个三进制快速沃尔什变换的做法,可能做得比较麻烦。

本题题意比较抽象,我们需要从题目信息中寻找一些三元组的性质。这里有一个性质可供使用:对于同一位上的三个数字 b_{i,z},b_{j,z},b_{k,z},其所有合法的组合对应的数字加起来都是 3 的倍数,即 (b_{i,z}+b_{j,z}+b_{k,z})\operatorname{mod} 3=0。 我们发现如果把题目中的 1,2,3 看作 0,1,2 的话,这个性质不会改变,而它同时又被赋予了三进制异或的性质。所谓 k 进制异或其实就是按位相加(不进位)并按位对 k 取模。

考虑到本质不同三元组个数为 3^m,设 c_i 表示 i 三进制下是否有对应的 m 元组,如果有值为 1,否则为 0。题中要的是 (b_{i,z}+b_{j,z}+b_{k,z})\operatorname{mod} 3=0,其实就是 (b_{i,z}\oplus_3 b_{j,z} \oplus_3 b_{k,z})=0,这里的 \oplus_k 表示 k 进制异或。我们发现这个和 FWT 的过程十分相似,即求:

F_r=\sum_{i\oplus_3 j \oplus_3 k=r}c_ic_jc_k

我们在本题中只需要对 c 做三次 FWT,取出常数项 F_0 即可得到答案 \frac {F_0-n}{6}。取出常数项的理由很简单,我们要的 F_0 其实就是

F_0=\sum_{i\oplus_3 j \oplus_3 k=0}c_ic_jc_k

关于减 n 除以 6 的问题,其实是去掉重复选取的三元组(题中要的是本质不同三元组)。减 n 是因为任意整数在三进制异或自己三次之后都为零,一共 n 个三元组,所以要减去这 n 个三元组自己三进制异或三次之后的贡献。除以 6 是因为 3 个整数调换顺序进行三进制取模不影响异或结果,而 3 个数字有 6 种不同的排列方法,我们要的本质不同三元组和顺序无关,也就是说这 6 种情况本质是一种选法,要除以 6。

关于 k 进制 FWT,这个知识和 FFT 以及二进制 FWT 很相似,我在这里放一个 k 进制 FWT 的转移矩阵,大家可以参考这个矩阵实现任意进制 FWT。

FWT and

转移矩阵:

1 & 1 & 1 & \cdots & 1 \\ 0 & 1 & 1 & \cdots & 1 \\ 0 & 0 & 1 & \cdots & 1 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & 0 & \cdots & 1 \end{bmatrix}

FWT and 相当于做后缀和,逆运算是差分。

FWT or

转移矩阵:

1 & 0 & 0 & \cdots & 0 \\ 1 & 1 & 0 & \cdots & 0 \\ 1 & 1 & 1 & \cdots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & 1 & 1 & \cdots & 1 \end{bmatrix}

FWT or 相当于做前缀和,逆运算是差分。

FWT xor

转移矩阵:

1 &1 & 1 & \cdots & 1 \\ 1 & \omega_n & \omega_n^2 & \cdots & \omega_n^{n-1} \\ 1 & \omega_{n}^2 & \omega_n^{4} & \cdots & \omega_n^{2(n-1)} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & \omega_n^{n-1} & \omega_n^{2(n-1)} & \cdots & \omega_n^{(n-1)(n-1)} \end{bmatrix}

其中 \omega 是单位复根,单位复根的概念可以参考 FFT 中的概念。

逆转移矩阵:

1 &1 & 1 & \cdots & 1 \\ 1 & \omega_{n}^{-1} & \omega_n^{-2} & \cdots & \omega_n^{-(n-1)} \\ 1 & \omega_{n}^{-2} & \omega_n^{-4} & \cdots & \omega_n^{-2(n-1)} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & \omega_n^{-(n-1)} & \omega_n^{-2(n-1)} & \cdots & \omega_n^{-(n-1)(n-1)} \end{bmatrix}

本题做的三进制 FWT 中,\omega=-\frac 1 2+ \frac {\sqrt 3i} 2\omega^2=-\frac 1 2-\frac {\sqrt 3 i} 2,转移矩阵就和逆转移矩阵分别是:

1 & 1 & 1 \\ 1 & \omega & \omega^2 \\ 1 & \omega^2 & \omega \end{bmatrix}, A^{-1}=\frac 1 3\begin{bmatrix} 1 & 1 & 1 \\ 1 & \omega^{-1} & \omega^{-2} \\ 1 & \omega^{-2} & \omega^{-1} \end{bmatrix}= \frac 1 3 \begin{bmatrix} 1 & 1 & 1 \\ 1 & \omega^{2} & \omega^{} \\ 1 & \omega^{} & \omega^{2} \end{bmatrix}

我们在 IFWT 的过程中,每一层是需要对点值多项式每一项除以 3 的。当然,由于 IFWT 一共 \log_3 n 层,每层都要除以 3,所以也可以写成最后都除以 n,不必在过程中做除法,下面给个三进制 FWT xor 的参考代码。

每次都除以 3 的写法:

void FWT3XOR(int limit,complex *arr,int sign){
    for(int l=1;l<limit;l*=3){
        for(int cx=0;cx<limit;cx+=3*l){
            for(int cy=0;cy<l;++cy){
                complex tmp0=arr[cx+cy+l*0];
                complex tmp1=arr[cx+cy+l*1];
                complex tmp2=arr[cx+cy+l*2];
                if(sign==1){
                    arr[cx+cy+l*0]=tmp0+tmp1+tmp2;
                    arr[cx+cy+l*1]=tmp0+tmp1*w+tmp2*w2;
                    arr[cx+cy+l*2]=tmp0+tmp1*w2+tmp2*w;
                }else{
                    arr[cx+cy+l*0]=(tmp0+tmp1+tmp2)/3;
                    arr[cx+cy+l*1]=(tmp0+tmp1*w2+tmp2*w)/3;
                    arr[cx+cy+l*2]=(tmp0+tmp1*w+tmp2*w2)/3;
                }
            }
        }
    }
}

最后除以 n 的写法:

void FWT3XOR(int limit,complex *arr,int sign){
    for(int l=1;l<limit;l*=3){
        for(int cx=0;cx<limit;cx+=3*l){
            for(int cy=0;cy<l;++cy){
                complex tmp0=arr[cx+cy+l*0];
                complex tmp1=arr[cx+cy+l*1];
                complex tmp2=arr[cx+cy+l*2];
                if(sign==1){
                    arr[cx+cy+l*0]=tmp0+tmp1+tmp2;
                    arr[cx+cy+l*1]=tmp0+tmp1*w+tmp2*w2;
                    arr[cx+cy+l*2]=tmp0+tmp1*w2+tmp2*w;
                }else{
                    arr[cx+cy+l*0]=(tmp0+tmp1+tmp2);
                    arr[cx+cy+l*1]=(tmp0+tmp1*w2+tmp2*w);
                    arr[cx+cy+l*2]=(tmp0+tmp1*w+tmp2*w2);
                }
            }
        }
    }
    if(sign==-1){
        for(int cx=0;cx<limit;++cx)
            arr[cx]=arr[cx]/limit;
    }
}

两种写法等价,但后者应该更快(做除法次数少)。

本题完整代码如下,仅供参考。

#include <stdio.h>
#include <ctype.h>
#include <cmath>
#include <algorithm>
#include <numeric>
#define qaq inline
const int sz=1e6+19;
using ll=long long;
int n,k,tup[sz];
ll ans=0;
struct complex{
    long double real,imag;
    complex()=default;
    complex(long double a,long double b){
        real=a,imag=b;
    }
    qaq friend complex operator+(complex a,complex b){
        return complex(a.real+b.real,a.imag+b.imag);
    }
    qaq friend complex operator-(complex a,complex b){
        return complex(a.real-b.real,a.imag-b.imag);
    }
    qaq friend complex operator*(complex a,complex b){
        complex res;
        res.real=a.real*b.real-a.imag*b.imag;
        res.imag=a.real*b.imag+a.imag*b.real;
        return res;
    }
    qaq friend complex operator*(complex a,long double b){
        return complex(a.real*b,a.imag*b);
    }
    qaq friend complex operator/(complex a,long double b){
        return complex(a.real/b,a.imag/b);
    }
}pl[sz];
const complex w=complex(-0.5,0.5*std::sqrt(3));
const complex w2=complex(-0.5,-0.5*std::sqrt(3));
void FWT3XOR(int limit,complex *arr,int sign){
    for(int l=1;l<limit;l*=3){
        for(int cx=0;cx<limit;cx+=3*l){
            for(int cy=0;cy<l;++cy){
                complex tmp0=arr[cx+cy+l*0];
                complex tmp1=arr[cx+cy+l*1];
                complex tmp2=arr[cx+cy+l*2];
                if(sign==1){
                    arr[cx+cy+l*0]=tmp0+tmp1+tmp2;
                    arr[cx+cy+l*1]=tmp0+tmp1*w+tmp2*w2;
                    arr[cx+cy+l*2]=tmp0+tmp1*w2+tmp2*w;
                }else{
                    arr[cx+cy+l*0]=(tmp0+tmp1+tmp2)/3;
                    arr[cx+cy+l*1]=(tmp0+tmp1*w2+tmp2*w)/3;
                    arr[cx+cy+l*2]=(tmp0+tmp1*w+tmp2*w2)/3;
                }
            }
        }
    }
}
int main(){
    scanf("%d%d",&n,&k);
    for(int cx=0,u;cx<n;++cx){
        u=0;
        char ch=getchar();
        while(!isdigit(ch)) ch=getchar();
        while(isdigit(ch)) u=u*3+ch-'1',ch=getchar();
        tup[u]++,pl[u]=complex(1,0);
    }
    int lim,limbit3;
    for(lim=1,limbit3=0;limbit3<k;lim*=3,++limbit3);
    FWT3XOR(lim,pl,1);
    for(int cx=0;cx<lim;++cx)
        pl[cx]=pl[cx]*pl[cx]*pl[cx];
    FWT3XOR(lim,pl,-1);
    ans=pl[0].real+0.5;
    printf("%lld\n",(ans-n)/6);
    return 0;
}