[SNOI2017]炸弹:线段树优化建图

· · 题解

update:现在这是一个时间复杂度正确的做法
update2:后来发现代码有锅,现已修复。

蒟蒻刚看到有线段树优化建图这种东西,于是想试着自己写一发,调了好久终于过了。。
收获不少,写篇题解报复造福社会qwq

这题要我们算每个炸弹能引爆的炸弹个数。
我们很容易想到从每个炸弹向它爆炸范围内的炸弹连边,然后从每个点搜索,能抵达的节点数量即是答案。

这样做是很不优的,时空复杂度都达到了 \text{O}(n^2) 的级别,那如何优化呢?

我们来考虑这样一种情况:
如图,有一个点要向另外 7 个点连边

注意到每个点连向的点一定都是在一段区间里的,所以我们可以这样:

最下面一排是原来的点,这样连边,和原图的效果是一样的。
你说这样反而复杂了?

原来要从那个点连出去 7 条边,现在只需要连出去 3 条,也就是 \lceil\log_2 7\rceil 条边了。
推而广之,用这种方法给 n 个点向区间连边,就只需要 \text O(n \log n) 条边!

如此,我们就把建边的时空复杂度降下来了。

关于具体的代码实现,这里再说一下:
首先我们要建树,这里和普通的线段树没什么两样:

void build(int l,int r,int u){
    if(l==r){
        id[l] = u //id[x]表示x节点在线段树上的标号
        return;
        //左右端点相等,则表示递归到底层
    }
    int mid = (l+r)>>1;
    build(l,mid,u<<1); //建左半边
    build(mid+1,r,u<<1|1); //建右半边
    add_edge(u,u<<1);
    add_edge(u,u<<1|1); //u向自己左右儿子连边
}

然后就是要实现点向区间连边,代码如下:

void link(int v,int nl,int nr,int l,int r,int u){
    //v向区间[nl,nr]连边  
    if(nl<=l&&r<=nr){
        if(v==u) return; //判自环
        add_edge(v,u);
        return;
    }
    int mid = (l+r)>>1;
    if(nl<=mid) link(v,nl,nr,l,mid,u<<1);
    if(nr>mid) link(v,nl,nr,mid+1,r,u<<1|1);
}

建图完毕,接下来要统计答案。
为了方便后面做,需要先跑个 \texttt{tarjan} 缩点,注意缩点后每个点的点权为 对应的强连通分量 中节点个数。

缩点之后,原图就变成 \text{DAG} 了,统计每个点能抵达的节点 权值 和,但是这个东西在一般的 \text{DAG} 上是个世纪难题。

不过现在它到了区间上,可以在 \texttt{tarjan} 的过程中记录每个新点(即强连通分量 )能到达的左右端点,最后对每个点直接取就是答案了。

不过要注意跑完缩点还要再 \text{dfs} 一遍,用 u 能到的所有节点 v 来更新它能到的左右端点。

代码如下:

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<queue>
#define ll long long
#define N 500003
#define M 2000003
#define reg register
#define p 1000000007
#define mid ((l+r)>>1)
#define ls (u<<1)
#define rs (u<<1|1)
using namespace std;

struct node{ 
    int l,r;
    inline node(int l=0,int r=0):l(l),r(r){}
}a[M];

vector<int> adj[M],G[M];
ll x[N],r[N];
int clr[M],low[M],dfn[M],stk[M];
int id[M],left[M],right[M];
bool vis[M];
int n,cnt,scc,tp,nd,ans;

void tarjan(int u){
    dfn[u] = low[u] = ++cnt;
    stk[++tp] = u;
    vis[u] = true;
    int v,l = adj[u].size();
    for(int i=0;i<l;++i){
        v = adj[u][i];
        if(!dfn[v]){
            tarjan(v);
            low[u] = min(low[u],low[v]);
        }else if(vis[v])
            low[u] = min(low[u],dfn[v]);
    }
    if(low[u]!=dfn[u]) return;
    ++scc;
    do{
        v = stk[tp];
        clr[v] = scc; 
        vis[v] = false;
        left[scc] = min(left[scc],a[v].l);  //更新这个强连通分量中能到的左右端点
        right[scc] = max(right[scc],a[v].r);
        --tp;
    }while(stk[tp+1]!=u);
}

void connect(int v,int nl,int nr,int l,int r,int u){ //线段树上连边
    if(nl<=l&&r<=nr){
        if(v==u) return;
        adj[v].push_back(u);
        return;
    }
    if(nl<=mid) connect(v,nl,nr,l,mid,ls);
    if(nr>mid) connect(v,nl,nr,mid+1,r,rs);
}

void build(int l,int r,int u){
    a[u] = node(l,r);
    nd = max(nd,u);
    if(l==r){
        id[l] = u; 
        return;
    }
    build(l,mid,ls);
    build(mid+1,r,rs);
    adj[u].push_back(ls);
    adj[u].push_back(rs);
}

inline int query(int x){ 
    int u = clr[id[x]];
    return right[u]-left[u]+1;
}

void dfs(int u){ //最后一遍 dfs 更新
    vis[u] = true;
    int v,l = G[u].size();
    for(int i=0;i!=l;++i){
        v = G[u][i];
        if(vis[v]){
            left[u] = min(left[u],left[v]);
            right[u] = max(right[u],right[v]);
            continue;
        }
        dfs(v);
        left[u] = min(left[u],left[v]);
        right[u] = max(right[u],right[v]);
    }
}

int main(){
    int L,R,ln,u;
    scanf("%d",&n);
    for(reg int i=1;i<=n;++i)
        scanf("%lld%lld",&x[i],&r[i]);
    memset(left,0x3f,sizeof(left));    
    x[n+1] = 0x3f3f3f3f3f3f3f3fll;
    build(1,n,1); 
    for(reg int i=1;i<=n;++i){
        if(!r[i]) continue;
        L = lower_bound(x+1,x+1+n,x[i]-r[i])-x;
        R = upper_bound(x+1,x+1+n,x[i]+r[i])-x-1;
        connect(id[i],L,R,1,n,1);
        a[id[i]] = node(L,R);
    }    
    tarjan(1);  
    for(reg int i=1;i<=nd;++i){
        ln = adj[i].size();
        for(reg int j=0;j!=ln;++j){
            u = adj[i][j];
            if(clr[u]==clr[i]) continue;
            G[clr[i]].push_back(clr[u]);
        }
    }
    for(reg int i=1;i<=scc;++i){ //把重复连的边去掉
        sort(G[i].begin(),G[i].end());
        unique(G[i].begin(),G[i].end());
    }
    memset(vis,0,sizeof(vis));
    for(reg int i=1;i<=scc;++i)
        if(!vis[i]) dfs(i);
    for(reg int i=1;i<=n;++i)
        ans = (ans+(ll)query(i)*i)%p;
    printf("%d",ans);
    return 0;
}