题解 CF125E 【MST Company】

· · 题解

可能更好的阅读体验

Solution

hack

强烈谴责题目数据太水...

首先大部分的凸优化(或者说 wqs/带权二分)的做法都错了。这样只能正确的得到最小生成树的权值,但是要以此为基础构造方案会错。比如下面两组数据:

5 8 3
1 2 7
1 4 8
1 5 7
1 3 3
2 3 1
3 5 1
3 4 10
4 5 3
7 9 3
1 2 8
1 7 4
1 4 3
1 5 6
3 6 1
3 4 4
4 7 1
4 6 3
5 6 1

上面的最小生成树的权值是 19,下面的是 20。绝大多数程序会输出 20 和 -1...

可以证明 在这种策略下做出的最小生成树即使度数超过 k,也必然可以通过换边来得到度数恰好为 k 的生成树,并且权值一样(凸包上多点共线)。然而直接根据排序顺序并限制度数再去构造方案,很可能得到不连通或错误的结果。

破圈

不妨考虑一个暴力一点的思路(n 的范围较小):

首先去掉 1 号点跑最小生成森林。

现在加入 1 号点。设此时除 1 号点外的联通块个数为 x,若 k < x,必然无解。

否则对于每个联通块加入能联通这个联通块的最小的边,这样就得到了一棵强制 x 度生成树。

现在依然考虑经典的破圈算法:

对于一条不在生成树上的与 1 相连的边,若将其加到最小生成树中,其可以替换所产生的环上的一条最大的边,会使权值变小 maxEdge - w,其中 maxEdge 是最大边边权,w 是这条边的边权。

在所有 maxEdge - w 找到其最大值并替换,在所替换的最大边不与 1 相连的前提下,就能得到一棵强制 x + 1 度生成树。

将这个过程进行 k - x 次,就能得到题目所要求的强制 k 度生成树。

找到不与 1 相连的最大边由于数据范围原因可以暴力算。设 maxE _i 表示在当前生成树上从 1 出发到 i 节点的路径上忽略与 1 相连的边的最大边,每次 DFS 一遍 O(n) 算就行。

O(m \log m + nk)

二分

虽然我本人感觉是对的,并且经过了上万组数据的对拍,但是因为数据湿度,这个做法依然可能会假,请谨慎阅读

现在这个构造法应该只能做这个题(强制 k 度),另一道也需要二分的题(恰好有 k 条白色边的最小生成树)无法用这个方法构造。

之前已经说了,二分最后得到的最小生成树中 1 的度数会大于等于 k

先通过二分把这个生成树建出来,并固定 1 号点为根。

考虑枚举生成树之外的所有边用来替换那些与 1 相连的边,直到 1 号点度数为 k。假设当前边为 (u, v),权值为 w,以及 u 归属于 1 号点的儿子中 p 的子树,v 归属于 1 号点的儿子中 q 的子树。

如果 p = q 自然无可替换。否则,如果 p, q 中任意一个点与 1 相连的边权是 w,那么这条边 (1, p/q) 就可以被 (u, v) 替换。不妨设可替换的边是 (1, p),那么 p 子树内的点需要全部接入 q(也即下一次查询 p 子树内的点,得到结果应当为 q)。这是简单并查集操作。

另外边权必须恰好是 w,才可以替换。实际上因为已经是最小生成树了,不可能替换之后边权更小,证明中也提到了一定是边权相同的边导致了 1 的度数超出了 k,所以必须严格等于。

复杂度的话,二分部分为 O(m \log m \log SIZE),其中 SIZE 为值域大小。如果将与 1 相连的边与其他边每次归并排序或者桶排可以优化变为 O(m \log SIZE)。换边构树部分为 O(m \log n)

Code

gen

from random import randint
from cyaron import *
import os
n = 200
m = randint(n - 1, min(1e5, n * (n - 1) / 2))
k = randint(1, n - 1)
outs = "%d %d %d\n" % (n, m, k)
graph = Graph.UDAG(n, m, weight_limit = 1e2, self_loop=False, repeated_edges=False)
outs += graph.to_str();
print(outs, file = open("tmp.in", "w"))

