AT_wtf22_day1_d Welcome to Tokyo!

· · 题解

传送门

原问题为在长度为 n 的序列上选最多 k 个点设为 1,使得给定的 m 个区间中区间和非零的区间数量最大。

也就是求一组 a_i,b_i\in\{0,1\},其中 a_i 表示第 i 个点是否被选,b_i 表示第 i 个区间是否非 0。要使 \sum\limits_{i=1}^mb_i 最大。

最多选 k 个点也就是 \sum\limits_{i=1}^na_i\le k,如果 b_i1,则 \sum\limits_{j=l_i}^{r_i}a_j 至少为 1,也就是 b_i\le \sum\limits_{j=l_i}^{r_i}a_j

总结一下就是:求出一组非负整数 a_i,b_i 满足:

\begin{cases}a_i\le 1\\b_i\le 1\\\sum\limits_{i=1}^na_i\le k\\b_1-\sum\limits_{j=l_1}^{r_1}a_j\le 0\\b_2-\sum\limits_{j=l_2}^{r_2}a_j\le 0\\\dots\\b_m-\sum\limits_{j=l_m}^{r_m}a_j\le 0\end{cases}

a_i,b_i 为实数的时候就是个线性规划问题。

于是猜测这是一个全幺模矩阵,但是并不满足每一列至多两个元素。

考虑哪个不等式是没有必要的。

可以发现当 a_i>1,答案一定不优,于是可以忽略第一个不等式。

于是这就是一个全幺模矩阵,把 a_i,b_i 松弛到实数域答案不变。

考虑转化一下原问题。

先给每个不等式两边乘个系数:

\begin{cases}B_ib_i\le B_i\\C\sum\limits_{i=1}^na_i\le Ck\\D_1b_1-D_1\sum\limits_{j=l_1}^{r_1}a_j\le 0\\D_2b_2-D_2\sum\limits_{j=l_2}^{r_2}a_j\le 0\\\dots\\D_mb_m-D_m\sum\limits_{j=l_m}^{r_m}a_j\le 0\end{cases}

g_{i,j}\in\{0,1\} 表示第 i 个区间是否包含 j

则:

\begin{cases}B_ib_i\le B_i\\C\sum\limits_{i=1}^na_i\le Ck\\D_1b_1-D_1\sum\limits_{j=1}^{n}g_{1,j}a_j\le 0\\D_2b_2-D_2\sum\limits_{j=1}^{n}g_{2,j}a_j\le 0\\\dots\\D_mb_m-D_m\sum\limits_{j=1}^{n}g_{m,j}a_j\le 0\end{cases}

把所有不等式相加:

\sum\limits_{i=1}^na_i(C-\sum\limits_{j=1}^mD_jg_{j,i})+\sum\limits_{i=1}^mb_i(B_i+ D_i)\le \sum\limits_{i=1}^mB_i+Ck

C-\sum\limits _{j=1}^mD_jg_{j,i}\ge 0B_i+D_i\ge 1,则有:

\sum\limits_{i=1}^mb_i\le \sum\limits_{i=1}^na_i(C-\sum\limits_{j=1}^mD_jg_{j,i})+\sum\limits_{i=1}^mb_i(B_i+ D_i)\le \sum\limits_{i=1}^mB_i+Ck

\sum\limits_{i=1}^mb_i\le \sum\limits_{i=1}^mB_i+Ck

于是 \max(\sum\limits_{i=1}^mb_i)=\min(\sum\limits_{i=1}^mB_i+Ck)

于是我们得到了原问题的对偶问题:

求出数组 B,D 和常数 C,满足:

_{j=1}^mD_jg_{j,i} \\B_i+D_i\ge 1\end{cases}

\sum\limits_{i=1}^mB_i+Ck 的最小值。

由于 B_i+D_i\ge 1,所以要么 B_i=1,要么 D_i=1

可以发现 C 的值就是所有满足 D_i=1 的区间覆盖的点中,被覆盖次数的最大值。

也就是要选出来一个区间的集合 S,使得 m-|S|+kx 最小,其中 x 是选出的集合覆盖的所有点被覆盖次数的最大值。

考虑对于每个 x,求出 |S| 的最大值,然后再套个斜率优化即可。

S_i 表示 x=i 时的集合,则 S_i 包含 S_{i-1}

考虑增量构造,每次取 r 最小的和当前已有的区间组合起来没有一个点被覆盖超过 x 次的区间。

可以开一个线段树维护每个点被覆盖的次数,再开一个线段树维护区间,在线段树上二分,总复杂度 O(n\log^2 n)

