期望线性时间的最近点对

· · 算法·理论

省流:这篇论文的复读 + 代码复现。

记号及约定:\delta(S) 表示二维点集 S 中的最近点对的距离,对于 x\in Sd(x) 表示 xS 中的其他点的最近距离。

筛选最近点对算法

S 表示要求解的点集,先求出 \delta(S) 的一个近似解(近似因子为 3),然后再通过这个近似解求出 \delta(S)

筛选过程

S_i 表示第 i 轮迭代时的点集(初始时 S_1=S),随机选取一个点 x_i\in S_i 并求出 d(x_i)(此时 d 是相对于 S_i 的,后文同理)。

将平面按照 b=d(x_i)/3 分块,若一个点的 8-邻居 块都为空且其自身所在的块仅有他自己,则删去这个点。删去所有符合上述条件的点后,得到 S_{i+1}

记最小的使得 S_{i^{\ast}+1}=\varnothing 的迭代次数为 i^{\ast}d(x_{i^{\ast}}) 满足 \delta(S)\leq d(x_{i^{\ast}})< 3\delta(S)

接下来是正确性证明,首先考虑两个事实:

  1. 满足 d(x)> 2\sqrt{2}b 的点,一定会被删除。
  2. 满足 d(x)< b 的点一定会被保留。

借用原论文中的一张图(其实是我重新画的)

夹在这个范围中间的点可能会被保留也可能会被删除。由于 \dfrac{2\sqrt{2}}{3}<1,所以 d(x)\geq d(x_i)x 都被删除了,因此每轮迭代中的 d(x_i)递减的。

(u,v)S 中的最近点对,j^{\ast} 表示 (u,v) 任一点被删除的最早时间,根据上面的性质我们知道 \delta(S)\geq d(x_{j^{\ast}})/3>d(x_{i^{\ast}})/3

因此 \delta(S)\leq d(x_{i^{\ast}})< 3\delta(S)\Box

求解最近点对

此时我们已经得到了一个 \delta(S) 的近似解 d(x_{i^{\ast}}),此时将平面按照 b=d(x_{i^{\ast}}) 分块,由于块长是 \delta(S) 的常数倍,因此每个块内只能装下 \mathcal{O}(1) 个点。

对于每个点,检查其所在的块以及其 8-邻居 块,用这些块内的点更新答案(其余的块距离该点 >\delta(S)),即求出 \delta(S)

代码实现细节

对于块长为 b 的分块,相当于是对点建立起一个 (x,y)\rightarrow \left(\left\lfloor\frac{x}{b}\right\rfloor,\left\lfloor\frac{y}{b}\right\rfloor\right) 的映射。对 \left(\left\lfloor\frac{x}{b}\right\rfloor,\left\lfloor\frac{y}{b}\right\rfloor\right) 建哈希表,就能将每个点存到对应的块中。查询 8-邻居 和自身块相当于查询块 \left(\left\lfloor\frac{x}{b}\right\rfloor+\Delta x,\left\lfloor\frac{y}{b}\right\rfloor+\Delta y\right),其中 \Delta x,\Delta y\in \{-1,0,1\}

时间复杂度

先分析筛选过程的时间复杂度,检查邻居只需要检查 \mathcal{O}(1) 个点,哈希表的查询复杂度可以认为是期望 \mathcal{O}(1)。因此每轮迭代的时间复杂度是 \mathcal{O}(|S_i|)。特别地,我们有:

\mathbb{E}\left(\sum_{i=1}^{i^{\ast}}|S_i|\right)\leq 2n

证明:给出一个粗略的估计,将 S_i 中的点按照 d(x) 的大小升序排序,其中选取到第 j至多能在 S_{i+1} 中保留 j-1 个点,因此

\mathbb{E}\left(|S_{i+1}|\right)\leq\dfrac{1}{|S_{i}|}\sum_{j=0}^{|S_i|-1}j=\dfrac{|S_{i}|-1}{2}

也就是说每轮迭代 S_i 期望减少一半,得到

\mathbb{E}\left(\sum_{i=1}^{i^{\ast}}|S_i|\right)=\sum_{i=1}^{i^{\ast}}\mathbb{E}\left(|S_i|\right)\leq \sum_{i=1}^{n}\dfrac{n}{2^{i-1}}\leq 2n

即证。\Box

对于第二部分求解 \delta(S) 的过程,由于每个点只会检查严格 \mathcal{O}(1) 个点,因此这部分的时间复杂度也是线性的。

总时间复杂度 \mathcal{O}(n)

如何推广到三维

推广到三维是平凡的,我们只需要照猫画虎地把检查 8-邻居 改成检查 26-邻居 就好了。但是三维中用 b=d(x_{i})/3 会出问题,因为立方体中能塞下的最长距离是 \dfrac{2\sqrt{3}}{3}>1,无法保证 d(x_{i}) 递减。因此我们需要把块长调整为 b=d(x_{i})/4,此时 \dfrac{\sqrt{3}}{2}<1 满足条件。

实现了三维最近点对的代码,写得比较丑。

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")

