题解 P10142 [USACO24JAN] Mooball Teams III P

· · 题解

P10142 题解

题目大意

n 只奶牛,第 i 只坐标为 x_i,y_ix,y 是两个排列),你要从中选出一些奶牛分为红队和蓝队(即两个不交的非空奶牛集合),问有多少种选择方案能够用一条平行于横轴或纵轴的直线区分红队和蓝队,对 10^9+7 取模。

题目分析

容斥,总方案数为能被横线分开的能被竖线分开的减掉既能被横线也能被竖线分开的

以下部分只考虑红队最大的 x 小于蓝队最小的 x(左红右蓝)的情况,显然最后方案数 \times2 即可。

前两部分显然一样,下面只讨论能被竖线分开的部分。显然可以直接枚举 i 表示 x 最大的红队奶牛的 x 坐标为 i,方案数为 2^{i-1}(2^{n-i}-1)x 小于 i 的可以选是否加入红队,大于 i 的可以选是否加入蓝队,但蓝队不能空),直接计算 O(n)。接下来是既能被横线也能被竖线分开的方案数。

考虑在 x 上扫描线,假设扫到了 i,将奶牛按 x 排序,前 i 头奶牛只能参加红队(记为红牛),后 n-i 头奶牛只能参加蓝队(记为蓝牛),这样就可以被竖线分开,我们现在对于每一头牛只有参加和不参加两种选择。把 n 头奶牛按 y 排序后,则要求选择参加的红牛全部在参加的蓝牛前或全部在后,考虑放在线段树上,每个线段树节点记录如下信息(均不考虑是否是空集):

线段树 pushup 时记左右节点为 l,r,则:

假设当前扫描线到的牛 iyj,为了不算重,强制 i 参加红队,记 1\sim j-1 构成的区间为 Lj+1\sim n 构成的区间为 R,对答案的贡献为:pa_L\times(sa_R-pa_R)+(sb_L-pa_L)\times pa_R(注意去掉蓝队为空的部分)。

扫描线移动时只需要将某一只蓝牛变为红牛,线段树上修改即可,合并两个线段树节点复杂度 O(1),总复杂度 O(n\log n)

边界等细节看代码。

代码

#include<bits/stdc++.h>
using namespace std;
const int MAXN=200001,mod=1e9+7;
int n;
struct node{
    int x,y;
}a[MAXN];
bool cmp(node i,node j){
    return i.x<j.x;
}
int p[MAXN];
struct data{
    int sa,sb,pa,pb;
}t[MAXN*4];
data operator + (data a,data b){
    data c;
    c.sa=(1ll*a.sa*b.pb+1ll*a.pa*b.sa+1ll*(mod-a.pa)*b.pb)%mod;
    c.sb=(1ll*a.sb*b.pa+1ll*a.pb*b.sb+1ll*(mod-a.pb)*b.pa)%mod;
    c.pa=1ll*a.pa*b.pa%mod;
    c.pb=1ll*a.pb*b.pb%mod;
    return c;
}
#define mid (l+r>>1)
#define ls (p<<1)
#define rs (p<<1|1)
void modify(int p,int l,int r,int x,bool v){
    if(l==r){
        if(v)t[p]=(data){2,2,1,2};
        else t[p]=(data){2,2,2,1};
        return;
    }
    if(x<=mid)modify(ls,l,mid,x,v);
    else modify(rs,mid+1,r,x,v);
    t[p]=t[ls]+t[rs];
}
data query(int p,int l,int r,int x,int y){
    if(x>y)return (data){1,1,1,1};
    if(x<=l&&r<=y)return t[p];
    data v=(data){1,1,1,1};
    if(x<=mid)v=v+query(ls,l,mid,x,y);
    if(mid<y)v=v+query(rs,mid+1,r,x,y);
    return v;
}
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    cin>>n;
    p[0]=1;
    for(int i=1;i<=n;i++){
        cin>>a[i].x>>a[i].y;
        p[i]=p[i-1]*2%mod;
    }
    int ans=0;
    for(int i=1;i<n;i++)ans=(ans+1ll*p[i-1]*(p[n-i]-1))%mod;
    ans=ans*2%mod;
    sort(a+1,a+n+1,cmp);
    for(int i=1;i<=n;i++)modify(1,1,n,i,1);
    for(int i=1;i<n;i++){
        int j=a[i].y;
        modify(1,1,n,j,0);
        data l=query(1,1,n,1,j-1),r=query(1,1,n,j+1,n);
        int res=(1ll*l.pa*(r.sa+mod-r.pa)+1ll*(l.sb+mod-l.pa)*r.pa)%mod;
        ans=(ans+mod-res)%mod;
    }
    cout<<ans*2%mod;
    return 0;
}

谢谢观看!