题解 CF364E 【Empty Rectangles】

· · 题解

题意

给定一个 nm 列的 01 矩阵,求有多少个子矩阵满足和为 k

\texttt{Data Range:}1\leq n,m\leq 2500,1\leq k\leq 6

题解

二维分治。

考虑一维上的问题:给定一个长度为 n01 序列,求有多少子段和为 k

这个题如果暴力枚举是 O(n^2) 的,但是有一个很好的性质:如果将这个序列划分成两个子段的话,那么可以 O(n+k) 统计出跨越这个划分点的答案。如果写过 CF755G 的倍增做法的话应该知道,也就是记录一下两边和为 k 的有多少然后类似于卷积转移,但是只要求一项所以可以 O(k)

将答案拆成三个部分:子段全部在左边的,全部在右边的和跨越划分点的。跨越划分点的已经说过了,而全部在一边的与原问题等价,于是可以分治,时间复杂度 O(nk\log n)

现在拓展到二维。还是先对一维进行分治,然后统计跨越划分点的贡献的时候,枚举子矩形的上下端点然后类似于一维做。

举个栗子,假设现在是横着划分,枚举 l,r 为左右端点的同时,记录 up_{i} 表示以这个端点为宽,以 mid 为底且子矩阵和 \leq i 的矩形中,上端点最远到达哪里。同时记录 dn_i,类似于 up_i,只不过表示下端点。转移的时候,固定一个 l,当 r 往右移动的时候 up_i 单调不增,dn_i 单调不降,并且每枚举一个 r 的时候,算出对应的 updn,然后差分求出答案。

但是时间复杂度为 O(nm^2\log n),无法通过。

这个时候,注意到第二维跟第一维需要处理的是一样的,说明我们可以同时对两维进行分治。这个实现类似于 K-D Tree 的建树,可以交替分治。

但是,这里存在一种不劣与交替分治的做法:每一次分治的时候,计算矩形两维的跨度,选择跨度大的进行分治。这个过程类似于 K-D Tree 中找方差大的一维建树,时间复杂度更加优秀。

目前按照跨度大的分治是你谷第二优解,比交替分治快了 3.6s,可以看出优化还是比较明显的。

代码

交替分治:

#include<bits/stdc++.h>
using namespace std;
typedef int ll;
typedef long long int li;
const ll MAXN=3e3+51;
ll n,m,kk;
li res;
ll pr[MAXN][MAXN],up[MAXN],dn[MAXN];
char ch[MAXN];
inline ll read()
{
    register ll num=0,neg=1;
    register char ch=getchar();
    while(!isdigit(ch)&&ch!='-')
    {
        ch=getchar();
    }
    if(ch=='-')
    {
        neg=-1;
        ch=getchar();
    }
    while(isdigit(ch))
    {
        num=(num<<3)+(num<<1)+(ch-'0');
        ch=getchar();
    }
    return num*neg;
}
inline ll calc(ll xl,ll xr,ll yl,ll yr)
{
    return pr[xr][yr]-pr[xr][yl]-pr[xl][yr]+pr[xl][yl];
}
inline void calc(ll xl,ll xr,ll yl,ll yr,ll d)
{
    if(xl==xr||yl==yr)
    {
        return;
    }
    if(xr==xl+1&&yr==yl+1)
    {
        return (void)(res+=calc(xl,xr,yl,yr)==kk);
    }
    if(d==0)
    {
        ll mid=(xl+xr)>>1;
        calc(xl,mid,yl,yr,1),calc(mid,xr,yl,yr,1);
        for(register int l=yl;l<=yr;l++)
        {
            up[0]=dn[0]=mid;
            for(register int i=1;i<=kk+1;i++)
            {
                up[i]=xl,dn[i]=xr;
            }
            for(register int r=l+1;r<=yr;r++)
            {
                for(register int i=1;i<=kk+1;i++)
                {
                    while(calc(up[i],mid,l,r)>=i)
                    {
                        up[i]++;
                    }
                    while(calc(mid,dn[i],l,r)>=i)
                    {
                        dn[i]--;
                    }
                }
                for(register int i=0;i<=kk;i++)
                {
                    res+=(li)(up[i]-up[i+1])*(dn[kk-i+1]-dn[kk-i]);
                }
            }
        }
    }
    if(d==1)
    {
        ll mid=(yl+yr)>>1;
        calc(xl,xr,yl,mid,0),calc(xl,xr,mid,yr,0);
        for(register int l=xl;l<=xr;l++)
        {
            up[0]=dn[0]=mid;
            for(register int i=1;i<=kk+1;i++)
            {
                up[i]=yl,dn[i]=yr;
            }
            for(register int r=l+1;r<=xr;r++)
            {
                for(register int i=1;i<=kk+1;i++)
                {
                    while(calc(l,r,up[i],mid)>=i)
                    {
                        up[i]++;
                    }
                    while(calc(l,r,mid,dn[i])>=i)
                    {
                        dn[i]--;
                    }
                }
                for(register int i=0;i<=kk;i++)
                {
                    res+=(li)(up[i]-up[i+1])*(dn[kk-i+1]-dn[kk-i]);
                }
            }
        }
    }
}
int main()
{
    n=read(),m=read(),kk=read();
    for(register int i=1;i<=n;i++)
    {
        scanf("%s",ch+1);
        for(register int j=1;j<=m;j++)
        {
            pr[i][j]=pr[i-1][j]+pr[i][j-1]-pr[i-1][j-1]+(ch[j]&1);
        }
    }
    calc(0,n,0,m,0),printf("%lld\n",res);
}

