题解:AT_arc106_e [ARC106E] Medals

· · 题解

大家都是使用 Hall 定理最后用 SOSDP 求解的,我来发个模拟网络流的题解。

网络流建模

首先将题意转化为二分图匹配问题,左部点为天数,右部点为人,对于天数 i,向所有第 i 天可以颁奖的人连容量为 1 的边。

然后假定我们确定了天数 k,建出前 k 个左部点,跑一边最大流就能判断是否合法。

但是,显然朴素的进行二分 + Dinic 是过不了的。

优化算法

我们枚举天数 d,判断是否合法,如果不合法就让 d 增加 1,然后在原图上增加一个左部点,设其为 u

考虑模拟 Dinic 寻找增广路的过程,你会发现一条增广路一定是从 u 出发,沿正边走向右部点,沿反边走向左部点,沿正边走向右部点,如此往复,一直走到某个和汇点之间仍有流量的点。这样就找到了一条增广路。

继续观察,我们发现以下性质:

  1. 我们这张二分图右部点的数量相当少,只有 O(n) 级别。
  2. 一条增广路重复经过了同一个点一定是不优的。
  3. 对于一个左部点,它和右部点的连边中至多有 1 条反边存在流量。

考虑建一个只有 n 个点的新图,对于一个左部点 u,如果它和一个右部点 v 存在有流量的反边,那么在新图上连边 v\rightarrow v',其中 v'\neq vv' 为和 u 有连边的右部点。

我们在新图上 dfs,如果走到一个在原图上有流量的点,说明找到了一条增广路。然后在回溯时把应该修改的点都修改连边几遍。

具体实现可以使用链表 + 空间回收,否则会 MLE。

不同算法复杂度的对照

Hall 定理 + 高维前缀和:O(2^nn\log(nk))

朴素的 Dinic:O(n^2k\sqrt{nk}\log(nk))

模拟网络流:O(n^3k)

可以发现,本题解提供的解法在本题的数据范围下表现的并不优秀,最慢的点用了 1.6s。但可以通过诸如 n=50,k=1000 的数据,仍是一种不错的解法。

代码

#include<bits/stdc++.h>
using namespace std;
const int N=4e6+5;
const int M=7e7+5;
int n,k;
int a[40],S[N];
int to[N],rs[20];
int f[N][20];
int head[20][20];
struct E{
    int lst,nxt,val;
}e[M];
int sk[M],top;
void add(int id,int v){
    int u=to[id];
    int s=S[id];
    while(s){
        int i=__builtin_ctz(s);
        s-=(1<<i);
        int x=f[id][i];
        if(x){
            if(x==head[u][i]) head[u][i]=e[x].nxt;
            if(e[x].nxt) e[e[x].nxt].lst=e[x].lst;
            if(e[x].lst) e[e[x].lst].nxt=e[x].nxt;
            sk[++top]=x;
        }
        f[id][i]=0;
        if(i!=v){
            int now=sk[top--];
            f[id][i]=now;
            int &h=head[v][i];
            if(h) e[h].lst=now;
            e[now]={0,h,id};
            h=now;
        }
    }
    to[id]=v;
}
int vis[20];
int dfs(int u){
    if(vis[u]) return -1;
    if(rs[u]){rs[u]--;return u;}
    vis[u]=1;
    for(int v=0;v<n;v++){
        if(vis[v]) continue;
        if(head[u][v]&&~dfs(v)){
            int x=e[head[u][v]].val;
            add(x,v);
            return u;
        }
    }
    return -1;
}
void solve(){
    for(int i=1;i<M-5;i++) sk[++top]=i;
    for(int i=0;i<n;i++) rs[i]=k;
    for(int i=1;;i++){
        int s=S[i];
        memset(vis,0,sizeof(vis));
        to[i]=-1;
        for(int j=0;j<n&&to[i]==-1;j++){
            if((s>>j)&1) to[i]=dfs(j);
        }
        if(to[i]!=-1) add(i,to[i]);
        int f0=0;
        for(int i=0;i<n;i++) f0|=rs[i];
        if(!f0){cout<<i;break;}
    }
}
signed main(){
    cin>>n>>k;
    for(int i=0;i<n;i++) cin>>a[i];
    for(int i=0;i<n;i++){
        for(int d=1;d<=n*k*2;d+=2*a[i]){
            for(int j=d;j<d+a[i];j++) S[j]|=1<<i;
        }
    }
    solve();
    return 0;
}