题解:P6431 [COCI2008-2009#1] KRTICA

· · 题解

思路:

题意:给定一棵树,断开一条边,新加一条边使得新树的直径最小。

来个数据结构暴力求解的题解。

直径的性质:令 F(x) 为点集 x 的直径的两个端点,对于点集 S 与点集 TF(S\cup T)\subseteq F(S)\cup F(T)

说直白点,就是给你两棵树,一棵树的直径端点是 (x1,y1),另一棵是 (x2,y2),加边合并这两棵新树,新直径的端点是四个点中的两个。

证明:若新树的直径不为两个原直径,那么新直径过连接边,两端点是距离连接点最远的两个端点,即两原树的直径端点之一。

所以求 F(x) 是非常好做的,只需要写一个合并函数:

inline Dia merge(Dia x,Dia y)
{
    int lx=x.l,rx=x.r,ly=y.l,ry=y.r;
    int dis[]={Dis(lx,rx),Dis(lx,ly),Dis(lx,ry),Dis(rx,ly),Dis(rx,ry),Dis(ly,ry)};
    int z=max_element(dis,dis+6)-dis;
    if(z==0) return {lx,rx};
    if(z==1) return {lx,ly};
    if(z==2) return {lx,ry};
    if(z==3) return {rx,ly};
    if(z==4) return {rx,ry};
    return {ly,ry};
}

注意到合并的复杂度是和你 LCA 算法有关的,所以建议写 O(1) LCA。

然后就可以上线段树啥的来维护,这里不推荐倍增数组,虽然可以 O(1) 询问,但是空间不够。

本人实现的是线段树和 O(1) LCA,可以 O(\log n) 来查询点集的直径两端点。

然后,我们再看怎样连边直径最小。断边后,原树被分成了两棵树,我们设第一棵树的直径为 l,第二棵树的直径为 r

考虑连接两棵树,假设连接的边的在第一棵树上的端点到第一棵树上其他点的最大距离为 x,在第二棵树上的端点到第二棵树上其他点的最大距离为 y,那么新树直径为 \max(l,r,x+y+1),而 x,y 的上界分别为 \lceil{l\over 2}\rceil,\lceil{r\over 2}\rceil,也就是说连接两直径的中点是最优的,那么答案就是所有断边情况的 \max(l,r,\lceil{l\over 2}\rceil+\lceil{r\over 2}\rceil+1) 的最小值。

剩下的,就是把树拍到区间上去来维护了。

时间复杂度为 O(n\log n),常数极大,且较卡空间。

code:

#include<bits/stdc++.h>
#define all(x) x.begin(),x.end()
#define mset(x,y) memset((x),(y),sizeof((x)))
#define mcpy(x,y) memcpy((x),(y),sizeof((y)))
#define FileIn(x) freopen(""#x".in","r",stdin)
#define FileOut(x) freopen(""#x".out","w",stdout)
#define debug(x) cerr<<""#x" = "<<(x)<<'\n'
#define Assert(x) if(!(x)) cerr<<"Failed: "#x" at line "<<__LINE__,exit(1)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef __int128 Int;
const int N=3e5+10;
bool StM;
int n,dep[N],dfn[N],From[N];
int siz[N],id[N],k=0;
int low[N<<1],t=0,rt=1;
int Lt[N];
pair<int,int>fa[N<<1][19];
vector<int>G[N];
int LCA(int l,int r)
{
    if(l>r) swap(l,r);
    int x=log2(r-l+1);
    return min(fa[l][x],fa[r-(1<<x)+1][x]).second;
}
struct Dia{int l,r;}g[N<<1];
inline int Dis(int x,int y)
{
    return dep[x]+dep[y]-2*dep[LCA(Lt[x],Lt[y])];
}
inline Dia merge(Dia x,Dia y)
{
    int lx=x.l,rx=x.r,ly=y.l,ry=y.r;
    int dis[]={Dis(lx,rx),Dis(lx,ly),Dis(lx,ry),Dis(rx,ly),Dis(rx,ry),Dis(ly,ry)};
    int z=max_element(dis,dis+6)-dis;
    if(z==0) return {lx,rx};
    if(z==1) return {lx,ly};
    if(z==2) return {lx,ry};
    if(z==3) return {rx,ly};
    if(z==4) return {rx,ry};
    return {ly,ry};
}
#define lc(x) (mid<<1)
#define rc(x) (mid<<1|1)
void build(int l,int r,int x)
{
    if(l==r)
        return void(g[x]={id[l],id[l]});
    int mid=(l+r)>>1;
    build(l,mid,lc(x)),build(mid+1,r,rc(x));
    g[x]=merge(g[lc(x)],g[rc(x)]);
}
Dia query(int p,int q,int l=1,int r=n,int x=1)
{
    if(p<=l&&q>=r)
        return g[x];
    int mid=(l+r)>>1;
    if(q<=mid) return query(p,q,l,mid,lc(x));
    if(p>mid) return query(p,q,mid+1,r,rc(x));
    return merge(query(p,q,l,mid,lc(x)),query(p,q,mid+1,r,rc(x)));
}
void dfs(int now,int from)
{
    id[dfn[now]=++k]=now;
    low[++t]=now;siz[now]=1;
    if(!Lt[now]) Lt[now]=t;
    for(int to:G[now])
    {
        if(to==from) continue;
        dep[to]=dep[now]+1;From[to]=now;
        dfs(to,now);
        low[++t]=now;siz[now]+=siz[to];
    }
    return;
}
int GetMid(int x,int y)
{
    int dis=Dis(x,y),z=LCA(x,y);
    if(dep[y]-dep[z]>dep[x]-dep[z]) swap(x,y);
    while(Dis(x,y)>dis/2) x=From[x];
    return x;
}
void Main()
{
    cin>>n;
    for(int i=1,x,y;i<n;i++)
    {
        cin>>x>>y;
        G[x].push_back(y);
        G[y].push_back(x);
    }
    dfs(rt,0);
    for(int i=1;i<=t;i++) 
        fa[i][0]={dep[low[i]],low[i]};
    for(int i=1;i<=18;i++)
        for(int j=1;j+(1<<i)-1<=t;j++)
            fa[j][i]=min(fa[j][i-1],fa[j+(1<<i-1)][i-1]);
    build(1,n,1);
    Dia x=query(1,n),L={0,0},R={0,0},Mid={0,0};
    int Ans=Dis(x.l,x.r),l,r;
    for(int i=2;i<=n;i++)
    {
        Dia l={1,1},r=query(dfn[i],dfn[i]+siz[i]-1);//第二棵树对应的就是子树的区间
        if(1<dfn[i]) l=merge(l,query(1,dfn[i]-1));
        if(dfn[i]+siz[i]-1<n) l=merge(l,query(dfn[i]+siz[i],n));
        //第一棵树对应的是除这棵子树外的最多两个区间
        int x=Dis(l.l,l.r),y=Dis(r.l,r.r);
        auto dia=max({(x+1)/2+(y+1)/2+1,x,y});
        if(dia<Ans) Ans=dia,L=l,R=r,Mid={i,From[i]};
    }
    cout<<Ans<<'\n';
    cout<<Mid.l<<' '<<Mid.r<<'\n';
    cout<<GetMid(L.l,L.r)<<' '<<GetMid(R.l,R.r)<<'\n';
}
bool EdM;
int main()
{
    cerr<<fabs(&StM-&EdM)/1024.0/1024.0<<" MB\n";
    ios::sync_with_stdio(false);
    cin.tie(0);cout.tie(0);
    int StT=clock();
    int T=1;
    while(T--) Main();
    int EdT=clock();
    cerr<<1e3*(EdT-StT)/CLOCKS_PER_SEC<<" ms\n";
    return 0;
}