按照跨度最大的一维分治:

#include<bits/stdc++.h>
using namespace std;
typedef int ll;
typedef long long int li;
const ll MAXN=3e3+51;
ll n,m,kk;
li res;
ll pr[MAXN][MAXN],up[MAXN],dn[MAXN];
char ch[MAXN];
inline ll read()
{
    register ll num=0,neg=1;
    register char ch=getchar();
    while(!isdigit(ch)&&ch!='-')
    {
        ch=getchar();
    }
    if(ch=='-')
    {
        neg=-1;
        ch=getchar();
    }
    while(isdigit(ch))
    {
        num=(num<<3)+(num<<1)+(ch-'0');
        ch=getchar();
    }
    return num*neg;
}
inline ll calc(ll xl,ll xr,ll yl,ll yr)
{
    return pr[xr][yr]-pr[xr][yl]-pr[xl][yr]+pr[xl][yl];
}
inline void calc1(ll xl,ll xr,ll yl,ll yr)
{
    if(xl==xr||yl==yr)
    {
        return;
    }
    if(xr==xl+1&&yr==yl+1)
    {
        return (void)(res+=calc(xl,xr,yl,yr)==kk);
    }
    ll d=xr-xl<yr-yl;
    if(d==0)
    {
        ll mid=(xl+xr)>>1;
        calc1(xl,mid,yl,yr),calc1(mid,xr,yl,yr);
        for(register int l=yl;l<=yr;l++)
        {
            up[0]=dn[0]=mid;
            for(register int i=1;i<=kk+1;i++)
            {
                up[i]=xl,dn[i]=xr;
            }
            for(register int r=l+1;r<=yr;r++)
            {
                for(register int i=1;i<=kk+1;i++)
                {
                    while(calc(up[i],mid,l,r)>=i)
                    {
                        up[i]++;
                    }
                    while(calc(mid,dn[i],l,r)>=i)
                    {
                        dn[i]--;
                    }
                }
                for(register int i=0;i<=kk;i++)
                {
                    res+=(li)(up[i]-up[i+1])*(dn[kk-i+1]-dn[kk-i]);
                }
            }
        }
    }
    if(d==1)
    {
        ll mid=(yl+yr)>>1;
        calc1(xl,xr,yl,mid),calc1(xl,xr,mid,yr);
        for(register int l=xl;l<=xr;l++)
        {
            up[0]=dn[0]=mid;
            for(register int i=1;i<=kk+1;i++)
            {
                up[i]=yl,dn[i]=yr;
            }
            for(register int r=l+1;r<=xr;r++)
            {
                for(register int i=1;i<=kk+1;i++)
                {
                    while(calc(l,r,up[i],mid)>=i)
                    {
                        up[i]++;
                    }
                    while(calc(l,r,mid,dn[i])>=i)
                    {
                        dn[i]--;
                    }
                }
                for(register int i=0;i<=kk;i++)
                {
                    res+=(li)(up[i]-up[i+1])*(dn[kk-i+1]-dn[kk-i]);
                }
            }
        }
    }
}
int main()
{
    n=read(),m=read(),kk=read();
    for(register int i=1;i<=n;i++)
    {
        scanf("%s",ch+1);
        for(register int j=1;j<=m;j++)
        {
            pr[i][j]=pr[i-1][j]+pr[i][j-1]-pr[i-1][j-1]+(ch[j]&1);
        }
    }
    calc1(0,n,0,m),printf("%lld\n",res);
}