题解:P11873 最大拟合

· · 题解

最大拟合

题意

只有小学数学知识的我看不懂一点。

记叫声为 0, 1 以及结束标志为 2

简单来说就是构造一些三元组 (u, v, w) (u\in[0, 1], v\in[0, 1], w\in[0, 2]),使其值不小于 0,不大于 1,且 (u, v, 0)+(u, v, 1)+(u, v, 2)=1。即 cnt_{u, v, w}ss_i=u, s_{i+1}=v, s_{i+2}=w (i\in[1, n-1]) 的个数,要求最大化 w=\prod_{u\in\{0, 1\}, v\in\{0, 1\}, w\in\{0, 1, 2\}, cnt_{u, v, w}>0} (u, v, w)^{cnt_{u, v, w}},输出 \ln w

解法

如果直接乘,精度损失也会导致算出了答案求不出对数。而题中要求输出对数就是一个提示,即把乘法操作换成对数操作。则 w=\sum cnt_{u, v, w}\ln(u, v, w)。而 cnt 是容易统计的,因此重点在考虑对数操作。首先对于每一对 u, v,有 \sum (u, v, w)=1,因此考虑对每一对 u, v 分别计算贡献。若 (u, v, w) 非最后的三元组,则只有 w01(u, v, w) 才有可能大于 0。先考虑这种情况。

我们可以把它写成 \max_{a>0, b>0, 0<p<1} a\ln(p)+b\ln(1-p)。对 \max 里的式子求导得 \frac a{p}-\frac b{(1-p)},使其为 0,解得 p=\frac a{a+b}。而观察函数图像,是中间高两边低的,所以导数为 0 时取到最大值。即 p=\frac a{a+b} 时最大。

三元时同理,结论与二元类似。设其系数为 a, b, c,则 p_a=\frac a{a+b+c}, p_b=\frac b{a+b+c}, p_c=\frac c{a+b+c}

Code

#include <iostream>
#include <algorithm>
#include <string.h>
#include <iomanip>
#include <bitset>
#include <math.h>
#include <string>
#include <vector>
#include <queue>
#include <set>
#include <map>
#define fst first
#define scd second
#define db double
#define ll long long
#define mp make_pair
#define pb push_back
#define eb emplace_back
#define vi vector <int>
#define pii pair <int, int>
#define sz(x) ((int)x.size())
#define ms(f, x) memset(f, x, sizeof(f))
#define L(i, j, k) for (int i=(j); i<=(k); ++i)
#define R(i, j, k) for (int i=(j); i>=(k); --i)
#define ACN(i, H_u) for (int i=H_u; i; i=E[i].nxt)
using namespace std;
template <typename INT> void rd(INT &res) {
    res=0; bool f=false; char ch=getchar();
    while (ch<'0'||ch>'9') f|=ch=='-', ch=getchar();
    while (ch>='0'&&ch<='9') res=(res<<1)+(res<<3)+(ch^48), ch=getchar();
    res=(f?-res:res);
}
template <typename INT, typename...Args>
void rd(INT &x, Args &...y) { rd(x), rd(y...); }
//dfs
const int maxn=1e5;
const int N=maxn+10;
char s[N]; int n, a[N], cnt[2][2][3];
//wmr

//incra

//lottle
signed main() {
//  ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
//  freopen(".in", "r", stdin);
//  freopen(".out", "w", stdout);
    scanf("%s", s+1); n=strlen(s+1);
    L(i, 1, n) a[i]=(s[i]=='H'?0:1); a[++n]=2;
    L(i, 1, n-2) ++cnt[a[i]][a[i+1]][a[i+2]];
    db ans=0;
    L(i, 0, 1) L(j, 0, 1) {
        int x=cnt[i][j][0], y=cnt[i][j][1], z=cnt[i][j][2];
        if (x+y+z) {
            if (x) ans+=x*log(1.0*x/(x+y+z));
            if (y) ans+=y*log(1.0*y/(x+y+z));
            if (z) ans+=z*log(1.0*z/(x+y+z));
        }
    }
    printf("%.6lf\n", ans);
    return 0;
}