UVA12633 Super Rooks on Chessboard

· · 题解

UVA12633 Super Rooks on Chessboard

前言

这有黑?纯纯细节题。

题解

很容易想到容斥。设能被行攻击的格子集合为 R,能被列攻击的格子集合为 C,能被对角线攻击的格子集合为 D,那么答案就为 n\times m-|R|-|C|-|D|+|R\cap C|+|C\cap D|+|D\cap R|-|R\cap C\cap D|,唯一的难点是求 |R\cap C\cap D|

对于一条对角线,xy 的差是相同的,我们用 x-y 表示一条对角线的标号。设 r_i 被行攻击, c_i 被列攻击, d_i 被对角线攻击。求 |R\cap C\cap D| 就是求多少个 (i,j,k) 使得 r_i-c_j=d_k,移项为 r_i=d_k+c_j

构造生成函数 F(x)\sum\limits_{i\in d}x^i,G(x)\sum\limits_{i\in c}x^i,答案为 \sum\limits_{i\in r}[x^i]F*G(x)

直接 FFT/NTT 即可

代码

写的比较乱,尽量不要参考

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=4e6+10,mod=998244353;
int f[N],g[N],n,m,l,r[N],vis1[N],vis2[N],vis3[N],sum1[N],sum2[N];
int qpow(int x,int y){
    int ret=1;
    while(y){
        if(y&1)ret=1ll*ret*x%mod;
        x=1ll*x*x%mod;
        y>>=1;
    }
    return ret;
}
void NTT(int *a,int op){
    for(int i=0;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
    for(int i=1;i<n;i<<=1){
        int wn=qpow(op==1?3:332748118,(mod-1)/(i<<1));
        for(int j=0;j<n;j+=(i<<1)){
            int w=1;
            for(int k=0;k<i;k++,w=1ll*w*wn%mod){
                int x=a[j+k],y=1ll*w*a[j+k+i]%mod;
                a[j+k]=(x+y)%mod;
                a[j+k+i]=(x-y+mod)%mod;
            } 
        }
    }
}
void init(){
    memset(sum1,0,sizeof(sum1));
    memset(sum2,0,sizeof(sum2));
    memset(vis1,0,sizeof(vis1));
    memset(vis2,0,sizeof(vis2));
    memset(vis3,0,sizeof(vis3));
    memset(f,0,sizeof(f));
    memset(g,0,sizeof(g));l=-1;
}
void solve(int T){
    init();
    int t;
    cin>>n>>m>>t;
    ll ans=-1ll*n*m;
    for(int i=1,x,y;i<=t;i++){
        cin>>x>>y;
        vis1[x]=vis2[y]=vis3[x-y+m]=1;
    }
    for(int i=1;i<=n;i++)if(vis1[i])ans+=m;
    for(int i=1;i<=m;i++)if(vis2[i])ans+=n;
    for(int i=1;i<=n+m;i++)if(vis3[i])ans+=min(i,n)-max(i+1-m,1)+1;
    for(int i=1;i<=n;i++)sum1[i]=vis1[i]+sum1[i-1];
    for(int i=1;i<=m;i++)sum2[i]=vis2[i]+sum2[i-1];
    ans-=1ll*sum1[n]*sum2[m];
    for(int i=1;i<=n+m;i++)if(vis3[i])ans-=sum1[min(i,n)]-sum1[max(i+1-m,1)-1];
    for(int i=1;i<=n+m;i++)if(vis3[i])ans-=sum2[min(n-i+m,m)]-sum2[max(1-i+m,1)-1];
    for(int i=1;i<=m;i++)if(vis2[i])f[i]=1;
    for(int i=1;i<=n+m;i++)if(vis3[i])g[i]=1;
    int n_=n,m_=m;
    m=n+m+m;n=1;
    while(n<=m)n<<=1,l++;
    for(int i=1;i<n;i++)r[i]=(r[i>>1]>>1)|((i&1)<<l);
    NTT(f,1),NTT(g,1);
    for(int i=0;i<n;i++)f[i]=1ll*f[i]*g[i]%mod;
    NTT(f,-1);
    int inv=qpow(n,mod-2);
    for(int i=0;i<n;i++)f[i]=1ll*f[i]*inv%mod;
    for(int i=1;i<=n_;i++)ans+=1ll*vis1[i]*f[i+m_];
    cout<<"Case "<<T<<": "<<-ans<<"\n";
}
int main(){
    int T;cin>>T;
    for(int i=1;i<=T;i++)solve(i);
}