浅谈 KM 算法

· · 题解

KM 算法,全名 Kuhn-Munkres 算法,可以在 O(n^3)​ 时间内求出二分图的最大权完美匹配。

该算法的核心思想是给每个点一个顶标 l_i,使得 \forall(u,v),l_u+l_v\ge w_{u,v},匹配时只考虑满足 l_u+l_v=w_{u,v} 的边 u,v,这样可以使得匹配时仅考虑这些边,不用考虑边权跑出来的完美匹配一定是最大权完美匹配。证明考虑任意一个完美匹配权值为 \sum w(u,v)\le\sum l,而我们按照上述方法跑出来的匹配权值为 \sum l

考虑如何实现。我们先初始化一组顶标,使得其满足上述条件,一种可行的初始化方案是令右部点的顶标 l_v=0,左部点的顶标 l_u=\max_v w_{u,v}。这之后,我们只考虑所有点和满足 l_u+l_v=w_{u,v}​ 的边,我们称这样的一个子图为“相等子图”,考虑在其中找匹配。

仿照找最大匹配的方法,我们考虑每次新加一个左部点进入匹配。加入一个左部点 u 时,我们从 u 开始跑一遍匈牙利算法找匹配,若找到匹配则继续考虑下一个左部点,若未找到匹配则说明当前相等子图已经无法找到完美匹配,此时我们要设法扩大它,于是我们考虑调整顶标。

注意到此时从 u 开始通过走未匹配、匹配、未匹配...的交错路我们可以到达部分左部点,将这些路径提出来可以得到一棵树,我们称之为交错树。定义在交错树上的左/右部点分别为点集 S,T,不在交错树上的点类似定义为 S',T'​。

下图是一个例子,其中红色点是 u​​​,红色边是匹配中的边,所有边都是相等子图中的边。在这个例子中,交错树是一条链(左 4 到左 3)。

我们发现:一定没有 S-T' 的边,不然交错树会增长或者甚至找到匹配;S'-T​​ 的边一定是非匹配边,否则该左部点将可以被 u 走到,从而进入 S​。

考虑一种调整的操作:设定一个常数 a​​,给 S​​ 中的顶标 -a​​,T​​ 中的顶标 +a​​​,则该操作会造成以下效果:

