P8923 『MdOI R5』Many Minimizations

· · 题解

先考虑给定 a 怎么算出答案。

f_{i,j} 表示当前考虑到 i,且 b_i=j 的答案。

f_{i,j}=\min\limits_{k\le j}\{f_{i-1,k}\}+|a_i-j|

可以发现对于每一个 if 是一个下凸函数,那么可以考虑一个经典套路:用堆维护所有斜率增加 1 的断点。

同时设 mn_i=\min\{f_{i,j}\}

每次操作相当于是往堆中加入两个 a_i 然后弹出最大值 mx,可以发现 mn_i=mn_{i-1}+mx-a_i

去掉一些可以快速计算的东西之后,我们实际上只需要考虑过程结束时堆中剩下的数之和。

我们考虑拆开每个数的贡献。具体地,钦定一个 x\in [0,m],并把 \le x 的数设为 0>x 的数设为 1

现在只需要对所有 x 算出答案之和即可。

对于一个 x,设 g_{i,j} 表示当前考虑到 i,当前堆中有 j1 的方案数。

每次操作有 x 种方案减少一个 1,有 m-x 种方案增加一个 1。特别地,任意时刻 1 的个数会对 0\max

可以发现,每一个 g_{i,j} 都是一个关于 x 的不超过 n 次的多项式,直接暴力维护这些多项式即可做到 O(n^3) 的复杂度。

但这还不够,我们需要进一步优化这个算法。

考虑把这个问题转化为格路计数。

对于一条当前意义下的路径,每次找到第一个对 0\max 的位置并把它前面的部分全部往上平移一格。

这样我们就把对 0\max 的要求去掉了。可以发现这两种路径是一一对应关系。

此时问题转化为了:

设走到 (n,i) 的方案数为 w_i。我们要求的是 \sum\limits_i i\times w_i

路径上至少有一个点纵坐标为 0 这个限制比较麻烦,我们将其容斥掉。

w_i=w_{0,i}-w_{1,i},其中 w_0w_1 分别描述纵坐标都 \ge 0 和纵坐标都 \ge 1 的情况。

根据经典的反射容斥方法可以得到:

w_{0,i}=\sum\limits_{i+2j\ge n}x^j(m-x)^{n-j}\left(\dbinom{n}{j}-\dbinom{n}{i+j+1}\right) w_{1,i}=\sum\limits_{i+2j\ge n}x^j(m-x)^{n-j}\left(\dbinom{n}{j}-\dbinom{n}{i+j}\right) w_i=w_{0,i}-w_{1,i}=\sum\limits_{i+2j\ge n}x^j(m-x)^{n-j}\left(\dbinom{n}{i+j}-\dbinom{n}{i+j+1}\right)

代回我们要求的式子里可以得到:

\sum\limits_{i\ge 0} w_i\times i=\sum\limits_{i\ge 0} i\sum\limits_{i+2j\ge n}x^j(m-x)^{n-j}\left(\dbinom{n}{i+j}-\dbinom{n}{i+j+1}\right) =\sum\limits_{i\ge 0} x^i(m-x)^{n-i}\sum\limits_{j\ge\max\{i,n-i\}} (j-i)\left(\dbinom{n}{j}-\dbinom{n}{j+1}\right)

推到这里,我们已经可以非常简单地在 O(n^2) 的时间复杂度内求出这个多项式。

再用你喜欢的方式求出自然数幂和带进多项式计算,时间复杂度 O(n^2)

进一步地,可以用卷积求出这个多项式,然后将自然数幂和代入计算。时间复杂度 O(n\log n)

参考代码(O(n^2)):

#include <bits/stdc++.h>
using namespace std;
#define N 5005
int n,m,MOD,ans,pw[N],inv[N],z[N],bn[N][N],o[N][N];
void W(int &x,int y) {x+=y;if(x>=MOD) x-=MOD;}
int add(int x,int y) {x+=y;return x<MOD?x:x-MOD;}
int qPow(int x,int y)
{
    int res=1;
    for(;y;y/=2,x=1ll*x*x%MOD) if(y&1)
        res=1ll*res*x%MOD;return res;
}
void init(int n)
{
    pw[0]=o[0][0]=1;
    for(int i=1;i<=n;++i) pw[i]=1ll*pw[i-1]*m%MOD;
    for(int i=1;i<=n+1;++i)
        inv[i]=i>1?1ll*inv[MOD%i]*(MOD-MOD/i)%MOD:1;
    for(int i=0;i<=n;++i) for(int j=0;j<=i;++j)
        bn[i][j]=j?add(bn[i-1][j],bn[i-1][j-1]):1;
    for(int i=1;i<=n;++i) for(int j=1;j<=i;++j)
        o[i][j]=(1ll*o[i-1][j]*j+o[i-1][j-1])%MOD;
}
int main()
{
    scanf("%d %d %d",&n,&m,&MOD);init(n);
    for(int i=0,t1,t2;i<=n;++i)
    {
        t1=0;
        for(int j=max(i,n-i);j<=n;++j)
            W(t1,1ll*(j-i)*(bn[n][j]-bn[n][j+1]+MOD)%MOD);
        for(int j=0;j<=n-i;++j)
        {
            t2=1ll*t1*bn[n-i][j]%MOD*pw[n-i-j]%MOD;
            W(z[i+j],(j&1?MOD-t2:t2));
        }
    }
    for(int i=0,t1,t2;i<=n;++i)
    {
        t1=0;t2=1;
        for(int j=0;j<=i;++j)
        {
            t2=1ll*t2*(m-j+1)%MOD;
            W(t1,1ll*t2*o[i][j]%MOD*inv[j+1]%MOD);
        }W(ans,1ll*t1*z[i]%MOD);
    }ans=add(MOD-ans,1ll*m*(m+1)/2%MOD*n%MOD*qPow(m,n-1)%MOD);
    printf("%d\n",ans);return 0;
}