题解 CF627E 【Orchestra】
justin_cao
2020-02-05 19:58:47
一道很妙的题。
首先考虑最暴力的解法应该就是枚举上下边界然后用单调栈去求了,复杂度$O(n^3)$。
考虑怎么求优化这个过程,考虑能不能只枚举上边界和下边界然后不用单调栈。
发现给定的点数很少(跟$n,m$同阶),于是考虑从这里下手。
考虑枚举一个上边界$i$,然后考虑将上边界为$i$,下边界为$j$的合法矩形个数算出来。
考虑怎么存它,且还要动态修改(因为要把下边界从$n$往上移)。
一个上边界的点,找到在它左边的第$k$个在$i$到$n$行之间的给定的点,那么它的贡献就是这个点的纵坐标。
于是考虑将第$i$行到第$n$行的点提取出来按照纵坐标排个序按照顺序建成一个双向链表,在对每个点记一个$cnt_i$表示上边界有多少个点的左边第$k$个点是$i$。
然后考虑一行一行地从下面往上面删行,先考虑删除一个点。
画个图就可以发现,删除一个点实际上就是把它到左边的$k$个点集体左移了一位,贡献的减少量也可以非常简单的通过双向链表来实现,同时维护一下$cnt$的变化,然后将这个点删掉即可。
于是运用双向链表可以非常简单的用$O(k)$的时间来删除一个点。
于是用这样的方法我们可以用上一行再删去这一行的点求得这一行的答案。
这样子一行行删去即可。
复杂度$O(n^2k)$。
code:
```cpp
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<iostream>
#include<set>
#include<vector>
#include<queue>
#include<stack>
#include<bitset>
#define G 3
#define eps 1e-15
#define maxn 3010
#define maxm 100010
#define inf 999999999999999
#define mod 998244353
#define mp(x,y) make_pair(x,y)
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef unsigned int uint;
typedef pair<int,int> pii;
int read()
{
int x=0,f=1;
char ch=getchar();
while(ch-'0'<0||ch-'0'>9){if(ch=='-') f=-1;ch=getchar();}
while(ch-'0'>=0&&ch-'0'<=9){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int n,m,c,K,top;
struct P{
int x,y;
}p[maxn];
int a[maxn];
vector<int>v[maxn];
bool cmp(int x,int y)
{
return p[x].y<p[y].y;
}
int L[maxn],R[maxn],cnt[maxn];
ll ans,sum;
bool dfs(int x,int las)
{
if(!las||!x) return true;
if(!dfs(L[x],las-1)) cnt[L[x]]=0;
cnt[L[x]]+=cnt[x];
return false;
}
int main()
{
n=read();m=read();c=read();K=read();
for(int i=1;i<=c;i++) p[i].x=read(),p[i].y=read(),v[p[i].x].push_back(i);
for(int i=1;i<=n;i++)
{
top=0;
for(int j=1;j<=c;j++)
if(p[j].x>=i) a[++top]=j;
sort(a+1,a+top+1,cmp);
a[top+1]=c+1;
sum=0;
for(int j=0;j<=c+1;j++) cnt[j]=L[j]=R[j]=0;
int now=0;
for(int j=1;j<=m;j++)
{
while(p[a[now+1]].y==j&&now<top)
{
now++;
L[a[now]]=a[now-1];
R[a[now]]=a[now+1];
}
int tmp=a[now];
for(int k=1;k<=K-1;k++) tmp=L[tmp];
cnt[tmp]++;
}
for(int j=1;j<=top;j++) sum+=p[a[j]].y*cnt[a[j]];
ans+=sum;
for(int j=n;j>=i+1;j--)
{
for(int k=0;k<v[j].size();k++)
{
int x=v[j][k];
for(int l=1;l<=K;l++)
{
sum-=cnt[x]*(p[x].y-p[L[x]].y);
x=L[x];
if(!x) break;
}
dfs(v[j][k],K);
x=v[j][k];
L[R[x]]=L[x];
R[L[x]]=R[x];
}
ans+=sum;
}
}
cout<<ans<<endl;
return 0;
}
```