题解 CF125E 【MST Company】
EndSaH
2019-09-25 22:30:28
[可能更好的阅读体验](https://endsah.cf/blog/CF125E-MST-Company/#more)
# Solution
## hack
强烈谴责题目数据太水...
首先大部分的凸优化(或者说 wqs/带权二分)的做法都错了。这样只能正确的得到最小生成树的权值,但是要以此为基础构造方案会错。比如下面两组数据:
```
5 8 3
1 2 7
1 4 8
1 5 7
1 3 3
2 3 1
3 5 1
3 4 10
4 5 3
```
```
7 9 3
1 2 8
1 7 4
1 4 3
1 5 6
3 6 1
3 4 4
4 7 1
4 6 3
5 6 1
```
上面的最小生成树的权值是 19,下面的是 20。绝大多数程序会输出 20 和 -1...
[可以证明](https://www.cnblogs.com/CreeperLKF/p/9045491.html) 在这种策略下做出的最小生成树即使度数超过 $k$,也必然可以通过换边来得到度数恰好为 $k$ 的生成树,并且权值一样(凸包上多点共线)。然而直接根据排序顺序并限制度数再去构造方案,很可能得到不连通或错误的结果。
## 破圈
不妨考虑一个暴力一点的思路($n$ 的范围较小):
首先去掉 1 号点跑最小生成森林。
现在加入 1 号点。设此时除 1 号点外的联通块个数为 $x$,若 $k < x$,必然无解。
否则对于每个联通块加入能联通这个联通块的最小的边,这样就得到了一棵强制 $x$ 度生成树。
现在依然考虑经典的破圈算法:
对于一条不在生成树上的与 1 相连的边,若将其加到最小生成树中,其可以替换所产生的环上的一条最大的边,会使权值变小 $maxEdge - w$,其中 $maxEdge$ 是最大边边权,$w$ 是这条边的边权。
在所有 $maxEdge - w$ 找到其最大值并替换,在所替换的最大边不与 1 相连的前提下,就能得到一棵强制 $x + 1$ 度生成树。
将这个过程进行 $k - x$ 次,就能得到题目所要求的强制 $k$ 度生成树。
找到不与 1 相连的最大边由于数据范围原因可以暴力算。设 $maxE _i$ 表示在当前生成树上从 1 出发到 $i$ 节点的路径上**忽略与 1 相连的边**的最大边,每次 DFS 一遍 $O(n)$ 算就行。
$O(m \log m + nk)$
## 二分
**虽然我本人感觉是对的,并且经过了上万组数据的对拍,但是因为数据湿度,这个做法依然可能会假,请谨慎阅读**。
现在这个构造法应该只能做这个题(强制 $k$ 度),另一道也需要二分的题(恰好有 $k$ 条白色边的最小生成树)无法用这个方法构造。
之前已经说了,二分最后得到的最小生成树中 1 的度数会大于等于 $k$。
先通过二分把这个生成树建出来,并固定 1 号点为根。
考虑枚举生成树之外的所有边用来替换那些与 1 相连的边,直到 1 号点度数为 $k$。假设当前边为 $(u, v)$,权值为 $w$,以及 $u$ 归属于 1 号点的儿子中 $p$ 的子树,$v$ 归属于 1 号点的儿子中 $q$ 的子树。
如果 $p = q$ 自然无可替换。否则,如果 $p, q$ 中任意一个点与 1 相连的边权是 $w$,那么这条边 $(1, p/q)$ 就可以被 $(u, v)$ 替换。不妨设可替换的边是 $(1, p)$,那么 $p$ 子树内的点需要全部接入 $q$(也即下一次查询 $p$ 子树内的点,得到结果应当为 $q$)。这是简单并查集操作。
另外边权**必须**恰好是 $w$,才可以替换。实际上因为已经是最小生成树了,不可能替换之后边权更小,证明中也提到了一定是边权相同的边导致了 1 的度数超出了 $k$,所以必须严格等于。
复杂度的话,二分部分为 $O(m \log m \log SIZE)$,其中 $SIZE$ 为值域大小。如果将与 1 相连的边与其他边每次归并排序或者桶排可以优化变为 $O(m \log SIZE)$。换边构树部分为 $O(m \log n)$。
# Code
## gen
```python
from random import randint
from cyaron import *
import os
n = 200
m = randint(n - 1, min(1e5, n * (n - 1) / 2))
k = randint(1, n - 1)
outs = "%d %d %d\n" % (n, m, k)
graph = Graph.UDAG(n, m, weight_limit = 1e2, self_loop=False, repeated_edges=False)
outs += graph.to_str();
print(outs, file = open("tmp.in", "w"))
```
## spj
需搭配 std 和 gen 食用
另外由于 python 玄学原因,每次只能运行 340 次...
```cpp
#include <cstdlib>
#include <iostream>
using namespace std;
using ptr = FILE*;
const int maxN = 5e3 + 5;
const int maxM = 1e5 + 5;
int n, m, _ans, num, k;
int fa[maxN], size[maxN];
struct Edge
{ int u, v, w; }
edge[maxM];
int Find(int x)
{ return fa[x] == x ? x : fa[x] = Find(fa[x]); }
inline bool Merge(int u, int v)
{
int fu = Find(u), fv = Find(v);
if (fu == fv)
return false;
if (size[fu] > size[fv])
swap(fu, fv);
fa[fu] = fv, size[fv] += size[fu];
return true;
}
inline void Err(const char* str)
{ fprintf(stderr, "%s\n", str); exit(0); }
int main()
{
for (int cnt = 1; ; ++cnt)
{
system("python3 gen.py && ./tmp && ./std");
ptr in = fopen("tmp.in", "r"), out = fopen("tmp.out", "r"), ans = fopen("tmp.ans", "r");
int tmp;
fscanf(in, "%d %d %d", &n, &m, &k);
for (int i = 1; i <= n; ++i)
fa[i] = i, size[i] = 1;
for (int i = 1; i <= m; ++i)
fscanf(in, "%d %d %d", &edge[i].u, &edge[i].v, &edge[i].w);
if (fscanf(out, "%d", &_ans) != 1)
Err("No ans.");
if (fscanf(ans, "%d", &num) != 1)
Err("Invalid output.");
if (num == -1 and _ans != -1)
Err("No solution.");
if (num != -1 and _ans == -1)
Err("Not.");
if (num == -1 and _ans == -1)
{
printf("AC %d times!\n", cnt);
continue;
}
if (num != n - 1)
Err("edges' number wrong!");
int curans = 0, c = 0;
for (int i = 1; i <= num; ++i)
{
if (fscanf(ans, "%d", &tmp) != 1)
Err("Invalid output.");
if (!Merge(edge[tmp].u, edge[tmp].v))
Err("Not tree.");
curans += edge[tmp].w;
if (edge[tmp].u == 1 or edge[tmp].v == 1)
++c;
}
if (curans != _ans)
Err("Wrong answer!");
if (c != k)
Err("1's degree isn't k!");
printf("AC %d times!\n", cnt);
}
return 0;
}
```
## std
### 破圈
```cpp
#include<bits/stdc++.h>
#define mp make_pair
#define pb push_back
#define x first
#define y second
using namespace std;
typedef long long LL;
typedef pair<int,int> pii;
template <typename T> inline T read()
{
T sum=0, fl=1; char ch=getchar();
for(; !isdigit(ch); ch=getchar()) if(ch=='-') fl=-1;
for(; isdigit(ch); ch=getchar()) sum=(sum<<3)+(sum<<1)+ch-'0';
return sum*fl;
}
const int maxn=5000+5;
const int maxm=1e5+5;
int n, m, k, c;
LL mx[maxn]; bool del[maxm<<1], no[maxm];
int fa[maxn], id[maxm<<1]; vector<int>vec, E;
int he[maxn], ne[maxm<<1], to[maxm<<1], w[maxm<<1];
struct EDGE { int u, v, w, id; } e[maxm];
bool operator<(const EDGE&a, const EDGE&b) { return a.w<b.w; }
void add_edge(int u, int v, int val)
{
ne[++c]=he[u]; he[u]=c; to[c]=v; w[c]=val;
ne[++c]=he[v]; he[v]=c; to[c]=u; w[c]=val;
}
int find(int x) { return fa[x]==x ? x : fa[x]=find(fa[x]); }
void forest()
{
for(int i=1; i<=n; ++i) fa[i]=i;
sort(e+1, e+m+1);
for(int i=1; i<=m; ++i)
{
int u=find(e[i].u), v=find(e[i].v);
if(u==v || u==1 || v==1) continue;
fa[u]=v; E.pb(e[i].id);
add_edge(e[i].u, e[i].v, e[i].w); id[c]=id[c-1]=e[i].id;
}
}
void get_max(int p, int f)
{
for(int i=he[p]; i; i=ne[i])
{
if(to[i]==f || del[i] || to[i]==1) continue;
if(w[i]>=w[mx[p]]) mx[to[i]]=i;
else mx[to[i]]=mx[p];
get_max(to[i], p);
}
}
void Solve()
{
n=read<int>(), m=read<int>(), k=read<int>(); c=1;
for(int i=1; i<=m; ++i) e[i]=(EDGE){read<int>(), read<int>(), read<int>(), i};
forest();
for(int i=1; i<=m; ++i)
{
if(e[i].u!=1 && e[i].v!=1) continue;
int u=find(e[i].u), v=find(e[i].v);
if(u == v) vec.pb(i);
else fa[u]=v, add_edge(e[i].u, e[i].v, e[i].w), E.pb(e[i].id), --k;
}
for(int i=2; i<=n; ++i) if(find(i)!=find(i-1)) return (void) printf("-1");
if(k<0) return (void) printf("-1");
for(int i=1; i<=k; ++i)
{
if(!vec.size()) return (void) printf("-1"); /**/
memset(mx, 0, sizeof(mx));
for(int j=he[1]; j; j=ne[j]) get_max(to[j], 1);
int now=-1, ans=2147483647;
for(int j=vec.size()-1; j>=0; --j)
{
int p=vec[j], var=max(e[p].u, e[p].v);
if(e[p].w-w[mx[var]]<ans) ans=e[p].w-w[mx[var]], now=j;
}
int p=vec[now], var=e[p].u==1?e[p].v:e[p].u;
add_edge(e[p].v, e[p].u, e[p].w); E.pb(e[p].id);
no[id[mx[var]]]=1; del[mx[var]]=del[mx[var]^1]=1;
vec.erase(vec.begin()+now);
}
printf("%d\n", n-1);
for(int i=E.size()-1; i>=0; --i) if(!no[E[i]]) printf("%d ", E[i]);
}
int main()
{
// freopen("125e.in","r",stdin);
// freopen("125e.out","w",stdout);
Solve();
return 0;
}
```
### 二分
```cpp
#include <iostream>
#include <set>
#include <vector>
#include <bitset>
#include <algorithm>
using namespace std;
const int maxN = 5e3 + 5;
const int maxM = 1e5 + 5;
int n, m, k, maxw, ans;
int fa[maxN], size[maxN], val[maxN], id[maxN];
set<int> out;
vector<struct Edge> G[maxN];
bitset<maxM> vis;
struct Edge
{
int u, v, w, id;
Edge() { }
Edge(int _u, int _v, int _w, int _id) : u(_u), v(_v), w(_w), id(_id) { }
} edge[maxM];
bool operator<(const Edge& x, const Edge& y)
{ return x.w == y.w ? x.u < y.u : x.w < y.w; }
int Find(int x)
{ return fa[x] == x ? x : fa[x] = Find(fa[x]); }
inline bool Merge(int u, int v)
{
int fu = Find(u), fv = Find(v);
if (fu == fv)
return false;
if (size[fu] > size[fv])
swap(fu, fv);
fa[fu] = fv, size[fv] += size[fu];
return true;
}
void Init()
{
for (int i = 1; i <= n; ++i)
fa[i] = i, size[i] = 1;
}
bool Check()
{
Init();
int deg = 0;
for (int i = 1; i <= m; ++i)
if (edge[i].u != 1)
Merge(edge[i].u, edge[i].v);
else
++deg;
if (deg < k)
return false;
int cnt = 0;
for (int i = 2; i <= n; ++i)
if (Find(i) == i)
++cnt;
if (cnt > k)
return false;
for (int i = 1; i <= m; ++i)
if (edge[i].u == 1)
Merge(edge[i].u, edge[i].v);
cnt = 0;
for (int i = 1; i <= n; ++i)
if (Find(i) == i)
++cnt;
return cnt == 1;
}
void DFS(int u, int father)
{
fa[u] = fa[father];
for (const auto& v : G[u]) if (v.u != father)
DFS(v.u, u);
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("tmp.in", "r", stdin);
freopen("tmp.ans", "w", stdout);
#endif
ios::sync_with_stdio(false);
cin >> n >> m >> k;
if (n == 1)
{ cout << 0 << endl; return 0; }
for (int i = 1; i <= m; ++i)
{
Edge& o = edge[i];
cin >> o.u >> o.v >> o.w;
o.id = i;
if (o.u > o.v)
swap(o.u, o.v);
maxw = max(maxw, o.w);
}
if (!Check())
{
cout << "-1" << endl;
return 0;
}
int l = -maxw, r = maxw, p, ecnt, cnt;
while (l <= r)
{
p = (l + r) >> 1;
ecnt = cnt = 0;
Init();
for (int i = 1; i <= m; ++i)
if (edge[i].u == 1)
edge[i].w += p;
sort(edge + 1, edge + m + 1);
for (int i = 1; i <= m; ++i)
if (Merge(edge[i].u, edge[i].v))
{
++ecnt;
if (edge[i].u == 1)
++cnt;
if (ecnt == n - 1)
break;
}
for (int i = 1; i <= m; ++i)
if (edge[i].u == 1)
edge[i].w -= p;
if (cnt >= k)
ans = p, l = p + 1;
else
r = p - 1;
}
Init(), cnt = 0;
for (int i = 1; i <= m; ++i)
if (edge[i].u == 1)
edge[i].w += ans;
sort(edge + 1, edge + m + 1);
for (int i = 1; i <= m; ++i)
if (Merge(edge[i].u, edge[i].v))
{
vis.set(i);
if (edge[i].u == 1)
++cnt;
out.insert(edge[i].id);
G[edge[i].u].emplace_back(edge[i].v, 0, edge[i].w, edge[i].id);
G[edge[i].v].emplace_back(edge[i].u, 0, edge[i].w, edge[i].id);
}
for (const auto& v : G[1])
{
val[v.u] = v.w, id[v.u] = v.id;
fa[1] = v.u, DFS(v.u, 1);
}
for (int i = 1; i <= m and cnt > k; ++i) if (edge[i].u != 1 and !vis[i])
{
int fu = Find(edge[i].u), fv = Find(edge[i].v);
if (fu == fv)
continue;
if (val[fu] == edge[i].w)
{
fa[fu] = fv, --cnt;
out.insert(edge[i].id);
out.erase(id[fu]);
continue;
}
if (val[fv] == edge[i].w)
{
fa[fv] = fu, --cnt;
out.insert(edge[i].id);
out.erase(id[fv]);
continue;
}
}
if (cnt == k)
{
cout << n - 1 << endl;
for (int i : out)
cout << i << ' ';
cout << endl;
}
else
cout << "-1" << endl;
return 0;
}
```