题解:P11152 [THUWC 2018] 七彩序列

· · 题解

第一步我就没注意到,其实很关键:

同时有前缀和后缀不好做,我们 把后缀的限制给弄到前缀上

具体的,称一个前缀不合法,当且仅当这个前缀每个数出现的次数为 (i,i,\cdots ,i),i\in [1,l] 或者 (a_1+i-l,\cdots ,a_n+i-l),i\in [0,l-1]

后者意义是填完这个前缀会让剩下的后缀不合法。

现在是一个标准的容斥形式,转换题意:

求从 $(0,0,\cdots ,0)$ 走到 $(a_1,a_2,\cdots ,a_n)$ 的方案数。 朴素容斥思路就是记 $f_i$ 表示走到 $i$ 号关键点不经过其他关键点的方案数。 然后转移就是扣掉被 $i$ 偏序 $j$ 的 $f_j$ 乘上组合数。 --- 但是这里我们点很有特殊性质,可以稍微写好看一点: $A_i=(i,i,\cdots ,i),i\in [1,l]$,$f_i$ 表示方案数。 $B_i=(a_1+i-l,\cdots ,a_n+i-l),i\in [0,l]$,$g_i$ 表示方案数。 - 方案数和上面一样,表示它是被经过的第一个关键点的方案数。 定义到 $g_l$ 是因为此时有 $ans=g_l$。 分类讨论下 $f,g$ 是被谁转移的,然后算容斥贡献: $$ \begin{aligned} f_i&=\begin{cases}-f_j\times w_1(i-j)\quad\quad (j\le i-1)\\-g_j\times w_2(i-j)\quad\quad (j\le i-(r-l))\end{cases} \\ \\ g_i&=\begin{cases}-g_j\times w_1(i-j)\quad\quad (j\le i-1)\\-f_j\times w_3(i-j)\quad\quad (j\,{\color{red}{\le}}\,i)\end{cases} \end{aligned} $$ 定义 $X(a_1,a_2,\cdots,a_k)=\dfrac{(\sum a_i)!}{\prod a_i!}.

w_1(t)=X(t,\cdots t),w_2(t)=X(\{t-(a_k-l)\}),w_3(t)=X(\{t+a_k-l\}).

写一下: $$ \begin{aligned} f_i&=w_1(i)-\sum\limits_{j<i} f_jw_1(i-j)-\sum\limits_{j<i} g_jw_2(i-j) \\ \\ g_i&=w_3(i)-\sum\limits_{j<i} g_jw_1(i-j)-\sum\limits_{j{\color{red}{\le}} i} f_jw_3(i-j) \end{aligned} $$ >注意 $w_1(0)=w_2(0)=0$ 并非是 $1$,因为 **dp** 转移的时候不能取这一项。然后 $w_3(0)$ 是要按定义算的。 > >这里直接定义成 $i\in [0,l]$ 即可,满足 $f_0=0,g_0=1$,十分正确。 > > 上述边界情况的讨论是为了下面写生成函数。 > > 这样生成函数写起来就不需要考虑 **Corner Case**,比如 $j<i$ 等条件直接丢掉,由于 $w_1(0)=0$ 仍然是对的。 写出生成函数,大力解方程一下。 $$ \begin{aligned} F&=W_1-W_1F-W_2G \\ G&=W_3-W_1G-W_3F \end{aligned} $$ 这里直接没有常数项了,归功于刚才定义的好。解得: $$ \begin{aligned} F &= \frac{W_1 + W_1^2 - W_2W_3}{1 + 2W_1 + W_1^2 - W_2W_3} \\ \\ G &= \frac{W_3}{1 + 2W_1 + W_1^2 - W_2W_3} \end{aligned} $$ 我们只需要算 $G$ 得到 $g_l$,直接多项式乘法加求逆算即可。复杂度 $\mathcal{O}(nV+V\log V).

:::info[代码]

