[P4180][BJWC2010]【模板】严格次小生成树

· · 题解

题目

        点这里看题目。

分析

        很容易想到,次小生成树和最小生成树之间在没有边权相同的边情况下有且仅有一条边不同。如果存在多条边不同的话就肯定不是次小的了。
        所以我们只需要考虑改的边到底是哪一条就可以了。先求出最小生成树,然后考虑加入一条新边(u,v,w),则原树上(u,lca(u,v),v)的路径会变成一个环。显然,我们必须要在这个环上删除一条边重新变成一棵树。如果这个环上的最长边删去之后生成树的权值和没变,我们就要删除次长边;否则就删去最长边。假如我们可以快速地求出生成树中一条路径(u,lca(u,v),v)的最长边和次长边,我们扫一遍就可以知道答案了。
        考虑快速求出最长边和次长边。我们模拟倍增,求出从某个点u上跳2^j步经过的边权的最大值和次大值,预处理时间同样是O(n\log_2n);而查询的时候类似于LCA的上跳,中间更新出最大值和次大值就可以了,单次时间O(\log_2n)

代码

#include <cmath>
#include <cstdio>
#include <utility>
#include <algorithm>
using namespace std; 

typedef long long LL;

const int INF = 0x3f3f3f3f;
const int MAXN = 1e5 + 5, MAXM = 3e5 + 5, MAXLOG = 20;

template<typename _T>
void read( _T &x )
{
    x = 0;char s = getchar();int f = 1;
    while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
    while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
    x *= f;
}

template<typename _T>
void write( _T x )
{
    if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
    if( 9 < x ){ write( x / 10 ); }
    putchar( x % 10 + '0' );
}

template<typename _T>
_T MAX( const _T a, const _T b )
{
    return a > b ? a : b;
}

template<typename _T>
_T MIN( const _T a, const _T b )
{
    return a < b ? a : b;
}

typedef pair<int, int> RT;

struct edge
{
    int to, nxt, w;
}Graph[MAXN * 2];

struct Edge
{
    int from, to, w;
    bool operator < ( const Edge & other ) const { return w < other.w; }
}E[MAXM];

RT maxW[MAXN][MAXLOG];
int f[MAXN][MAXLOG];
int fa[MAXN], dep[MAXN];
int head[MAXN];
int N, M, cnt, lg2;
bool used[MAXM];

void addEdge( const int from, const int to, const int W )
{
    cnt ++;
    Graph[cnt].w = W;
    Graph[cnt].to = to;
    Graph[cnt].nxt = head[from];
    head[from] = cnt;
}

void makeSet( const int siz ) { for( int i = 1 ; i <= siz ; i ++ ) fa[i] = i; }
int findSet( const int u ) { return fa[u] ^ u ? ( fa[u] = findSet( fa[u] ) ) : fa[u] ; }
bool unionSet( const int u, const int v )
{
    int r1 = findSet( u ), r2 = findSet( v );
    if( r1 == r2 ) return false;
    fa[r1] = r2;
    return true;
}

void DFS( const int u, const int fat )
{
    f[u][0] = fat;
    dep[u] = dep[fat] + 1;
    int v;
    for( int i = head[u] ; i ; i = Graph[i].nxt )
    {
        v = Graph[i].to;
        if( v ^ fat )
            maxW[v][0] = RT( Graph[i].w, -INF ), 
            DFS( v, u );
    }
}

void maintain( RT &a, const int val )
{
    if( a.first < val ) a.second = a.first, a.first = val;
    else if( a.second < val ) a.second = val;
}

void init()
{
    for( int j = 1 ; j <= lg2 ; j ++ )
        for( int i = 1 ; i <= N ; i ++ )
        {
            f[i][j] = f[f[i][j - 1]][j - 1];
            maxW[i][j] = RT( -INF, -INF );
            maintain( maxW[i][j], maxW[i][j - 1].first ), maintain( maxW[i][j], maxW[i][j - 1].second ),
            maintain( maxW[i][j], maxW[f[i][j - 1]][j - 1].first ), maintain( maxW[i][j], maxW[f[i][j - 1]][j - 1].second );
        }
}

void balance( int &u, RT &a, const int steps )
{
    for( int i = 0 ; ( 1 << i ) <= steps ; i ++ )
        if( steps & ( 1 << i ) )
            maintain( a, maxW[u][i].first ), 
            maintain( a, maxW[u][i].second ), 
            u = f[u][i];
}

RT query( int u, int v )
{
    RT r = RT( -INF, -INF );
    if( dep[u] > dep[v] ) balance( u, r, dep[u] - dep[v] );
    if( dep[v] > dep[u] ) balance( v, r, dep[v] - dep[u] );
    if( u == v ) return r;
    for( int i = lg2 ; ~ i ; i -- )
        if( f[u][i] ^ f[v][i] )
            maintain( r, maxW[u][i].first ), maintain( r, maxW[u][i].second ), 
            maintain( r, maxW[v][i].first ), maintain( r, maxW[v][i].second ),
            u = f[u][i], v = f[v][i];
    maintain( r, maxW[u][0].first ), maintain( r, maxW[u][0].second );
    maintain( r, maxW[v][0].first ), maintain( r, maxW[v][0].second );
    return r;
}

int main()
{
    read( N ), read( M );
    lg2 = log2( N );
    for( int i = 1 ; i <= M ; i ++ )
        read( E[i].from ), read( E[i].to ), read( E[i].w );
    std :: sort( E + 1, E + 1 + M );
    makeSet( N );
    LL MST = 0;
    int tot = 0;
    for( int i = 1 ; i <= M ; i ++ )
    {
        if( unionSet( E[i].from, E[i].to ) )
            MST += E[i].w, tot ++, used[i] = true,
            addEdge( E[i].from, E[i].to, E[i].w ), addEdge( E[i].to, E[i].from, E[i].w );
        if( tot == N - 1 ) break;
    }
    DFS( 1, 0 );
    init();
    RT tmp;
    LL res = 0x3f3f3f3f3f3f3f3f;
    for( int i = 1 ; i <= M ; i ++ )
        if( ! used[i] )
        {
            tmp = query( E[i].from, E[i].to );
            if( MST - tmp.first + E[i].w > MST ) res = MIN( res, MST - tmp.first + E[i].w );
            if( MST - tmp.second + E[i].w > MST ) res = MIN( res, MST - tmp.second + E[i].w );
        }
    write( res ), putchar( '\n' );
    return 0;
}