题解:P9563 [SDCPC 2023] Be Careful 2

· · 题解

不考虑容斥,从上往下扫描线。

将所有点按照横坐标排序,从大到小加入每个点,每次只统计下边界横坐标在 [x_{i-1},x_i) 的所有正方形,考虑其上边界所在的位置。

先将加入的所有点按 y 排序(可以插入排序)。画图发现有两种情况:

  1. 对一个点 (x_i,y_i),单调栈求出每个点左边第一个 x<x_i 的点(设为 (x_j,y_j))、右边第一个 x<x_i 的点(设为 (x_k,y_k)),上边界在 \max(x_j,x_k)+1\le x\le x_i,y_j\le y\le y_k 这样一个矩形中(注意不要算重);

  2. 考虑相邻两个点 (x_i,y_i),(x_{i+1},y_{i+1}),若 y_i<y_{i+1},则上边界在 \max(x_i,x_{i+1})+1\le x\le n,y_i\le y\le y_{i+1} 这样一个矩形中。

转化为求解 f(n,m,d,u)n\times m 的矩形内,上边界横坐标在 [u,m],下边界横坐标在 [0,d] 的正方形个数。

简单容斥一下:设 g(n,m) 表示 n\times m 的矩形内的正方形个数,f(n,m,d,u)=g(n,m)-g(u-1,m)-g(n-d-1,m)+g(u-d-2,m)。而 g(n,m) 是好求的:

g(n,m)=\sum_{d=1}^{\min(n,m)}(n-d+1)(m-d+1)d^2

是一个四次多项式前缀和的形式,拆成平方和、三次方和、四次方和即可。放一下比较冷门的四次方和公式:

\sum_{i=1}^ni^4=\dfrac{n^5}5+\dfrac{n^4}2+\dfrac{n^3}3-\dfrac{n}{30}

复杂度 O(k^2)。代码实现略有不同。

#include<bits/stdc++.h>
#define N 5005
#define ll long long
#define mod 998244353
#define inv2 499122177
#define inv3 332748118
#define inv5 598946612
#define inv6 166374059
using namespace std;
int n,m,k;
struct point{
    int x,y;
}a[N];
bool cmpx(point a,point b){
    return a.x<b.x;
}
bool cmpy(point a,point b){
    return a.y<b.y||a.y==b.y&&a.x<b.x;
}
ll qmi(ll a,ll b){
    ll ans=1;
    for(;b;b>>=1,a=a*a%mod) if(b&1) ans=ans*a%mod;
    return ans;
}
ll S1(ll n){
    return n*(n+1)%mod*inv2%mod;
}
ll S2(ll n){
    return n*(n+1)%mod*(n*2+1)%mod*inv6%mod;
}
ll S3(ll n){
    return S1(n)*S1(n)%mod;
}
ll S4(ll n){
    return (qmi(n,5)*inv5%mod+qmi(n,4)*inv2%mod+qmi(n,3)*inv3%mod-n*inv6%mod*inv5%mod+mod)%mod;
}
ll calc0(ll a2,ll a3,ll a4,ll n){
    return (a2*S2(n)+a3*S3(n)+a4*S4(n))%mod;
}
ll calc1(int n,int m){
    if(n<=0||m<=0) return 0;
    return calc0((ll)(n+1)*(m+1)%mod,(mod-(n+m+2)%mod)%mod,1,min(n,m));
}
ll calc2(int n,int m,int d,int u){
    if(n<=0||m<=0||u-d>m||u>n) return 0;
    return ((calc1(n,m)-calc1(u-1,m)-calc1(n-d-1,m)+calc1(u-d-2,m))%mod+mod)%mod;
}
ll ans;
int st[N],top;
int pl[N],pr[N],d[N],tag[N];
int main(){
    scanf("%d%d%d",&n,&m,&k);
    for(int i=1;i<=k;i++) scanf("%d%d",&a[i].x,&a[i].y);
    sort(a+1,a+k+1,cmpx);
    ans=calc1(n-a[k].x,m);
    for(int i=k;i>=1;i--){
        if(a[i].x==a[i-1].x) continue;
        int d1=a[i-1].x,d2=a[i].x-1-a[i-1].x;
        sort(a+i,a+k+1,cmpy);
        top=0;
        for(int j=i;j<=k;j++){
            tag[j]=0,d[j]=d1;
            while(top&&a[st[top]].x>a[j].x) top--;
            if(top&&a[st[top]].x==a[j].x) tag[j]=1,top--;
            if(top) pl[j]=a[st[top]].y,d[j]=max(d[j],a[st[top]].x);
            else pl[j]=0;
            st[++top]=j;
        }
        top=0;
        for(int j=k;j>=i;j--){
            while(top&&a[st[top]].x>=a[j].x) top--;
            if(top) pr[j]=a[st[top]].y,d[j]=max(d[j],a[st[top]].x);
            else pr[j]=m;
            st[++top]=j;    
        }
        for(int j=i;j<=k;j++){
            if(tag[j]) continue;
            ans=(ans+calc2(a[j].x-d1,pr[j]-pl[j],d2,d[j]+1-d1));
        }
        ans=(ans+calc2(n-d1,a[i].y,d2,a[i].x+1-d1))%mod;
        ans=(ans+calc2(n-d1,m-a[k].y,d2,a[k].x+1-d1))%mod;
        for(int j=i;j<k;j++) ans=(ans+calc2(n-d1,a[j+1].y-a[j].y,d2,max(a[j+1].x,a[j].x)+1-d1))%mod;
    }
    printf("%lld\n",ans);
    return 0;
}