题解 P10868【[HBCPC2024] Points on the Number Axis B】

· · 题解

Points on the Number Axis B

很牛逼的一道题。

由期望的线性性,我们可以单独考虑每个 x_i 对答案的贡献,这个贡献一定是 k\cdot x_i 的形式,而 k 只跟 i 有关,考虑求出每个 k 的期望。

f_{i,j} 表示当前数前面有 i 个数,后面有 j 个数时系数的期望。按照不断消除的过程,最后一定只剩一个数,所以初始状态为 f_{0,0}=1

由于对称性,先考虑消除的是前面的数。若选到前 i-1 个数,则系数不变。若选到第 i 个数,则系数乘 \dfrac{1}{2}

消除后半部分同理,可以得到转移:

f_{i,j}\leftarrow \dfrac{i-1}{i+j}\,f_{i-1,j}+\dfrac{1}{2}\cdot\dfrac{1}{i+j}\,f_{i-1,j}+\dfrac{j-1}{i+j}\,f_{i,j-1}+\dfrac{1}{2}\cdot\dfrac{1}{i+j}\,f_{i,j-1}

化简一下可以得到:

f_{i,j}\leftarrow \dfrac{1}{i+j}\left(i-\dfrac{1}{2}\right)\,f_{i-1,j} +\dfrac{1}{i+j}\left(j-\dfrac{1}{2}\right)\,f_{i,j-1}

发现这个 i+j 在做无用功,不妨设 g_{i,j}=(i+j)!\,f_{i,j},可以得到:

g_{i,j}\leftarrow \left(i-\dfrac{1}{2}\right)\,g_{i-1,j} +\left(j-\dfrac{1}{2}\right)\,g_{i,j-1}

(i,j) 当成网格图上的点,发现 g_{i,j} 本质上就是网格图路径计数。具体地说,把 \left(i-\dfrac{1}{2}\right)\left(j-\dfrac{1}{2}\right) 当成网格图上的边权,对于一个点 (i,j),到这个点的所有合法路径上的边权积都是:

\left(\prod_{k=1}^{i}\left(k-\dfrac{1}{2}\right)\right)\left(\prod_{k=1}^j\left(k-\dfrac{1}{2}\right)\right)

而方案数是组合数 \dbinom{i+j}{i},所以 g_{i,j} 可以直接计算:

g_{i,j}=\dbinom{i+j}{i}\left(\prod_{k=1}^{i}\left(k-\dfrac{1}{2}\right)\right)\left(\prod_{k=1}^j\left(k-\dfrac{1}{2}\right)\right)

只需要预处理 n!\prod\left(k-\dfrac{1}{2}\right) 就可以 \mathcal{O}(n) 计算答案了。

#include<bits/stdc++.h>
typedef long long ll;
typedef long double ld;
using namespace std;
const int N=1000010,lpw=998244353,inv2=499122177;
inline int max(int x,int y){return x>y?x:y;}
inline int min(int x,int y){return x<y?x:y;}
inline void swap(int &x,int &y){x^=y^=x^=y;}
int n,ans,fac[N],inv[N],ffac[N];
int qpow(int x,int k){
    int res=1;
    while(k){
        if(k&1)res=1ll*res*x%lpw;
        k>>=1;x=1ll*x*x%lpw;
    }
    return res;
}
int c(int n,int k){
    if(n<k)return 0;
    return 1ll*fac[n]*inv[n-k]%lpw*inv[k]%lpw;
}
int main(){
    scanf("%d",&n);
    fac[0]=inv[0]=ffac[0]=1;
    for(int i=1;i<N;i++){
        ffac[i]=1ll*ffac[i-1]*(i-inv2+lpw)%lpw;
        fac[i]=1ll*fac[i-1]*i%lpw;
    }
    inv[N-1]=qpow(fac[N-1],lpw-2);
    for(int i=N-2;i>=1;i--)
        inv[i]=1ll*inv[i+1]*(i+1)%lpw;
    for(int i=1;i<=n;i++){
        int x,res;
        scanf("%d",&x);
        res=c(n-1,i-1);
        res=1ll*res*ffac[i-1]%lpw;
        res=1ll*res*ffac[n-i]%lpw;
        res=1ll*res*inv[n-1]%lpw;
        ans=(ans+1ll*res*x%lpw)%lpw;
    }
    printf("%d\n",ans);
    return 0;
}