题解:P11918 [PA 2025] 考试 / Egzamin

· · 题解

背景知识请先阅读 https://www.luogu.com.cn/article/xs2uucwg

n 个随机变量 X_1,X_2,\ldots,X_n,其中 X_ip_i 的概率为 1,否则为 -1。你选择一个 k,再选择其中的 k 个随机变量,求最优决策下你选择的随机变量的和 \ge t 的概率。

首先,可以贪心地选随机变量,因为希望和尽可能大,所以优先选 p_i 大的随机变量。

将所有 p_i 排完序后,最终答案一定是一段前缀。

可以设计一个简单的 DP,记 f_{i,j} 为选长度为 i 的前缀,和为 j 的概率。

每一次和只会 +1 或者 -1,则有转移式:

f_{i,j}=p_i \cdot f_{i-1,j-1} + (1-p_i) \cdot f_{i-1,j+1}

这样做是 \Theta(n^2) 的。

考虑优化,令 \varepsilon 为误差,我们断言,对于每一个 i,使得 f_{i,j} > \varepsilon 的位置不会很多。

转移时只需舍去 f_{i,j} \le \varepsilon 的状态,精细实现即可通过本题。

下面给出这个做法的复杂度证明。

考虑对于每一个 i 分析复杂度。事实上,事件 \{所有随机变量取值为 -11,且和 \ge t\} 等价于事件 \{所有随机变量取值为 01,且和 \ge \frac{i+t}{2}\},这样我们就把问题转化为了若干次的 \text{Bernoulli} 分布,此时和的期望 \mu=\sum \limits_{k=1}^{i} p_i

容易注意到的是,满足 f_{i,j} > \varepsilonj 应该是一段连续区间,不妨设这一段区间为 [(1-B_1)\mu,(1+B_2)\mu],且 B_1,B_2>0

代入 \text{Multiplicative Chernoff Bound} 可以得到:

\textbf{Pr}\big(X \ge (1+B_2) \mu \big) \le e^{-\frac{B_2^2 \mu}{B_2+2}} \le \varepsilon \textbf{Pr}\big(X \le (1-B_1) \mu \big) \le e^{-\frac{B_1^2 \mu}{2}} \le \varepsilon

不等式两边同时取对数的相反数:

\frac{B_2^2 \mu}{B_2+2} \ge \ln \varepsilon^{-1} \frac{B_1^2 \mu}{2} \ge \ln \varepsilon^{-1}

解二次不等式组得:

B_2 \ge \frac{\ln \varepsilon^{-1} + \sqrt{\ln^2 \varepsilon^{-1}-8\mu \ln \varepsilon^{-1}}}{2\mu} B_1 \ge \sqrt{\frac{2\ln \varepsilon^{-1}}{\mu}}

舍去常数,由于 0 \le \mu \le i,所以可以视为 n,\mu 同阶,且 \ln \varepsilon^{-1} 相对于 n 是小量,所以:

B_1,B_2 \ge \sqrt{\frac{\ln \varepsilon^{-1}}{n}}

取等号时最优,得到满足条件的区间为 [\mu - \sqrt{n \ln \varepsilon^{-1}},\mu + \sqrt{n \ln \varepsilon^{-1}}],即对于每个 i,合法状态数为 \Theta(\sqrt{n \ln \varepsilon^{-1}}) 个。

综上,总时间复杂度为 \Theta(n \sqrt{n \ln \varepsilon^{-1}})

经过测试,实际 \varepsilon 的值大约取到 10^{-11},此时总时间复杂度不超过 6 \times 10^7

核心代码:

constexpr int N=5e4+5;
constexpr double eps=1e-11;
namespace Junounly
{
    int n,K;
    double p[N];
    vector<pair<int,double> > q[2];
    void main()
    {
        scanf("%d%d",&n,&K);
        for(int i=1;i<=n;i++) scanf("%lf",&p[i]);
        sort(p+1,p+n+1,greater<double>());
        q[0].emplace_back(0,1);
        double res=0;
        for(int i=1,op=0;i<=n;i++,op^=1)
        {
            for(auto [X,Y]:q[op])
            {
                int x=X-1;double y=Y*(1-p[i]);
                if(y>eps)
                {
                    if(q[op^1].size()&&q[op^1][q[op^1].size()-1].first==x) q[op^1][q[op^1].size()-1].second+=y;
                    else if(q[op^1].size()>1&&q[op^1][q[op^1].size()-2].first==x) q[op^1][q[op^1].size()-2].second+=y;
                    else q[op^1].emplace_back(x,y);
                }
                x=X+1;y=Y*p[i];
                if(y>eps) q[op^1].emplace_back(x,y);
            }
            q[op].clear();
            double now=0;
            for(auto [X,Y]:q[op^1])
                if(X>=K) now+=Y;
            res=max(res,now);
        }
        printf("%.11lf\n",res);
    }
}