P10004题解

· · 题解

前言:

这是蒟蒻做的第一道二项式反演,一开始反演的方向反了(致敬反方向的钟

科技:

二项式反演

思路:

直接计算 ans_{i,j} 是不好做的,我们考虑求出钦定在原排列和逆排列种各有 i,j 个上升的连续对的方案数 f_{i,j},则 f_{i,j}=\sum\limits_{a\ge i}\sum\limits_{b \ge j}\binom{a}{i}\binom{b}{j}ans_{a,b},反演得 ans_{i,j}=\sum\limits_{a\ge i}(-1)^{a-i}\binom{a}{i}\sum\limits_{b\ge j}(-1)^{b-j}\binom{b}{j}f_{a,b}。直接算是 O(n^4) 的,但我们可以预处理一个 s_{j,a}=\sum\limits_{b\ge j}(-1)^{b-j}\binom{b}{j}f_{a,b},这样就可以得到 ans_{i,j}=\sum\limits_{a\ge i}(-1)^{a-i}\binom{a}{i}s_{j,a}。两者都是 O(n^3) 的复杂度。
剩下的事就是求 f_{i,j} 了。
我们钦定了原排列、逆排列中各有 i,j 个上升的连续对,那么就相当于钦定原排列、逆排列中各有 n-i,n-j 个连续的上升段。接下来有一个显然的结论:一个原排列中的连续上升段,段内的值的位置一定也是上升的。也就是说,原排列中一段连续上升的元素,是按照其原顺序被划分至逆排列的若干个上升段的。既然顺序已定,我们可以设 c_{x,y} 表示原排列中第 x 个上升段中有 c_{x,y} 个元素划分在逆排列的第 y 个上升段中。那么,我们钦定原排列、逆排列中各有 i,j 个连续的上升段,就相当于要有一个 i\times j 的矩阵,满足其中每一行每一列的总和为正,且整个矩阵总和为 n
g_{i,j} 表示钦定原排列、逆排列中各有 i,j 个连续的上升段时的方案数,也就是合法的大小为 i\times j 的矩阵 c 的数。注意 f_{i,j}=g_{n-i,n-j}
t_{i,j} 表示只满足第二条限制(总和为 n)的矩阵的数量,易得 t_{i,j}=\binom{n+ij-1}{ij-1},且 t_{i,j}=\sum\limits_{a\le i}\binom{i}{a}\sum\limits_{b\le j}\binom{j}{b}g_{i,j},反演得 g_{i,j}=\sum\limits_{a\le i}\binom{i}{a}(-1)^{i-a}\sum\limits_{b\le j}\binom{j}{b}(-1)^{j-b}t_{i,j},用上文提到的方法可优化至 O(n^3)

代码:

#include <bits/stdc++.h>
using namespace std;
const int N = 510,M = 250510;
int n,v,mod,inv[M],fac[M],ifac[M],ans[N][N],g[N][N],t[N][N],s[N][N],c[N][N];
int C(int n,int m){return 1LL * fac[n] * ifac[m] % mod * ifac[n - m] % mod;}
int main()
{
    scanf("%d%d",&n,&mod);
    v = n * n + n;
    inv[1] = fac[0] = ifac[0] = 1;
    for(int i = 2;i <= v;i++) inv[i] = mod - 1LL * (mod / i) * inv[mod % i] % mod;
    for(int i = 1;i <= v;i++)
    {
        fac[i] = 1LL * fac[i - 1] * i % mod;
        ifac[i] = 1LL * ifac[i - 1] * inv[i] % mod;
    }
    for(int i = 1;i <= n;i++)
        for(int j = 1;j <= n;j++)
            t[i][j] = C(n + i * j - 1,i * j - 1);
    for(int i = 0;i <= n;i++)
    {
        c[i][0] = 1;
        for(int j = 1;j <= i;j++) c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % mod;
    }
    int tmp;
    for(int j = 1;j <= n;j++)
        for(int a = 1;a <= n;a++)
            for(int b = 1;b <= j;b++)
            {
                tmp = 1LL * c[j][b] * t[a][b] % mod;
                if((j - b) & 1) s[j][a] = (s[j][a] - tmp + mod) % mod;
                else s[j][a] = (s[j][a] + tmp) % mod;
            }
    for(int i = 1;i <= n;i++)
        for(int j = 1;j <= n;j++)
            for(int a = 1;a <= i;a++)
            {
                tmp = 1LL * c[i][a] * s[j][a] % mod;
                if((i - a) & 1) g[i][j] = (g[i][j] - tmp + mod) % mod;
                else g[i][j] = (g[i][j] + tmp) % mod;
            }
    memset(s,0,sizeof(s));
    for(int j = 0;j < n;j++)
        for(int a = 0;a < n;a++)
            for(int b = j;b < n;b++)
            {
                tmp = 1LL * c[b][j] * g[n - a][n - b] % mod;
                if((b - j) & 1) s[j][a] = (s[j][a] - tmp + mod) % mod;
                else s[j][a] = (s[j][a] + tmp) % mod;
            }
    for(int i = 0;i < n;i++)
        for(int j = 0;j < n;j++)
            for(int a = i;a < n;a++)
            {
                tmp = 1LL * c[a][i] * s[j][a] % mod;
                if((a - i) & 1) ans[i][j] = (ans[i][j] - tmp + mod) % mod;
                else ans[i][j] = (ans[i][j] + tmp) % mod;
            }
    for(int i = 0;i < n;i++)
    {
        for(int j = 0;j < n;j++) printf("%d ",ans[i][j]);
        puts("");
    }
    return 0;
}