// 洛谷 P11152
// https://www.luogu.com.cn/problem/P11152
#include<bits/stdc++.h>
#define LL long long
#define fr(x) freopen(#x".in","r",stdin);freopen(#x".out","w",stdout);
using namespace std;
const int N=2e7+5,M=1<<21|5,mod=998244353;
int n,l,r,u,a[105],jc[N],inv[N],_[M],w1[M],w2[M],w3[M],w[M],ans,U,f[M],g[N];
inline int bger(int x){return x|=x>>1,x|=x>>2,x|=x>>4,x|=x>>8,x|=x>>16,x+1;}
inline int md(int x){return x>=mod?x-mod:x;}
inline int ksm(int x,int p){int s=1;for(;p;(p&1)&&(s=1ll*s*x%mod),x=1ll*x*x%mod,p>>=1);return s;}
inline void init(int U)
{
    for(int i=1,j,k;i<U;i<<=1)
        for(w[j=i]=1,k=ksm(3,(mod-1)/(i<<1)),j++;j<(i<<1);j++)
            w[j]=1ll*w[j-1]*k%mod;
}
inline void DNT(int *a,int U)
{
    for(int i,j,k=U>>1,L,*W,*x,*y,z;k;k>>=1)
        for(L=k<<1,i=0;i<U;i+=L)
            for(j=0,W=w+k,x=a+i,y=x+k;j<k;j++,W++,x++,y++)
                *y=1ll*(*x+mod-(z=*y))* *W%mod,*x=md(*x+z);
}
inline void IDNT(int *a,int U)
{
    for(int i,j,k=1,L,*W,*x,*y,z;k<U;k<<=1)
        for(L=k<<1,i=0;i<U;i+=L)
            for(j=0,W=w+k,x=a+i,y=x+k;j<k;j++,W++,x++,y++)
                z=1ll* *W* *y%mod,*y=md(*x+mod-z),*x=md(*x+z);
    reverse(a+1,a+U);
    for(int inv=ksm(U,mod-2),i=0;i<U;i++) a[i]=1ll*a[i]*inv%mod;
}
void INV(int num,int *a,int *b)
{
    if(num==1) return b[0]=ksm(a[0],mod-2),void();
    INV((num+1)>>1,a,b);int U=bger(num<<1);static int c[N];
    for(int i=0;i<num;i++) c[i]=a[i];for(int i=num;i<U;i++) c[i]=0;
    DNT(c,U);DNT(b,U);
    for(int i=0;i<U;i++) b[i]=1ll*(2-1ll*c[i]*b[i]%mod+mod)%mod*b[i]%mod;
    IDNT(b,U);for(int i=num;i<U;i++) b[i]=0;
}
int main()
{
    cin.tie(0)->sync_with_stdio(0);cin>>n;
    for(int i=1;i<=n;i++) cin>>a[i],u+=a[i];
    sort(a+1,a+1+n);l=a[1],r=a[n];
    if(l==r) return cout<<0,0;
    for(int i=*jc=1;i<=u;i++) jc[i]=1ll*jc[i-1]*i%mod;
    inv[u]=ksm(jc[u],mod-2);for(int i=u;i;i--) inv[i-1]=1ll*inv[i]*i%mod;
    for(int i=0,s,S;i<=l;i++)
    {
        _[i]=w1[i]=1ll*jc[i*n]*ksm(inv[i],n)%mod;
        if(i>=r-l)
        {
            s=0,S=1;
            for(int j=1;j<=n;j++) S=1ll*S*inv[i+l-a[j]]%mod,s+=i+l-a[j];
            w2[i]=1ll*S*jc[s]%mod;
        }
        s=0,S=1;
        for(int j=1;j<=n;j++) S=1ll*S*inv[i-l+a[j]]%mod,s+=i-l+a[j];
        w3[i]=1ll*S*jc[s]%mod;
    }w1[0]=0;
    init(U=bger((l+5)<<1));
    DNT(w1,U);DNT(w2,U),DNT(w3,U);
    for(int i=0;i<U;i++) f[i]=(1ll*w1[i]*w1[i]+1ll*(mod-w2[i])*w3[i])%mod;
    IDNT(f,U);fill(f+l+1,f+U,0);f[0]++;
    for(int i=1;i<=l;i++) f[i]=(f[i]+2ll*_[i])%mod;
    INV(l+1,f,g);DNT(g,U);
    for(int i=0;i<U;i++) g[i]=1ll*g[i]*w3[i]%mod;
    IDNT(g,U);
    return cout<<g[l],0;
}

:::