P12692 BZOJ3784 树上的路径 题解

· · 题解

目前题解区都是 O(n\log^2 n) 做法,我来一个 O(n\log n) 做法吧。

Description

题目

给定一个 n 个结点的树,结点用正整数 1 \sim n 编号。每条边有一个正整数权值 c。用 d(a, b) 表示从结点 a 到结点 b 路边上经过边的权值。其中要求 a < b。将这 \frac{n \times (n-1)}{2} 个距离从大到小排序,输出前 m 个距离值。

n \leq 5 \times 10^4$,$m \leq \min(3 \times 10^5, \frac{n \times (n-1)}{2})$,$a, b \leq n$,$c \leq 10^4

Solution

先考虑前 k 优的一般性方法。

类似于P2048 [NOI2010] 超级钢琴,我们尝试对于每一个点用堆维护到其它点的最远距离。

考虑直径引理,像P2056 [ZJOI2007] 捉迷藏那样用线段树维护编号连续的一段区间的直径两个端点。

于是,对于点 i 到区间 [l,r] 的最远点对就可以 O(\log n) 求出来。

具体地,初始化,对于每个点 i 将距离区间 i+1\sim n 的最远点加入堆(强制连接到右边的点是为了避免重复点对)。然后不断取出堆顶元素并且分裂区间。

#include<bits/stdc++.h>
#define N 50005
#define pb push_back
using namespace std;
#define getc (cs==ct&&(ct=(cs=cb)+fread(cb,1,1<<15,stdin),cs==ct)?0:*cs++)
char cb[1<<15],*cs,*ct;
inline void read(auto &num){
    char ch;while(!isdigit(ch=getc));
    for(num=ch-'0';isdigit(ch=getc);num=num*10+ch-'0');
}
int n,m,st[16][N],rk[N],lg[N],dn,dep[N];
vector<pair<int,int>>g[N];
inline void dfs(int u,int fa){
    st[0][rk[u]=++dn]=fa;
    for(auto [v,w]:g[u])if(v!=fa)dep[v]=dep[u]+w,dfs(v,u);
}
inline int get(int x,int y){return rk[x]<rk[y]?x:y;}
inline int lca(int u,int v){
    if(u==v)return u;
    if((u=rk[u])>(v=rk[v]))swap(u,v);
    int g=lg[v-u++];
    return get(st[g][u],st[g][v-(1<<g)+1]);
}
inline int dis(int u,int v){return u==-1||v==-1?-1:dep[u]+dep[v]-2*dep[lca(u,v)];}
struct node{int u,v;}t[N<<2];
#define ls (p<<1)
#define rs (p<<1|1)
#define u1 x.u
#define u2 y.u
#define v1 x.v
#define v2 y.v
#define D dis(r.u,r.v)
#define mid ((l+r)>>1)
inline node merge(node x,node y){
    if(u1==-1&&v1==-1)return y;
    if(u2==-1&&v2==-1)return x;
    auto r=(dis(u1,v1)>dis(u2,v2))?x:y;
    if(dis(u1,u2)>D)r={u1,u2};
    if(dis(u1,v2)>D)r={u1,v2};
    if(dis(v1,u2)>D)r={v1,u2};
    if(dis(v1,v2)>D)r={v1,v2};
    return r;
}
inline void push_up(int p){
    t[p]=merge(t[ls],t[rs]);
}
inline void addsum(int p,int l,int r,int x){
    if(l==r)return (t[p].u==-1&&t[p].v==-1)?t[p].u=t[p].v=x:t[p].u=t[p].v=-1,void();
    if(x<=mid)addsum(ls,l,mid,x);
    if(x>mid)addsum(rs,mid+1,r,x);
    push_up(p);
}
inline void build(int p,int l,int r){
    if(l==r)return t[p].u=t[p].v=l,void();
    build(ls,l,mid),build(rs,mid+1,r);
    push_up(p);
}
inline node getsum(int p,int l,int r,int L,int R){
    if(L<=l&&r<=R)return t[p];
    if(R<=mid)return getsum(ls,l,mid,L,R);
    if(L>mid)return getsum(rs,mid+1,r,L,R);
    return merge(getsum(ls,l,mid,L,R),getsum(rs,mid+1,r,L,R));
}
struct info{
    int w,i,x,l,r;
};
struct ifcp{
    inline bool operator()(const info& a, const info& b){return a.w < b.w;}
};
priority_queue<info,vector<info>,ifcp>q;
inline int query(int x,int l,int r){
    auto res=getsum(1,1,n,l,r);
    return dis(x,res.u)>dis(x,res.v)?res.u:res.v;
}
inline void ps(int i,int l,int r){
    int res=query(i,l,r);
    q.push({dis(i,res),i,res,l,r});
}
int main(){
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    read(n),read(m);
    lg[0]=-1;
    for(int i=1;i<=n;i++)lg[i]=lg[i>>1]+1;
    for(int u,v,w,i=1;i<n;i++)read(u),read(v),read(w),g[u].pb({v,w}),g[v].pb({u,w});
    dfs(1,0);
    for(int i=1;i<=lg[n];i++)for(int j=1;j+(1<<i)-1<=n;j++)st[i][j]=get(st[i-1][j],st[i-1][j+(1<<(i-1))]);
    build(1,1,n);
    for(int i=1;i<n;i++)ps(i,i+1,n);
    while(m--){
        auto [w,i,x,l,r]=q.top();q.pop();
        cout<<w<<'\n';
        if(x!=l)ps(i,l,x-1);
        if(x!=r)ps(i,x+1,r);
    }
}