P10867 [HBCPC2024] Points on the Number Axis A 题解

· · 题解

简要题意

有一个序列 a_1,a_2,...,a_n,问若每次等概率选择 2 个数 a_i,a_j,删去这 2 个数,并插入 \frac{a_i+a_j}2,操作 n-1 后剩下的那个数期望是?

Sol

猜一手结论:平均数。

感性地理解一下说讲的很模糊?来理论地证一下!(证明在下面)

有理数的取模:\frac pq \bmod 998244353 就是 p 乘上 q998244353 的逆元。

逆元:因为除法不满足 (a\div b) \bmod p=(a \bmod p) \div (b \bmod p),所以要将除法转换为乘法。根据费马小定理有 a^{p-1} \equiv 1 \bmod p,所以 a 在模 p 意义下的逆元为 a^{p-2}

Proof

由 lzyqwq 大佬的思路启发,只是解释的啰嗦一点:

不妨设答案记作 Sum(a,n)

a 数组只有 2 项,即 n=2,显然成立。

a 数组有 3a_1,a_2,a_3,期望为

\begin{aligned} Sum(a,3) &=\frac 13 \times(\frac{\frac {a_1+a_2}2+a_3}2+ \frac{\frac {a_1+a_3}2+a_2}2+ \frac{\frac {a_2+a_3}2+a_1}2)\\ &=\frac 13 \times (a_1+a_2+a_3) \end{aligned}

就是平均数。

a 数组有 4a_1,a_2,a_3,a_4,期望为

\begin{aligned} Sum(a,4) &=\frac 16\times (Sum(\Set {\frac {a_1+a_2}2,a_3,a_4},3)+Sum(\Set {\frac {a_1+a_3}2,a_2,a_4},3)+Sum(\Set {\frac {a_1+a_4}2,a_2,a_3},3)+Sum(\Set {\frac {a_2+a_3}2,a_1,a_4},3)+Sum(\Set {\frac {a_2+a_4}2,a_1,a_3},3)+Sum(\Set {\frac {a_3+a_4}2,a_1,a_2},3))\\ &=\frac {1}{18} \times (\frac 92 \times (a_1+a_2+a_3+a_4))\\ &=\frac 14 \times (a_1+a_2+a_3+a_4) \end{aligned}

也是平均数。

那么若 a 数组有 n 项,期望为

\begin{aligned} Sum(a,n) &=\frac 1{\frac {n\times (n-1)}2}\times (Sum(\Set {\frac {a_1+a_2}2,a_3,...},n-1)+Sum(\Set {\frac {a_1+a_3}2,a_2,...},n-1)+Sum(\Set {\frac {a_1+a_4}2,a_2,...},n-1)+...+Sum(\Set {\frac {a_{n-1}+a_n}2,a_1,...},n-1))\\ &=\frac {1}{\frac {n\times (n-1)^2}2} \times ((n-1 \times \frac12+(\frac {n\times (n-1)}2-(n-1)) \times 1 ) \times (a_1+a_2+a_3+a_4+...+a_n))\\ &=\frac {1}{\frac {n\times (n-1)^2}2} \times (\frac {(n-1)^2}2) \times (a_1+a_2+a_3+a_4+...+a_n))\\ &=\frac 1n \times (a_1+a_2+a_3+a_4+...+a_n) \end{aligned}

还是平均数。证毕!

Code

因为跟 seanli1008 组队的,用他的账号交的代码,这个 Code 是我写的,没有作弊,特此声明。

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define mod 998244353
int qmi(int a,int b,int p){
    int res=1;
    while(b){
        if(b&1) res=res*a%p;
        b>>=1;
        a=a*a%p;
    }
    return res;
}
signed main(){
    int n;
    cin>>n;
    int x,ans=0;
    for(int i=1;i<=n;i++){
        cin>>x;
        ans+=x;
    }
    ans=ans%mod;
    ans=ans*qmi(n,mod-2,mod)%mod;
    cout<<ans;
}