spj

需搭配 std 和 gen 食用

另外由于 python 玄学原因,每次只能运行 340 次...

#include <cstdlib>
#include <iostream>

using namespace std;
using ptr = FILE*;

const int maxN = 5e3 + 5;
const int maxM = 1e5 + 5;

int n, m, _ans, num, k;
int fa[maxN], size[maxN];

struct Edge
{ int u, v, w; }
edge[maxM];

int Find(int x)
{ return fa[x] == x ? x : fa[x] = Find(fa[x]); }

inline bool Merge(int u, int v)
{
    int fu = Find(u), fv = Find(v);
    if (fu == fv)
        return false;
    if (size[fu] > size[fv])
        swap(fu, fv);
    fa[fu] = fv, size[fv] += size[fu];
    return true;
}

inline void Err(const char* str)
{ fprintf(stderr, "%s\n", str); exit(0); }

int main()
{
    for (int cnt = 1; ; ++cnt)
    {
        system("python3 gen.py && ./tmp && ./std");
        ptr in = fopen("tmp.in", "r"), out = fopen("tmp.out", "r"), ans = fopen("tmp.ans", "r");
        int tmp;
        fscanf(in, "%d %d %d", &n, &m, &k);
        for (int i = 1; i <= n; ++i)
            fa[i] = i, size[i] = 1;
        for (int i = 1; i <= m; ++i)
            fscanf(in, "%d %d %d", &edge[i].u, &edge[i].v, &edge[i].w);
        if (fscanf(out, "%d", &_ans) != 1)
            Err("No ans.");
        if (fscanf(ans, "%d", &num) != 1)
            Err("Invalid output.");
        if (num == -1 and _ans != -1)
            Err("No solution.");
        if (num != -1 and _ans == -1)
            Err("Not.");
        if (num == -1 and _ans == -1)
        {
            printf("AC %d times!\n", cnt);
            continue;
        }
        if (num != n - 1)
            Err("edges' number wrong!");
        int curans = 0, c = 0;
        for (int i = 1; i <= num; ++i)
        {
            if (fscanf(ans, "%d", &tmp) != 1)
                Err("Invalid output.");
            if (!Merge(edge[tmp].u, edge[tmp].v))
                Err("Not tree.");
            curans += edge[tmp].w;
            if (edge[tmp].u == 1 or edge[tmp].v == 1)
                ++c;
        }
        if (curans != _ans)
            Err("Wrong answer!");
        if (c != k)
            Err("1's degree isn't k!");
        printf("AC %d times!\n", cnt);
    }
    return 0;
}

std

破圈

#include<bits/stdc++.h>

#define mp make_pair
#define pb push_back
#define x first
#define y second

using namespace std;

typedef long long LL;
typedef pair<int,int> pii;

template <typename T> inline T read()
{
    T sum=0, fl=1; char ch=getchar();
    for(; !isdigit(ch); ch=getchar()) if(ch=='-') fl=-1;
    for(; isdigit(ch); ch=getchar()) sum=(sum<<3)+(sum<<1)+ch-'0';
    return sum*fl;
}

const int maxn=5000+5;
const int maxm=1e5+5;

int n, m, k, c;
LL mx[maxn]; bool del[maxm<<1], no[maxm];
int fa[maxn], id[maxm<<1]; vector<int>vec, E;
int he[maxn], ne[maxm<<1], to[maxm<<1], w[maxm<<1];
struct EDGE { int u, v, w, id; } e[maxm];

bool operator<(const EDGE&a, const EDGE&b) { return a.w<b.w; }

void add_edge(int u, int v, int val)
{
    ne[++c]=he[u]; he[u]=c; to[c]=v; w[c]=val;
    ne[++c]=he[v]; he[v]=c; to[c]=u; w[c]=val;
}

