题解:CF2081D MST in Modulo Graph

· · 题解

这场 div1 D<B<C 啊。

很显然不能 O(n^2) 直接连边,这种题一般得考虑优化连边过程。

我们先把 p 去重排序,然后我们注意到 p 的值域才 5\times10^5,那么我们不妨这样连边:枚举 i,找到每一个 k,在 p 中找到位于 [p_i\times k,p_i\times (k+1)) 中最小的数 y,设其下标为 z,然后在 i\to z 连边权为 y-p_i\times k 的一条边。

为什么这是正确的?我们假设有 p_i[p_i\times k,p_i\times (k+1)) 中最小的数为 b,还有一个 p_c>p_bc\in [p_i\times k,p_i\times (k+1)),我们注意到,如果连 i\to bb\to c 总是比连 i\to bi\to c 好,因为两部分边权分别为 (p_b-p_i)+(p_c-p_b)(p_b-p_i)+(p_c-p_i)。得证。

连完边后跑一遍 Kruskal 即可,这样只有 O(V\ln V) 条边,时间复杂度 O(V\ln V\log V+n\log n),瓶颈在于去重排序和连边 Kruskal。

code:

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define fir first
#define sec second.first
#define trd second.second
#define pb push_back
#define piii pair<int,pair<int,int>>

const int N=5e5+5;
int n,p[N],lst[N],nxt[N],pos[N];
int fa[N],cxx[N],V;
vector<piii> v;
bool cmp(piii x,piii y){
    return x.trd<y.trd;
}
int find(int x){
    if(x==fa[x]) return x;
    else return fa[x]=find(fa[x]);
}

signed main(){
    ios::sync_with_stdio(0),cin.tie(0);
    int T;cin>>T;while(T--){
        cin>>n;V=n;
        for(int i=1;i<=n;i++){cin>>p[i];cxx[p[i]]=1;V=max(V,p[i]);}
        sort(p+1,p+n+1);n=unique(p+1,p+n+1)-p-1;
        for(int i=1;i<=n;i++) fa[i]=i,pos[p[i]]=i;
        for(int i=1;i<=V;i++){
            if(cxx[i]) lst[i]=i;
            else lst[i]=lst[i-1];
        }
        for(int i=V;i>=1;i--){
            if(cxx[i]) nxt[i]=i;
            else nxt[i]=nxt[i+1];
        }
        v.clear();
        for(int i=1;i<=n;i++){
            for(int k=1;k*p[i]<=V;k++){
                if(cxx[k*p[i]]&&k!=1){
                    v.pb({pos[k*p[i]],{i,0}});
                    continue;
                }
                int to=0;
                if(nxt[k*p[i]+1]<=(k+1)*p[i]) to=nxt[k*p[i]+1];
                if(!to) continue;
                v.pb({pos[to],{i,to-k*p[i]}});
            }
        }
        sort(v.begin(),v.end(),cmp);
        int ans=0;
        for(auto i:v){
            if(find(i.fir)==find(i.sec)) continue;
            fa[find(i.fir)]=find(i.sec);
            ans+=i.trd;
        }
        cout<<ans<<endl;
        for(int i=1;i<=V;i++) cxx[i]=0,pos[i]=0,lst[i]=nxt[i]=0;
    }
    return 0;
}