P11307 [COTS 2016] 建造费 Pristojba 题解

· · 题解

无聊的 Boruvka 模板题。

显然考虑 Boruvka,现在仅需考虑一轮内的情况。转化为,每个点有颜色 bl_i 和点权 a_i,我们希望求出每种颜色最小的异色边。

边是区间的形式,考虑线段树,容易发现我们的要求仅是不与某个颜色相同,那只要维护最小和次小,并保证颜色不同就一定能取到最优了。边显然有两种形式,一个是作为给出三元组起点连到区间内,一个是被上面这种连。只需要写一个线段树支持区间合并查询,再写一个线段树支持区间修改单点查询即可。封装一下可以合并到一起。

每轮复杂度 O(n\log n),而每次至少使连通块数减半,复杂度为 O(n\log^2 n)

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#define ll long long
#define inf 1000000000
using namespace std;

int n,m;
int a[100005];
int bl[100005];
int fa[100005];
int vl[100005];
int to[100005];
vector <pair <int,int>> g[100005];
struct Edge{int u,v,w;}e[400005];
struct node{int mn1,mn2,co1,co2;};
node t[400005];
node tg[400005];
node bt[100005];

inline void in(int &n){
    n=0;
    char c=getchar();
    while(c<'0' || c>'9') c=getchar();
    while(c>='0'&&c<='9') n=n*10+c-'0',c=getchar();
    return ;
}

node operator + (node p,node q){
    pair <int,int> a[5];
    a[1]={p.mn1,p.co1};
    a[2]={q.mn1,q.co1};
    a[3]={p.mn2,p.co2};
    a[4]={q.mn2,q.co2};
    sort(a+1,a+1+4);
    node tmp;
    tmp.mn1=a[1].first,tmp.co1=a[1].second;
    if(a[2].second!=a[1].second){
        tmp.mn2=a[2].first,tmp.co2=a[2].second;
        return tmp;
    }
    if(a[3].second!=a[1].second){
        tmp.mn2=a[3].first,tmp.co2=a[3].second;
        return tmp;
    }
    if(a[4].second!=a[1].second){
        tmp.mn2=a[4].first,tmp.co2=a[4].second;
        return tmp;
    }
    tmp.mn2=inf,tmp.co2=0;
    return tmp;
}

inline void build(int u,int l,int r){
    tg[u]={inf,inf,0,0};
    if(l==r){t[u]={a[l],inf,bl[l],0};return ;}
    int mid=(l+r)>>1;
    build(u<<1,l,mid);
    build(u<<1|1,mid+1,r);
    t[u]=t[u<<1]+t[u<<1|1];
    return ;
}

inline node qry(int u,int l,int r,int L,int R){
    if(L<=l&&r<=R) return t[u];
    int mid=(l+r)>>1;
    if(L>mid) return qry(u<<1|1,mid+1,r,L,R);
    if(R<=mid) return qry(u<<1,l,mid,L,R);
    return qry(u<<1,l,mid,L,R)+qry(u<<1|1,mid+1,r,L,R);
}

inline void upd(int u,int l,int r,int L,int R,node x){
    if(L<=l&&r<=R){tg[u]=tg[u]+x;return ;}
    tg[u<<1]=tg[u<<1]+tg[u];
    tg[u<<1|1]=tg[u<<1|1]+tg[u];
    int mid=(l+r)>>1;
    if(L<=mid) upd(u<<1,l,mid,L,R,x);
    if(R>mid) upd(u<<1|1,mid+1,r,L,R,x);
    return ;
}

inline node qry(int u,int l,int r,int k){
    if(l==r) return tg[u];
    tg[u<<1]=tg[u<<1]+tg[u];
    tg[u<<1|1]=tg[u<<1|1]+tg[u];
    int mid=(l+r)>>1;
    if(k<=mid) return qry(u<<1,l,mid,k);
    else return qry(u<<1|1,mid+1,r,k);
}

inline int Find(int x){return fa[x]==x?x:fa[x]=Find(fa[x]);}

int main(){
    in(n),in(m);
    for(int i=1;i<=n;i++) in(a[i]),fa[i]=bl[i]=i;
    while(m--){
        int u,l,r;
        in(u),in(l),in(r);
        g[u].push_back({l,r});
    }
    ll s=0;
    while(1){
        bool ok=1;
        for(int i=1;i<=n;i++) ok&=(Find(i)==Find(1));
        if(ok) break;
        for(int i=1;i<=n;i++) bl[i]=Find(i),bt[i]={inf,inf,0,0},vl[i]=1e9;
        build(1,1,n);
        int m=0;
        for(int i=1;i<=n;i++){
            node tmp={inf,inf,0,0};
            for(auto tt:g[i]){
                tmp=tmp+qry(1,1,n,tt.first,tt.second);
                upd(1,1,n,tt.first,tt.second,{a[i],inf,bl[i],0});
            }
            if(tmp.co1!=bl[i]&&tmp.co1){
                if(vl[bl[i]]>tmp.mn1+a[i]) vl[bl[i]]=tmp.mn1+a[i],to[bl[i]]=tmp.co1;
            }
            if(tmp.co2!=bl[i]&&tmp.co2){
                if(vl[bl[i]]>tmp.mn2+a[i]) vl[bl[i]]=tmp.mn2+a[i],to[bl[i]]=tmp.co2;
            }
        }
        for(int i=1;i<=n;i++){
            auto tmp=qry(1,1,n,i);
            tmp.mn1+=a[i],tmp.mn2+=a[i];
            bt[bl[i]]=bt[bl[i]]+tmp;
        }
        for(int i=1;i<=n;i++){
            if(Find(i)==i){
                if(bt[i].co1!=i&&bt[i].co1){
                    if(vl[i]>bt[i].mn1) vl[i]=bt[i].mn1,to[i]=bt[i].co1;
                }
                if(bt[i].co2!=i&&bt[i].co2){
                    if(vl[i]>bt[i].mn2) vl[i]=bt[i].mn2,to[i]=bt[i].co2;
                }
                e[++m]={i,to[i],vl[i]};
            }
        }
        sort(e+1,e+1+m,[](Edge p,Edge q){return p.w<q.w;});
        for(int i=1;i<=m;i++){
            int u=Find(e[i].u),v=Find(e[i].v);
            if(u!=v) fa[u]=v,s+=e[i].w;
        }
    }
    printf("%lld\n",s);

    return 0;
}