#include<bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define poly vector<p3>
#define int ll
#define i128 __int128
#define i3 ll,ll,ll
typedef long long ll;
typedef long double ld;
using namespace std;
const int N=1000010,lpw=1000003,B=19260817,V=1e11,INF=0x3f3f3f3f3f3f3f3f;
const ld eps=1e-9,inf=1e100;
const int dx[27]={-1,-1,-1,-1,-1,-1,-1,-1,-1,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1};
const int dy[27]={-1,-1,-1,0,0,0,1,1,1,-1,-1,-1,0,0,0,1,1,1,-1,-1,-1,0,0,0,1,1,1};
const int dz[27]={-1,0,1,-1,0,1,-1,0,1,-1,0,1,-1,0,1,-1,0,1,-1,0,1,-1,0,1,-1,0,1};
mt19937_64 rnd(chrono::steady_clock::now().time_since_epoch().count());
struct p3{
    ll x,y,z;//no double hiahia
    p3(ll xx=0,ll yy=0,ll zz=0){x=xx,y=yy,z=zz;}
    bool operator ==(const p3 &a)const{return x==a.x&&y==a.y&&z==a.z;}
    bool operator !=(const p3 &a)const{return x!=a.x||y!=a.y||z!=a.z;}
    p3 &operator +=(const p3 &b){x+=b.x,y+=b.y,z+=b.z;return *this;}
    p3 &operator -=(const p3 &b){x-=b.x,y-=b.y,z-=b.z;return *this;}
    friend istream &operator >>(istream &is,p3 &a){return is>>a.x>>a.y>>a.z;}
    friend ostream &operator <<(ostream &os,p3 &a){return os<<'('<<a.x<<','<<a.y<<','<<a.z<<')';}
}O(0,0,0);
p3 operator +(const p3 &a,const p3 &b){return p3(a.x+b.x,a.y+b.y,a.z+b.z);}
p3 operator -(const p3 &a,const p3 &b){return p3(a.x-b.x,a.y-b.y,a.z-b.z);}
p3 operator *(const ld &x,const p3 &b){return p3((ld)x*b.x+eps,(ld)x*b.y+eps,(ld)x*b.z+eps);}
p3 operator /(const p3 &a,const ld &x){return p3((ld)a.x/x+eps,(ld)a.y/x+eps,(ld)a.z/x+eps);}
ld operator *(const p3 &a,const p3 &b){return (ld)a.x*b.x+(ld)a.y*b.y+(ld)a.z*b.z;}
ld dis(p3 a,p3 b){return sqrtl((a-b)*(a-b));}
int head[N],edge[N],Next[N],tot;
int T,n;
p3 ver[N];
poly p,q,s;
ld ans;
int hsh(p3 y){
    y.x=(y.x%lpw+lpw)%lpw;
    y.y=(y.y%lpw+lpw)%lpw;
    y.z=(y.z%lpw+lpw)%lpw;
    int x=(y.x*B%lpw*B%lpw+y.y*B%lpw+y.z)%lpw;
    return (x+lpw)%lpw;
}
void clr(ld d){
    tot=0;
    for(p3 z:p)head[hsh(4.0*z/d)]=0;
}
void add(p3 y,ld d){
    int x=hsh(4.0*y/d);
    ver[++tot]=y,edge[tot]=1,Next[tot]=head[x],head[x]=tot;
}
void upd(p3 y,ld d){
    int flg=0,x=hsh(4.0*y/d);
    for(int i=head[x];i;i=Next[i]){
        if(ver[i]!=y)continue;
        edge[i]++;
        flg=1;
        break;
    }
    if(!flg)add(y,d);
}
ld solve(){
    for(;;){
        q.clear();
        int j=rnd()%p.size();
        ld d=inf;
        for(int i=0;i<(int)p.size();i++)
            if(i!=j)d=min(d,dis(p[i],p[j]));
        if(fabs(d)<eps)return 0;
        for(int i=0;i<(int)p.size();i++)upd(p[i],d);
        for(int i=0;i<(int)p.size();i++){
            for(int k=0,flg=1;k<27;k++){
                p3 o=(4.0*p[i]/d)+p3(dx[k],dy[k],dz[k]);
                int x=hsh(o),val=0;
                for(int j=head[x];j;j=Next[j]){
                    if((4.0*ver[j]/d)!=o)continue;
                    val+=edge[j];
                    if(val>1)break;
                }
                flg&=(k==13)?val==1:val==0;
                if(!flg){
                    q.pb(p[i]);
                    break;
                }
            }
        }
        clr(d);
        if(q.size()<=1)return d;
        p.swap(q);
    }
    return 0;
}
void _solve(){
    cin>>n;
    p.resize(n);
    s.resize(n);
    ans=inf;
    for(int i=0;i<n;i++){
        cin>>p[i];
        p[i]+=p3(V,V,V);
        s[i]=p[i];
    }
    ld d=solve();
    assert(!tot);
    if(fabs(d)<eps){
        cout<<"0.00\n";
        return;
    }
    for(int i=0;i<n;i++)
        upd(s[i],4.0*d);
    for(p3 z:s){
        p3 o=(1.0*z/d);
        int x=hsh(o);
        for(int i=head[x];i;i=Next[i]){
            if(ver[i]==z){
                if(edge[i]>1){
                    p=s,clr(4.0*d);
                    cout<<"0.00\n";
                    return;
                }
                continue;
            }
            ans=min(ans,dis(z,ver[i]));
        }
        for(int k=0;k<27;k++){
            if(k==13)continue;
            int x=hsh(o+p3(dx[k],dy[k],dz[k]));
            for(int i=head[x];i;i=Next[i])
                ans=min(ans,dis(z,ver[i]));
        }
    }
    p=s,clr(4.0*d);
    cout<<fixed<<setprecision(2)<<ans<<'\n';
}
signed main(){
    ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
    cin>>T;
    while(T--)_solve();
    return 0;
}