[ICPC2019 WF]Directing Rainfall 题解

· · 题解

巨大多恐怖题。

首先要设法给这些挡板排个序,使得水只能从排在前面的流到排在后面的。

直接使用 sort 是不可行的,因为我们定义的是偏序关系,只有 x 区间有交的两个挡板才能比较大小,但由于是偏序,可以考虑用拓扑排序的办法解决。

x 从左到右扫描,用一个 set 按照 y 的顺序维护一个包含当前 x 坐标的挡板集合(由于它们的 x 区间有交,所以是全序),在每一个挡板的端点处,将它此时在 set 中的前驱向它连边,它向 set 中的后继连边。

得到的图中进行一次拓扑排序,就得到了一个可行的顺序。

得到顺序之后考虑 DP:

从下往上扫描挡板,令 f_i 表示仅考虑当前的挡板时,x=i 处的雨水要落到规定区间至少需要的洞数。

现在加入一块覆盖区间 [l,r] 且左边低右边高的挡板,考虑对 f 的影响:

所以这个操作可以概括为:

左高右低同理,不过是求后缀最小值。

我们发现没有什么好的数据结构直接支持取一个区间的前缀最小值这个操作,所以考虑更改一下维护信息。

g_i=f_i-f_{i-1},区间加操作就是 g_i 的单点修改。

而对于区间取前缀最小值操作,我们发现有 f_i\leq f_{i-1},所以 g_i\leq 0,我们需要找到所有 g_i>0i,然后将它修改为 0,同时为了保证后面的信息正确,我们实际上要找到 i 后面的一些 g_j<0j,将 g_jg_i 抵消。

所以在做前缀最小值时,我们顺次找 g_i>0i,然后找 i 后面第一个满足 g_j<0j,如果 |g_j|>|g_i|,就将 g_j 改为 g_j+g_i 并将 g_i 清零;否则就将 g_i 变为 g_i+g_j 并将 g_j 清零。如果这样的 j 已经在区间外,就直接这个 g_i 加到 g_{r+1} 上。

可以用 set 分别维护 g_i>0g_i<0 的情况,由于访问过的元素大多被删除,所以最终复杂度是 O(n\log n)