int find(int x) { return fa[x]==x ? x : fa[x]=find(fa[x]); }
void forest()
{
    for(int i=1; i<=n; ++i) fa[i]=i;
    sort(e+1, e+m+1); 
    for(int i=1; i<=m; ++i)
    {
        int u=find(e[i].u), v=find(e[i].v);
        if(u==v || u==1 || v==1) continue;
        fa[u]=v; E.pb(e[i].id);
        add_edge(e[i].u, e[i].v, e[i].w); id[c]=id[c-1]=e[i].id;
    }
}

void get_max(int p, int f)
{
    for(int i=he[p]; i; i=ne[i])
    {
        if(to[i]==f || del[i] || to[i]==1) continue;
        if(w[i]>=w[mx[p]]) mx[to[i]]=i;
        else mx[to[i]]=mx[p];
        get_max(to[i], p);
    }
}

void Solve()
{
    n=read<int>(), m=read<int>(), k=read<int>(); c=1;
    for(int i=1; i<=m; ++i) e[i]=(EDGE){read<int>(), read<int>(), read<int>(), i};
    forest(); 

    for(int i=1; i<=m; ++i) 
    {
        if(e[i].u!=1 && e[i].v!=1) continue;
        int u=find(e[i].u), v=find(e[i].v);
        if(u == v) vec.pb(i);
        else fa[u]=v, add_edge(e[i].u, e[i].v, e[i].w), E.pb(e[i].id), --k; 
    }

    for(int i=2; i<=n; ++i) if(find(i)!=find(i-1)) return (void) printf("-1");
    if(k<0) return (void) printf("-1");

    for(int i=1; i<=k; ++i)
    {
        if(!vec.size()) return (void) printf("-1"); /**/
        memset(mx, 0, sizeof(mx));
        for(int j=he[1]; j; j=ne[j]) get_max(to[j], 1);
        int now=-1, ans=2147483647;
        for(int j=vec.size()-1; j>=0; --j) 
        {
            int p=vec[j], var=max(e[p].u, e[p].v);
            if(e[p].w-w[mx[var]]<ans) ans=e[p].w-w[mx[var]], now=j;
        }
        int p=vec[now], var=e[p].u==1?e[p].v:e[p].u;
        add_edge(e[p].v, e[p].u, e[p].w); E.pb(e[p].id); 
        no[id[mx[var]]]=1; del[mx[var]]=del[mx[var]^1]=1;
        vec.erase(vec.begin()+now);
    }

    printf("%d\n", n-1);
    for(int i=E.size()-1; i>=0; --i) if(!no[E[i]]) printf("%d ", E[i]);
}

int main()
{

    //  freopen("125e.in","r",stdin);
    //  freopen("125e.out","w",stdout);

    Solve();

    return 0;
}

二分

#include <iostream>
#include <set>
#include <vector>
#include <bitset>
#include <algorithm>

using namespace std;

const int maxN = 5e3 + 5;
const int maxM = 1e5 + 5;

int n, m, k, maxw, ans;
int fa[maxN], size[maxN], val[maxN], id[maxN];
set<int> out;
vector<struct Edge> G[maxN];
bitset<maxM> vis;

struct Edge
{
    int u, v, w, id;

    Edge() { }

    Edge(int _u, int _v, int _w, int _id) : u(_u), v(_v), w(_w), id(_id) { }
} edge[maxM];

bool operator<(const Edge& x, const Edge& y)
{ return x.w == y.w ? x.u < y.u : x.w < y.w; }

int Find(int x)
{ return fa[x] == x ? x : fa[x] = Find(fa[x]); }

inline bool Merge(int u, int v)
{
    int fu = Find(u), fv = Find(v);
    if (fu == fv)
        return false;
    if (size[fu] > size[fv])
        swap(fu, fv);
    fa[fu] = fv, size[fv] += size[fu];
    return true;
}

void Init()
{
    for (int i = 1; i <= n; ++i)
        fa[i] = i, size[i] = 1;
}