#include <bits/stdc++.h>
using namespace std;
const int N=1e6+5;
int n,m,rs[N],q[N],hd,tl;
struct ndP{
    int l,r;
    friend bool operator<(ndP a,ndP b){
        return a.r<b.r;
    }
}w[N];
namespace tr{
    struct node{
        int l,r,laz,sm;
    }p[N<<2];
    void upset(int x){
        p[x].sm=max(p[x<<1].sm,p[x<<1|1].sm);
    }
    void add(int x,int sm){
        p[x].laz+=sm;
        p[x].sm+=sm;
    }
    void reset(int x,int l,int r){
        p[x].l=l,p[x].r=r;
        if(l==r)return;
        int mid=l+r>>1;
        reset(x<<1,l,mid);
        reset(x<<1|1,mid+1,r);
    }
    void dnset(int x){
        if(p[x].laz){
            add(x<<1,p[x].laz);
            add(x<<1|1,p[x].laz);
            p[x].laz=0;
        }
    }
    void add(int x,int l,int r,int sm){
        if(l<=p[x].l&&r>=p[x].r){
            add(x,sm);
            return;
        }
        int mid=p[x].l+p[x].r>>1;
        dnset(x);
        if(l<=mid)add(x<<1,l,r,sm);
        if(r>mid)add(x<<1|1,l,r,sm);
        upset(x);
    }
    int gets(int x,int l,int r){
        if(l>r)return 0;
        if(l==0)return 1e9;
        if(l<=p[x].l&&r>=p[x].r)return p[x].sm;
        dnset(x);
        int mid=p[x].l+p[x].r>>1;
        if(r<=mid)return gets(x<<1,l,r);
        if(l>mid)return gets(x<<1|1,l,r);
        return max(gets(x<<1,l,r),gets(x<<1|1,l,r)); 
    }
}
struct node{
    int l,r,sm;
}p[N<<2];
void upset(int x){
    p[x].sm=max(p[x<<1].sm,p[x<<1|1].sm);
}
void reset(int x,int l,int r){
    p[x].l=l,p[x].r=r;
    if(l==r){
        p[x].sm=w[l].l;
        return;
    }
    int mid=l+r>>1;
    reset(x<<1,l,mid);
    reset(x<<1|1,mid+1,r);
    upset(x);
}
void sets(int x,int d,int sm){
    if(p[x].l==p[x].r){
        p[x].sm=sm;
        return; 
    }
    if(d<=(p[x].l+p[x].r>>1))sets(x<<1,d,sm);
    else sets(x<<1|1,d,sm);
    upset(x); 
}
int gets(int x,int k){
    if(p[x].l==p[x].r)return p[x].l;
    //cout<<x<<" "<<p[x].l<<" "<<p[x].r<<" "<<p[x<<1].sm<<" "<<k<<endl;
    if(tr::gets(1,p[x<<1].sm,n)<k)return gets(x<<1,k);
    return gets(x<<1|1,k);
}
inline int read(){
    int res=0;char c;
    do{
        c=getchar();
    }while(!isdigit(c));
    while(isdigit(c)){
        res=res*10+c-'0';
        c=getchar();
    }
    return res;
}
void print(int x){
    if(x>=10)print(x/10);
    putchar(x%10+'0');
}
signed main(){
    cin>>n>>m;
    for(int i=1;i<=m;i++)w[i].l=read(),w[i].r=read();
    sort(w+1,w+m+1);
    w[m+1].l=1e9;
    reset(1,1,m+1);
    tr::reset(1,1,n);
    for(int i=1;i<=m;i++){
        rs[i]=rs[i-1];
        int x=gets(1,i);
        for(int j=x;j<=m;j=gets(1,i)){
            rs[i]++;
            sets(1,j,0);
            tr::add(1,w[j].l,w[j].r,1);
        //  cout<<j<<" "<<w[j].l<<" "<<w[j].r<<endl;
        }
    //  cout<<i<<" "<<rs[i]<<endl;
    }
    hd=1;
    for(int i=0;i<=m;i++){
        while(hd<tl&&1ll*(rs[q[tl]]-rs[q[tl-1]])*(i-q[tl])<=1ll*(rs[i]-rs[q[tl]])*(q[tl]-q[tl-1]))tl--;
        q[++tl]=i;
    }
    vector<int>qs;
    for(int i=n;i>=1;i--){
        int res=1e9;
        while(hd<tl&&1ll*i*q[hd]-rs[q[hd]]>1ll*i*q[hd+1]-rs[q[hd+1]])hd++;
        res=m-rs[q[hd]]+1ll*i*q[hd];
        qs.push_back(res);
    }
    while(qs.size())print(qs.back()),putchar('\n'),qs.pop_back();
}