P5633

· · 题解

Solution

稍微讲一下 wqs 二分 . 如果仍看不懂的话可以看一下这一篇 , 蛮不错的 : Link

f(x) 表示点 s 的出度为 x 的时候的答案 .

考虑从 x-1 转移过来 . 我们应当找 s 的一条边 , 给它填上去 . 这时候肯定有一个环 . 我们删掉环上面的原先最大值 . 新的答案减去了那个最大值 , 加上了当前的值 .

然后证明这个东西有凹凸性 . 考虑当前新加入一条边 , 它新产生的贡献肯定比之前变大 .

如果变小的话 , 那我们可以先加入这条边 , 那么 f(x-1) 就会更小 , 与定义不符 . 这样还是不大严谨 , 需要考虑连通性的问题 . 可以不深究 . 如果非要深究的 , 那么 x-1 次加边后 , 再 x 加边 , x 加的边所在的环上可能有上一次加的边 , 但是把这两次加边调换一下 , 后面那次加边删去的最大值肯定还能删 , 甚至可以删更大的 , 所以此时 f(x-1) 会更优 .

wqs 二分是二分斜率 , 也就是 f(x)-f(x-1) . 把 (x,f(x)) 这个东西放到坐标系上 , 每次找到谁的斜率是 c . 画几个图可以发现 , 如果每个点画一个斜率是 c 的直线 , 那么它们截距的最小的那个就是所求 .

考虑截距的实际含义 : b=f(x)-cx , 相当于每条出边的大小加上 -c . 那我们把每条出边加上 -c , 是不是就可以跑 MST , 求出所求点 .

然后根据所求点和 k 的大小关系 , 可以调整斜率 , 直到找到了 k . 最后我们就可以算出来 f(k) 加上某些数之后的值了 , 然后把多余的减掉 .

判断 Impossible : 这非常恶心 . 一下情况不可以 :

3 2 1 1
1 2 1
1 3 1

我们的程序很难判断 , 所以从右边逼近它 (具体说 , 就是找出度大于等于 k的最小值) , 然后更具定义 , 如果能达到 k , 那么将找到的斜率减小一点 , 就一定对应一个小于 k 的数 . 这时候判断就可以啦 .

并且注意 , 为了降低复杂度 , 我们归并排序 .

code :

#include<bits/stdc++.h>
#define ffor(i,a,b) for(int i=(a);i<=(b);i++)
#define roff(i,a,b) for(int i=(a);i>=(b);i--)
#define int long long
using namespace std;
const int MAXN=3e5+10;
int n,m,s,k,cnt,res,ccnt,fa[MAXN];
int find(int k) {
    return (fa[k]==k)?k:(fa[k]=find(fa[k]));    
}
vector<pair<int,pair<int,int>>> a,b;
vector<pair<int,pair<int,int>>> merge(vector<pair<int,pair<int,int>>> a,vector<pair<int,pair<int,int>>> b) {
    vector<pair<int,pair<int,int>>> res;
    int l=0,r=0;
    while(l<a.size()&&r<b.size()) {
        if(a[l]<b[r]) res.push_back(a[l]),l++;
        else res.push_back(b[r]),r++;
    }
    ffor(i,l,(int)(a.size())-1) res.push_back(a[i]);
    ffor(i,r,(int)(b.size())-1) res.push_back(b[i]);
    return res;
}
int get_ans(int k) {
    vector<pair<int,pair<int,int>>> c;
    ffor(i,0,(int)a.size()-1) c.push_back({a[i].first+k,a[i].second});
    vector<pair<int,pair<int,int>>> tmp=merge(c,b);
    ffor(i,1,n) fa[i]=i; cnt=0,res=0,ccnt=0;
    for(auto V:tmp) {
        int w=V.first,u=V.second.first,v=V.second.second;
        if(find(u)==find(v)) continue;
        if(u==s||v==s) cnt++;
        u=find(u),v=find(v);
        res+=w,ccnt++,fa[u]=v;  
    }
    if(ccnt<n-1) return -1;
    return cnt;
}
signed main() {
    ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
    cin>>n>>m>>s>>k;
    ffor(i,1,m) {
        int u,v,w;
        cin>>u>>v>>w;
        if(u==s||v==s) a.push_back({w,{u,v}});
        else b.push_back({w,{u,v}});    
    }
    if(a.size()<k) {cout<<"Impossible";return 0;}
    sort(a.begin(),a.end()),sort(b.begin(),b.end());
    int l=-INT_MAX,r=INT_MAX,ans=-LONG_LONG_MAX;
    if(get_ans(l)<0) {cout<<"Impossible";return 0;}
    while(l<=r) {
        int mid=(l+r>>1),v=get_ans(-mid);
        if(v>=k) ans=mid,r=mid-1;
        else l=mid+1;
    }
    if(ans==-LONG_LONG_MAX) {cout<<"Impossible";return 0;}
    ans--;get_ans(-ans);
    if(cnt>k) {cout<<"Impossible";return 0;}
    ans++;
    get_ans(-ans);
    cout<<res+ans*k;
    return 0;   
}