#include <bits/stdc++.h>
using namespace std;
const int MAXN=1000010,INF=2e9,INFV=1e8;
struct P {
    P (int a=0,int c=0,int d=0) {x=a,bel=c,flg=d;}
    int x,bel,flg;
}p[MAXN*2];
struct Tarp {
    Tarp (int a=0,int b=0,int c=0,int d=0,int e=0) {x1=a,y1=b,x2=c,y2=d,id=e;}
    int x1,x2,y1,y2,id;
}t[MAXN];
int l,r,n,tot,cnt,ans,in[MAXN],sx[MAXN];
vector <int> v[MAXN];
multiset < pair<int,int> > s1,s2;
bool cmp (P a,P b) {return a.x<b.x;}
bool operator < (Tarp a,Tarp b) {
//  cout << a.x1 << "  " << a.x2 << "  " << b.x1 << "  " << b.x2 << endl;
    if (a.x2<b.x1||a.x1>b.x2) {return 0;}
    if (a.x1<b.x1&&b.x1<a.x2) {
//      cout << 1ll*a.y1*(a.x2-a.x1)+1ll*(a.y2-a.y1)*(b.x1-a.x1) << "  " << 1ll*b.y1*(a.x2-a.x1) << endl;
        return 1ll*a.y1*(a.x2-a.x1)+1ll*(a.y2-a.y1)*(b.x1-a.x1)<1ll*b.y1*(a.x2-a.x1);
    } else {
        return 1ll*b.y1*(b.x2-b.x1)+1ll*(b.y2-b.y1)*(a.x1-b.x1)>1ll*a.y1*(b.x2-b.x1);
    }
}
set <Tarp> st;
void add_edge (int a,int b) {
    in[b]++;
    v[a].push_back(b);
    return;
}
void print () {
    multiset< pair<int,int> >::iterator it,it2;
    it=s1.begin(),it2=s2.begin();
    while (it!=s1.end()&&it2!=s2.end()) {
        pair<int,int> a=(*it),b=(*it2);
        if (a.first<=b.first) {
            printf("%d  %d\n",a.first,a.second);
            it++;
        } else {
            printf("%d  %d\n",b.first,b.second);
            it2++;
        }
    }
    while (it!=s1.end()) {
        pair<int,int> a=(*it);
        printf("%d  %d\n",a.first,a.second);
        it++;
    }
    while (it2!=s2.end()) {
        pair<int,int> b=(*it2);
        printf("%d  %d\n",b.first,b.second);
        it2++; 
    }
}
int main () {
    scanf("%d%d%d",&l,&r,&n);
    l*=2,r*=2;
    for (int i=1;i<=n;i++) {
        scanf("%d%d%d%d",&t[i].x1,&t[i].y1,&t[i].x2,&t[i].y2);
        t[i].x1*=2,t[i].x2*=2;
        if (t[i].x1>t[i].x2) {swap(t[i].x1,t[i].x2),swap(t[i].y1,t[i].y2);}
        t[i].id=i;
        p[++tot]=P(t[i].x1,i,0);
        p[++tot]=P(t[i].x2,i,1);
    }
    st.insert(Tarp(-1,0,2e9+1,0,0));
    st.insert(Tarp(-1,1e9+1,2e9+1,1e9+1,n+1));
    sort(p+1,p+tot+1,cmp);
    for (int i=1;i<=tot;i++) {
        if (p[i].flg==1) {
            st.erase(t[p[i].bel]);
        }
        set<Tarp>::iterator it=st.lower_bound(t[p[i].bel]),it2;
        it2=it;
        it2--;
        if ((*it).id!=n+1) {add_edge(p[i].bel,(*it).id);}
        if ((*it2).id!=0) {add_edge((*it2).id,p[i].bel);}
        if (p[i].flg==0) {
            st.insert(t[p[i].bel]);
        }
    }
    queue <int> q;
    for (int i=1;i<=n;i++) {
        if (!in[i]) {q.push(i);}
    }
    while (!q.empty()) {
        int a=q.front();
        q.pop();
        sx[++cnt]=a;
        int len=v[a].size();
        for (int j=0;j<len;j++) {
            if (--in[v[a][j]]==0) {q.push(v[a][j]);} 
        }
    }
    s1.insert(make_pair(r+1,INFV));
    s2.insert(make_pair(l,-INFV));
    multiset< pair<int,int> >::iterator it,it2;
    for (int i=1;i<=n;i++) {
        Tarp tmp=t[sx[i]];
        if (tmp.y1<tmp.y2) {
            s1.insert(make_pair(tmp.x1+1,1));
            s2.insert(make_pair(tmp.x2+1,-1));
            it=s1.lower_bound(make_pair(tmp.x1+1,-INF-1));
            while (it!=s1.end()&&(*it).first<=tmp.x2) {
                //printf("    %d  %d\n",(*it).first,(*it).second);
                int flg=1;
                it2=s2.lower_bound(make_pair((*it).first,-INF-1));
                while (it2!=s2.end()&&(*it2).first<=tmp.x2) {
                    //printf("        %d  %d\n",(*it2).first,(*it2).second);
                    if ((*it).second>-(*it2).second) {
                        pair<int,int> tmp2=(*it);
                        s1.erase(it);
                        s1.insert(make_pair(tmp2.first,tmp2.second+(*it2).second));
                        s2.erase(it2);
                        it=s1.lower_bound(make_pair(tmp2.first,-INF-1));
                        it2=s2.lower_bound(make_pair((*it).first,-INF-1));
                    } else if ((*it).second==-(*it2).second) {
                        s1.erase(it);
                        s2.erase(it2);
                        flg=0;
                        break;
                    } else {
                        pair<int,int> tmp2=(*it2);
                        s2.erase(it2);
                        s2.insert(make_pair(tmp2.first,tmp2.second+(*it).second));
                        s1.erase(it);
                        flg=0;
                        break;
                    }
                }
                if (flg) {
                    s1.insert(make_pair(tmp.x2+1,(*it).second));
                    s1.erase(it);
                    //cout << "             " << i << "  " << tmp.x1+1 << endl;
                }
                it=s1.lower_bound(make_pair(tmp.x1+1,-INF-1));
            }
        } else {
            s1.insert(make_pair(tmp.x1,1));
            s2.insert(make_pair(tmp.x2,-1));
            it=s2.lower_bound(make_pair(tmp.x2,INF+1));
            while (it!=s2.begin()&&(*(--it)).first>=tmp.x1+1) {
                //printf("    %d  %d\n",(*it).first,(*it).second);
                int flg=1;
                it2=s1.lower_bound(make_pair((*it).first,INF+1));
                while (it2!=s1.begin()&&(*(--it2)).first>=tmp.x1+1) {
                    //printf("        %d  %d\n",(*it2).first,(*it2).second);
                    if (-(*it).second>(*it2).second) {
                        pair<int,int> tmp2=(*it);
                        s2.erase(it);
                        s2.insert(make_pair(tmp2.first,tmp2.second+(*it2).second));
                        it=s2.find(make_pair(tmp2.first,tmp2.second+(*it2).second));
                        s1.erase(it2);
                        it2=s1.lower_bound(make_pair((*it).first,INF+1));
                    } else if (-(*it).second==(*it2).second) {
                        s2.erase(it);
                        s1.erase(it2);
                        flg=0;
                        break;
                    } else {
                        pair<int,int> tmp2=(*it2);
                        s1.erase(it2);
                        s1.insert(make_pair(tmp2.first,tmp2.second+(*it).second));
                        s2.erase(it);
                        flg=0;
                        break;
                    }
                }
                if (flg) {
                    s2.insert(make_pair(tmp.x1,(*it).second));
                    s2.erase(it);
                }
                it=s2.lower_bound(make_pair(tmp.x2,INF+1));
            }
        }
        //printf("%d\n",i);
        //print();
        //printf("----------------------------\n");
    }
    int las=-1,val=INFV;
    ans=INF;
    it=s1.begin(),it2=s2.begin();
    while (it!=s1.end()&&it2!=s2.end()) {
        pair<int,int> a=(*it),b=(*it2);
        if (a.first<=b.first) {
            //printf("%d  %d\n",a.first,a.second);
            if (las<=r&&a.first>l) {ans=min(ans,val);}
            val+=a.second;
            las=a.first;
            it++;
        } else {
            //printf("%d  %d\n",b.first,b.second);
            if (las<=r&&b.first>l) {ans=min(ans,val);}
            val+=b.second;
            las=b.first;
            it2++;
        }
    }
    while (it!=s1.end()) {
        pair<int,int> a=(*it);
        //printf("%d  %d\n",a.first,a.second);
        if (las<=r&&a.first>l) {ans=min(ans,val);}
        val+=a.second;
        las=a.first;
        it++;
    }
    while (it2!=s2.end()) {
        pair<int,int> b=(*it2);
        //printf("%d  %d\n",b.first,b.second);
        if (las<=r&&b.first>l) {ans=min(ans,val);}
        val+=b.second;
        las=b.first;
        it2++; 
    }
    printf("%d\n",ans);
    return 0;
}