题解:AT_abc395_g [ABC395G] Minimum Steiner Tree 2

· · 题解

前言

之前没怎么正经学习这个算法,这次考到了就无语了。所以经过我的痛定思痛啊,参考了很多文章,打算写一篇详细的题解来讲解最小斯坦纳树以及本题相关做法,如果你已经掌握了模板题的做法请略过前面的介绍部分。

什么是最小斯坦纳树

最小斯坦纳树问题与最小生成树问题其实非常相似,把问题形式化如下:

给你一个无向连通图 G=(V,E),以及一个包含 k 个节点的点集 S,包含点集 S 的联通子图就是斯坦纳树。但一般在算法竞赛中我们都只在意包含点集 S 的最小联通子图,也就是最小斯坦纳树

先预告一下,这个问题是 NP 困难问题,所以没有正经的多项式解法,但是存在多项式的求近似解做法。

求解问题

这里先给出一张无向图,加粗的点为点集 S=\{2,4,5,7\},要求求出点集 S 的最小斯坦纳树。

其中最小斯坦纳树的值为 11,整棵树长这样:

你会发现这棵树不仅包含了点集 S,还包含了图上另外一个点 6。所以我们发现,最小斯坦纳树上不一定只有 S 中的点。像点 6 这样添加后能使答案最小化的点我们称之为斯坦纳点

接下来我们先介绍两个简单的性质:

  1. 最小斯坦纳树只包含点集 S 与斯坦纳点,这个从定义就能说明,所以不多赘述。

  2. 答案子图一定是一颗树

【证明】 如果答案子图 G' 包含至少一个环,那么把环上边权最大的那一条边删除也能使联通,所以 G' 中不包括环,证毕。这也是它的名字由来。

所以实际上我们考虑自下而上确定这颗树,具体的,我们设 f_{s} 表示联通状态包括 s 的最小代价。

你可能想当然的以为:“转移很显然啊,枚举 S 中的点集与 s 中的点计算最短路转移即可。”

这样的转移是对的吗?显然是错的。我们没有考虑斯坦点带来的收益,导致没有枚举到正确状态使答案变大,这里还是图例解释。

刚刚这一张图如果按照刚刚的 dp 设计那么答案应该是 11。因为它会钦定 1-4,4-7,1-5 这三条路径的答案。而最小斯坦纳树只需要考虑把点 6 加入集合即可,取到最小边权和为 4+1+2+3=10

所以我们需要改进刚刚的状态设计,不妨直接记一个点,表示树根,设 f_{i,s} 表示以 i 为根节点且包含 s 的树的最小边权和,其中如果 i 不属于 S 那么它就是一个斯坦纳点。

考虑怎么转移,因为我们不知道哪个点作为斯坦点能使答案变优,所以我们只能枚举一个 j,考虑当前状态是从 ji 转移而来。