bool Check()
{
    Init();
    int deg = 0;
    for (int i = 1; i <= m; ++i)
        if (edge[i].u != 1)
            Merge(edge[i].u, edge[i].v);
        else
            ++deg;
    if (deg < k)
        return false;
    int cnt = 0;
    for (int i = 2; i <= n; ++i)
        if (Find(i) == i)
            ++cnt;
    if (cnt > k)
        return false;
    for (int i = 1; i <= m; ++i)
        if (edge[i].u == 1)
            Merge(edge[i].u, edge[i].v);
    cnt = 0;
    for (int i = 1; i <= n; ++i)
        if (Find(i) == i)
            ++cnt;
    return cnt == 1;
}

void DFS(int u, int father)
{
    fa[u] = fa[father];
    for (const auto& v : G[u]) if (v.u != father)
        DFS(v.u, u);
}

int main()
{
#ifndef ONLINE_JUDGE
    freopen("tmp.in", "r", stdin);
    freopen("tmp.ans", "w", stdout);
#endif
    ios::sync_with_stdio(false);
    cin >> n >> m >> k;
    if (n == 1)
    { cout << 0 << endl; return 0; }
    for (int i = 1; i <= m; ++i)
    {
        Edge& o = edge[i];
        cin >> o.u >> o.v >> o.w;
        o.id = i;
        if (o.u > o.v)
            swap(o.u, o.v);
        maxw = max(maxw, o.w);
    }
    if (!Check())
    {
        cout << "-1" << endl;
        return 0;
    }
    int l = -maxw, r = maxw, p, ecnt, cnt;
    while (l <= r)
    {
        p = (l + r) >> 1;
        ecnt = cnt = 0;
        Init();
        for (int i = 1; i <= m; ++i)
            if (edge[i].u == 1)
                edge[i].w += p;
        sort(edge + 1, edge + m + 1);
        for (int i = 1; i <= m; ++i)
            if (Merge(edge[i].u, edge[i].v))
            {
                ++ecnt;
                if (edge[i].u == 1)
                    ++cnt;
                if (ecnt == n - 1)
                    break;
            }
        for (int i = 1; i <= m; ++i)
            if (edge[i].u == 1)
                edge[i].w -= p;
        if (cnt >= k)
            ans = p, l = p + 1;
        else
            r = p - 1;
    }
    Init(), cnt = 0;
    for (int i = 1; i <= m; ++i)
        if (edge[i].u == 1)
            edge[i].w += ans;
    sort(edge + 1, edge + m + 1);
    for (int i = 1; i <= m; ++i)
        if (Merge(edge[i].u, edge[i].v))
        {
            vis.set(i);
            if (edge[i].u == 1)
                ++cnt;
            out.insert(edge[i].id);
            G[edge[i].u].emplace_back(edge[i].v, 0, edge[i].w, edge[i].id);
            G[edge[i].v].emplace_back(edge[i].u, 0, edge[i].w, edge[i].id);
        }
    for (const auto& v : G[1])
    {
        val[v.u] = v.w, id[v.u] = v.id;
        fa[1] = v.u, DFS(v.u, 1);
    }
    for (int i = 1; i <= m and cnt > k; ++i) if (edge[i].u != 1 and !vis[i])
    {
        int fu = Find(edge[i].u), fv = Find(edge[i].v);
        if (fu == fv)
            continue;
        if (val[fu] == edge[i].w)
        {
            fa[fu] = fv, --cnt;
            out.insert(edge[i].id);
            out.erase(id[fu]);
            continue;
        }
        if (val[fv] == edge[i].w)
        {
            fa[fv] = fu, --cnt;
            out.insert(edge[i].id);
            out.erase(id[fv]);
            continue;
        }
    }
    if (cnt == k)
    {
        cout << n - 1 << endl;
        for (int i : out)
            cout << i << ' ';
        cout << endl;
    }
    else
        cout << "-1" << endl;
    return 0;
}