P6689 序列

· · 题解

一个显然的结论是,右括号数目相同的括号序列出现的概率相同。

f_{i,j}表示当前参数为iS中右括号的数目是j

则,

f_{i-1,j}=\frac{N-j+1}{N}\sum_{k=j-1}f_{i,k}\prod_{l=j}^{k}\frac{l}{N}

容易用前缀和优化得到全部的f_{0,d}。计算出了最终有d个括号序列的概率,除以个数,即可得到每一种有d个右括号的括号序列的概率。

接下来的问题转化为对长度给定的最长合法括号子序列计数。

考虑怎么求解一个括号序列的最长合法子序列长度,给出结论:

sum_i表示令'('1')' -1的前缀和,则有ans=N-sum_n+2min\space sum_i

证明的话,考虑令x表示sum_x最小的位置,那么在此之前有-sum_x个右括号是不能匹配的,之后又有(sum_n-sum_x)个左括号不能匹配。减去不匹配的,即合法的。

问题即,为长度为N,共有给定个\pm 1,对前缀和不小于某一给定值的序列计数。考虑其组合意义,相当于初始在(0,0),每次横坐标+1,纵坐标\pm 1,求出不经过某一条直线到达某一个点的方案数。折线法即可。

总复杂度O(NK+N^2)

#include<bits/stdc++.h>
using namespace std;
template<typename T>inline T read(){
    T f=0,x=0;char c=getchar();
    while(!isdigit(c)) f=c=='-',c=getchar();
    while(isdigit(c)) x=x*10+c-48,c=getchar();
    return f?-x:x;
}
namespace run{
    const int N=5009,mod=998244353;
    inline int add(int x,int y){return x+y>=mod?x-mod+y:x+y;}
    inline int sub(int x,int y){return x>=y?x-y:x+mod-y;}
    inline int qpow(int x,int y){
        int ret=1;
        while(y){
            if(y&1) ret=1LL*x*ret%mod;
            x=1LL*x*x%mod,y>>=1;
        }
        return ret;
    }
    int fac[N],ifac[N];
    inline int C(int n,int m){
        if(n<0 || m<0 || n<m) return 0;
        return 1LL*fac[n]*ifac[m]%mod*ifac[n-m]%mod;
    }inline int calc(int n,int d,int m){return sub(C(n,n-d-m),C(n,n-d-m+1));}

    int f[N][N],n,k,inv;
    int main(){
        n=read<int>(),k=read<int>(),inv=qpow(n,mod-2);
        fac[0]=ifac[0]=ifac[1]=1;
        for(int i=1;i<=n;i++) fac[i]=1LL*fac[i-1]*i%mod;
        ifac[n]=qpow(fac[n],mod-2);
        for(int i=n-1;i>=1;i--) ifac[i]=1LL*ifac[i+1]*(i+1)%mod;
        assert(1LL*fac[n-1]*ifac[n-1]%mod==1);

        f[0][0]=1;
        for(int i=1;i<=k;i++){
            int sum=f[i-1][n];
            for(int j=min(n,i);j;j--){
                sum=(1LL*sum*j%mod*inv+f[i-1][j-1])%mod;
                f[i][j]=1LL*inv*(n-j+1)%mod*sum%mod;
            }
        }
        int chk=0;
        for(int i=0;i<=n;i++) chk=add(chk,f[k][i]);
        assert(chk==1);

        for(int i=1;i<=n;i++) f[k][i]=1LL*f[k][i]*qpow(C(n,i),mod-2)%mod;
        int ans=0;
        for(int i=1;i<=n;i++)
            for(int j=2;j<=n && min(i,n-i)>=j/2;j+=2)
                ans=(1LL*f[k][i]*calc(n,i,(j-2*i)/2)%mod*j+ans)%mod;
        printf("%d\n",ans);
        return 0;
    }
}
int main(){
#ifdef my
    freopen("sequence.in","r",stdin);
    freopen("sequence.out","w",stdout);
#endif
    return run::main();
}