然而此时我们选的这个 a 需要使得我们做完这次操作后仍然对于所有边有 l_u+l_v\ge w_{u,v}。第一种影响不用考虑,第二种影响的 l_u+l_v​​​​ 只增不减,于是也不用考虑,只用考虑第三种影响。为了使调整后的顶标仍然满足条件,我们取 a=\min\{l_u+l_v-w_{u,v}\mid u\in S,v\in T'\} 即可。

调整后会有新右部点加入,我们考虑这个右部点:

可以发现,在最多 n 次操作后,我们一定可以找到一个未匹配点,于是此时匹配完成,继续考虑下一个左部点即可。

实现时直接维护交错树即可,因为交错树中的边是只增不删的。

除此之外,为了保证复杂度,我们还需要对每个 T'​​ 中的节点维护一个松弛量 slack_v=\min \{l_u+l_v-w_{u,v}\mid u\in S\}​​,每次求 a​​ 的值时暴力枚举每一个 T'​​​ 中的点,找到最小的 slack_v,令 a 等于其值,并把 v​ 和其匹配点(如果有)加入交错树即可,而无需再将能从之直接到达的点加入:因为若 v 的匹配点 x 还能到达其他右部点 y,则说明有 l_x+l_y-w_{x,y}=0,根据定义有 slack_y=0​,于是在接下来的操作中 y 会先被加入。

时间复杂度:需要执行 O(n) 次加入左部点操作,每次操作需要跑一遍 O(n^2) 的匈牙利算法,每次需要将 O(n) 个左部点加入交错树中,每加入一个左部点需要 O(n) 用当前加入的点对每个右部点计算 slack;每次操作要调整 O(n) 次顶标,每调整一次顶标要修改 O(n) 个点的顶标和 slack,于是时间复杂度是严格 O(n^3)

对于更一般的情况的处理:

code(洛谷模板题):

#include <bits/stdc++.h>
using namespace std;
int n , m , x , y , z , fp[1100] , par[1100] , fa[1100] , mn[1100] , ok = 0;
int fst[1100] , nex[550000] , v[550000] , tot , mxpar[1100] , f[550][550];
long long l[1100] , slack[1100] , ans , val[550000];
vector< int > id , lp , rp;
void add( int a , int b , long long c )
{
    nex[++tot] = fst[a]; fst[a] = tot;
    v[tot] = b; val[tot] = c;
}
int match( int u )
{
    id.push_back(u);
    for(int i = fst[u] ; i ; i = nex[i] )
    {
        if(l[u] + l[v[i]] != val[i] || fa[v[i]]) continue;
        if(!fp[v[i]]) 
        {
            par[u] = v[i]; par[v[i]] = u; fp[u] = fp[v[i]] = 1;
            ans += l[v[i]];
            return 1;
        }
        else
        {
            if(fa[v[i]] = u , fa[par[v[i]]] = v[i] , match(par[v[i]])) 
            {
                par[v[i]] = u; par[u] = v[i]; fp[u] = fp[v[i]] = 1;
                return 1;
            }
        }
    }
    return 0;
}
long long KM( vector< int > &lp , vector< int > &rp )
{
    ans = 0;
    int t = n + n;
    while(lp.size() < rp.size()) 
    {
        lp.push_back(++t);
        for(int i : rp) add(t , i , 0);
    }
    while(lp.size() > rp.size()) 
    {
        rp.push_back(++t);
        for(int i : lp) add(i , t , 0);
    }
    for(int i : lp) fp[i] = l[i] = par[i] = 0;
    for(int i : rp) fp[i] = l[i] = par[i] = 0;
    for(int i : lp)
        for(int e = fst[i] ; e ; e = nex[e] ) l[i] = max(l[i] , (long long)val[e]);
    for(int i : lp)
    {
        for(int u : lp) fa[u] = 0 , slack[u] = 1e16;
        for(int u : rp) fa[u] = 0 , slack[u] = 1e16;
        slack[0] = 1e16;
        fa[i] = -1; 
        if(match(i)) 
        { 
            ans += l[i] , id.clear(); 
            continue; 
        }
        while(1)
        {
            for(int j : id)
                for(int e = fst[j] ; e ; e = nex[e] )
                    if(slack[v[e]] > l[j] + l[v[e]] - val[e])
                        slack[v[e]] = l[j] + l[v[e]] - val[e] , mn[v[e]] = j;
            id.clear();
            int u = 0;
            for(int j : rp)
                if(slack[j] < slack[u] && !fa[j]) u = j;
            int w = slack[u];
            for(int j : lp)
                if(fa[j]) l[j] -= w , ans -= w;
            ans += w;
            for(int j : rp)
            {
                if(fa[j]) l[j] += w , ans += w;
                else slack[j] -= w;
            }
            if(!fp[u])
            {
                ans += l[u] + l[i]; 
                int las = u; u = mn[u]; 
                while(u != -1)
                {
                    if(las)
                    {
                        par[las] = u; par[u] = las; fp[u] = fp[las] = 1;
                        las = 0;
                    }
                    else las = u;
                    u = fa[u];
                }
                break;
            }
            fa[u] = mn[u]; int v = par[u]; fa[v] = u;
            id.push_back(v);
        }
    }
    return ans;
}
int main()
{
//  freopen("1.in" , "r" , stdin);
//  freopen("1.out" , "w" , stdout);
    scanf("%d%d" , &n , &m);
    for(int i = 1 ; i <= m ; i++ )
    {
        scanf("%d%d%d" , &x , &y , &z); f[x][y] = 1; z += 19980731;
        add(x , y + n , z); 
    }
    for(int i = 1 ; i <= n ; i++ )
        for(int j = n + 1 ; j <= n + n ; j++ )
            if(!f[i][j - n]) add(i , j , -1e12);
    for(int i = 1 ; i <= n ; i++ ) lp.push_back(i);
    for(int i = n + 1 ; i <= n + n ; i++ ) rp.push_back(i);
    printf("%lld\n" , KM(lp , rp) - 19980731ll * n);
    for(int i = n + 1 ; i <= n + n ; i++ ) 
    {
        if(par[i] <= n && f[par[i]][i - n]) printf("%d " , par[i]);
        else printf("0 ");
    }
    return 0;
}
/*
*/