[NOI2013] 向量内积

· · 题解

传送门

感谢 wzy 提供的思路,这可能是一个比较数学且困难的做法,和其他题解不太一样。

k=2

考虑把这些向量排成一个 n\times d 的矩阵 A,考虑把 AA 的转置 A^T 相乘得到 B,发现如果有解就是 B 不是一个全 1 矩阵,且如果 B_{i,j}=0,答案就是 i,j

思考如何 check 一个全 1 矩阵,随机构造一个 n 行列向量 R。如果 B 是一个全 1 矩阵,那么 B\times R 得到的列向量的每个元素都应该是 R 中的所有元素和,且如果第 i 行不满足,那么肯定存在一组 i,j 是答案。

计算 B\times R 也很简单,B\times R=(A\times(A^T\times R)),时间复杂度 \mathcal{O}(nd)

k=3

因为模 3 意义下 12 的平方都是 1,考虑 check 矩阵 B 的每个值平方后是否是全 1 矩阵,随机生成一个数组 r[],我们要 check 的就是对于每个 i\sum_jB_{i,j}^2r_j 是否等于 \sum_j r_j

考虑构造出矩乘的形式,\sum_jB_{i,j}^2r_j=\sum_{j}B_{i,j}R_{j,j}B^T_{j,i}。其中 R 是一个只有对角线有值的矩阵,且 R_{j,j}=r_j。现在要 check 的是最终得到的矩阵中对角线的值是否都是 \sum_j r_j。如果可以得到最终这个矩阵,构造方案的方法与 k=2 类似。

\begin{aligned} &B\times R\times B^T\\ =&A\times (A^T\times (R \times A))\times A^T\\ \end{aligned}

其中 B^T=(A\times A^T)^T=(A^T)^T\times A^T=A\times A^T

因为 R 只有对角线有值,所以 R\times A 可以快速处理,中间三项乘出来得时间复杂度是 \mathcal{O}(nd^2)。因为我们只关心答案矩阵对角线上的值,左右两个矩阵不用暴力乘上去,考虑答案矩阵的 Ans_{i,i},这时只用取出第一项的第 i 行以及最后一项的第 i 列来乘就好,总的时间复杂度 \mathcal{O}(nd^2)

实现

#include<iostream>
#include<stdio.h>
#include<ctype.h>
#include<bitset>
#include<random>
#include<stdlib.h>
using namespace std;
inline int read(){
    int x=0,f=0; char ch=getchar();
    while(!isdigit(ch)) f|=(ch==45),ch=getchar();
    while(isdigit(ch)) x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
    return f?-x:x;
}
int n,m,k;
struct matrix{
    vector<vector<int> > a;
    int n,m;
    inline void build(){
        a.resize(n+1);
        for(int i=1;i<=n;++i) a[i].resize(m+1);
    }
    matrix operator * (matrix B){
        matrix A=*this,C;
        C.n=A.n,C.m=B.m;
        C.a.resize(C.n+1);
        for(int i=1;i<=C.n;++i) C.a[i].resize(C.m+1);
        for(int i=1;i<=C.n;++i)
            for(int j=1;j<=C.m;++j){
                C.a[i][j]=0;
                for(int k=1;k<=A.m;++k){
                    C.a[i][j]+=A.a[i][k]*B.a[k][j];
                }
                C.a[i][j]%=k;
            }
        return C;
    }
};
namespace sub2{
    inline void main(){
        mt19937 rnd(114514);
        matrix A,B;
        A.n=n,A.m=m,A.build();
        B.n=m,B.m=n,B.build();
        for(int i=1;i<=n;++i){
            for(int j=1;j<=m;++j){
                B.a[j][i]=A.a[i][j]=read()%k;
            }
        }
        for(int fick=1;fick<=10;++fick){
            matrix R;
            R.n=n,R.m=1,R.build();
            int sum=0;
            for(int i=1;i<=n;++i) R.a[i][1]=(rnd()&1),(sum+=R.a[i][1])%=2;
            R=B*R;
            R=A*R;
            for(int i=1;i<=n;++i){
                if(R.a[i][1]!=sum){
                    for(int j=1;j<=n;++j){
                        if(i==j) continue;
                        int s=0;
                        for(int k=1;k<=m;++k){
                            (s+=A.a[i][k]*A.a[j][k])%=2;
                        }
                        if(s==0){
                            printf("%d %d\n",min(i,j),max(i,j));
                            return;
                        }
                    }
                }
            }
        }
        puts("-1 -1");
    }
}
namespace sub3{
    inline void main(){
        mt19937 rnd(114514);
        matrix A,B,C,D;
        A.n=n,A.m=m,A.build();
        B.n=m,B.m=n,B.build();
        for(int i=1;i<=n;++i){
            for(int j=1;j<=m;++j){
                B.a[j][i]=A.a[i][j]=read()%k;
            }
        }
        C=A,D=B;
        for(int fick=1;fick<=10;++fick){
            int sum=0;
            matrix R=C;
            for(int i=1;i<=n;++i){
                int x=rnd()%3;
                (sum+=x)%=3;
                for(int j=1;j<=m;++j){
                    (R.a[i][j]*=x)%=3;
                }
            }
            R=B*R;
            for(int i=1;i<=n;++i){
                int s=0;
                for(int j=1;j<=m;++j){
                    for(int k=1;k<=m;++k){
                        s+=A.a[i][j]*R.a[j][k]*D.a[k][i];
                    }
                }
                if(s%3!=sum){
                    for(int j=1;j<=n;++j){
                        if(i==j) continue;
                        int tmp=0;
                        for(int k=1;k<=m;++k){
                            (tmp+=A.a[i][k]*A.a[j][k])%=3;
                        }
                        if(tmp==0){
                            printf("%d %d\n",min(i,j),max(i,j));
                            return;
                        }
                    }
                }
            }
        }
        puts("-1 -1");
    }
}
int main(){
    n=read(),m=read(),k=read();
    if(k==2) sub2::main();
    else sub3::main();
    return 0;
}