这样的转移并不完备,因为它构造的树没有枚举**树的形态**,也就是说这样转移只会类似一条链,因为它把 $j$ 换根移动到了 $i$,而没有考虑儿子更复杂的情况。 为了充分的枚举到所有状态集合。我们考虑枚举 $s$ 的子集即可。为什么这样是对的?可以这样理解:实际上 $s$ 的子集 $T$ 与 $s-T$ 可以看作是两个以 $i$ 为根的子图,只不过我们把它们的边取了一个并集而已,也就枚举到了所有集合。 $f_{i,s}=\min(f_{i,s},f_{i,T}+f_{j,s-T})$。 如果你真的看懂了斯坦纳树是怎么构造的,那么一定挥发两种转移是有**先后顺序**的,因为我们在换根的时候一定要保证当前状态最优,所以我们需要先枚举子集。也就是说枚举子集之后才去换根转移。并且实际上以最小斯坦纳树上的哪个点为根都是最优解。 ## 实现以及细节 这里我们需要枚举点集 $S$ 中的每个点 $x$,初始化以 $x$ 为根的子树以及对应的联通状态为 $0$。 然后正如刚才所说,枚举状态集合,先转移子集再转移换根。 时间复杂度 $O(n \times 3^k + m \log m \times 2^k)$。$k$ 是点集大小。 以刚才讲解为例,这里提供 [P6192 【模板】最小斯坦纳树](https://www.luogu.com.cn/problem/P6192) 的代码。 ```cpp #include<bits/stdc++.h> using namespace std; const int N=1e2+5; int u,v,w,n,m,k,a[N],f[N][(1<<11)+5],dis[N]; struct Point{ int v,val; }; struct cmp{ bool operator()(Point x,Point y){ return x.val>y.val; } }; priority_queue <Point,vector<Point>,cmp> q; vector <Point> e[N]; bool vis[N]; void dij(int state){ for(int i = 0;i < n;i++) if(f[i][state]!=0x3f3f3f3f) q.push((Point){i,f[i][state]}); memset(vis,false,sizeof(vis)); while(!q.empty()){ int head=q.top().v; q.pop(); if(vis[head])continue; vis[head]=1; for(int i = 0;i < e[head].size();i++){ int v=e[head][i].v,val=e[head][i].val; if(f[v][state]>f[head][state]+val){ f[v][state]=f[head][state]+val; q.push((Point){v,f[v][state]}); } } } } int main(){ ios::sync_with_stdio(0); cin.tie(0); cout.tie(0); cin >> n >> m >> k; for(int i = 1;i <= m;i++){ cin >> u >> v >> w;u--,v--; e[u].push_back((Point){v,w}); e[v].push_back((Point){u,w}); } memset(f,0x3f,sizeof(f)); for(int i = 0;i < k;i++){ cin >> a[i];a[i]--; f[a[i]][1<<i]=0; } for(int i = 1;i < (1<<k);i++){ for(int j = i&(i-1);j;j=(j-1)&i){ if(j<(i^j))break; for(int now = 0;now < n;now++)f[now][i]=min(f[now][i],f[now][j]+f[now][i-j]); } dij(i); } int ans=2e9; for(int i = 0;i < n;i++)ans=min(ans,f[i][(1<<k)-1]); cout << ans; return 0; } ``` 这是我去年七月学的时候照着题解敲的,实际上有很多地方可以优化,比如可以预处理点对之间的最短路,方便换根时的转移。 ## 本题做法 现在默认你至少会了模板题,现在这个问题就很简单了。 这个题实际上就是固定点集 $S$ 然后每次询问等价于往里面多丢两个数。 因为每次重新做复杂度爆炸,所以我们不妨预处理答案,直接枚举然后加入一个数 $i$,暴力把 $i$ 计入状态。 然后跑 $n$ 次即可,时间复杂度 $O(3^k \times n^2+2^k \times n^2)$。这里预处理了最短路。 ```cpp #include<bits/stdc++.h> using namespace std; #define int long long const int N=85,M=9,inf=1e18; int n,m,K,a[N][N],x,y,f[N][N][1<<M]; signed main(){ ios::sync_with_stdio(0); cin.tie(0),cout.tie(0); cin >> n >> K; for(int i = 1;i <= n;i++)for(int j = 1;j <= n;j++)cin >> a[i][j]; for(int k = 1;k <= n;k++) for(int i = 1;i <= n;i++) for(int j = 1;j <= n;j++) a[i][j]=min(a[i][j],a[i][k]+a[k][j]); memset(f,0x3f,sizeof(f)); for(int i = 1;i <= n;i++){ for(int j = 1;j <= K;j++)f[i][j][1<<j-1]=0; f[i][i][1<<K]=0; } for(int state = 1;state < (1<<K+1);state++){ for(int i = 1;i <= n;i++){ for(int j = 1;j <= n;j++){ for(int k = state&(state-1);k;k=(k-1)&state){ if(k<(state^k))break; f[i][j][state]=min(f[i][j][state],f[i][j][k]+f[i][j][k^state]); } } for(int j = 1;j <= n;j++)for(int k = 1;k <= n;k++)f[i][j][state]=min(f[i][j][state],f[i][k][state]+a[k][j]); } }cin >> m;// while(m--){ cin >> x >> y; cout << f[x][y][(1<<K+1)-1] << "\n"; } return 0; } ```