P1850 [NOIP 2016 提高组] 换教室题解

· · 题解

思路

看到期望就想到 dp。
题目中两个教室间的距离可以用 floyd 快速求出。

dp_{i,j,0/1} 表示到第 i 个课程,换了 j 次课(包括这一次),当前课是否选择换课的期望距离最小值。

转移方程

dis_{i,j} 表示第 i 个教室和第 j 个教室间的距离,第 i 个课交换成功的概率为 p_i

不交换

先考虑当前课不换的情况。

此时只有两种情况:上一次课换或不换。

对于上一次课不换的情况,显然有:

dp_{i,j,0}=dp_{i-1,j,0}+dis_{c_{i-1},c_{i}}

对于上一次课换的情况,由于上一次换课不一定成功,所以需要分类讨论。

当上一次换课成功时,两者距离为 dis_{d_{i-1},c_{i}},失败则为 dis_{c_{i-1},c_{i}}。概率分别为 p_{i-1}1-p_{i-1}

于是有:

dp_{i,j,0}=dp_{i-1,j,1}+dis_{d_{i-1},c_{i}}p_{i-1}+dis_{c_{i-1},c_{i}}\cdot (1-p_{i-1})

由于上次课换与不换的情况不能同时出现,且题目要求求期望最小值,所以两者取 \min 即可。

交换

当前课交换的情况与不交换的情况的转移方程推导类似(大力分讨),希望读者自己推导,故不再赘述,直接放出结果:

dp_{i,j,1}=\min(dp_{i-1,j-1,0}+dis_{c_{i-1},c_{i}}\cdot (1-p_{i-1})+dis_{c_{i-1},d_{i}}p_{i-1},\\ dp_{i-1,j-1,1}+\\ dis_{c_{i-1},d_i}\cdot(1-p_{i-1})p_i+dis_{d_{i-1},c_i}p_{i-1}\cdot(1-p_i)+\\dis_{c_{i-1},c_i}\cdot(1-p_{i-1})\cdot(1-p_i)+dis_{d_{i-1},d_{i}}p_{i-1}p_i

蒟蒻太菜式子丑见谅
注意取第一个式子的条件是 j\geq 1(即至少包含当前课),取第二个式子的条件是 j\geq 2(即至少包含当前课和上一节课)。

然后就可以愉快地 AC 啦!

代码

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define dl double)
inline int read() {
    int x = 0, f = 1;
    char ch;
    while ((ch = getchar()) < 48 || ch > 57)if (ch == '-')f = -1;
    while (ch >= 48 && ch <= 57)x = x * 10 + ch - 48, ch = getchar();
    return x * f;
}
char __sta[1009], __len;
inline void write(int x, bool bo) {
    if (x < 0)putchar('-'), x = -x;
    do __sta[++__len] = x % 10 + 48, x /= 10;
    while (x);
    while (__len)putchar(__sta[__len--]);
    putchar(bo ? '\n' : ' ');
}
const ll V = 309, N = 2009, INF = 1e9+3, E = 90009;
ll n, m, v, e;
ll c[N], d[N];
dl p[N];
dl dp[N][N][2];//1换0不换
ll dis[V][V];
inline void init() {
    for (int i = 0; i <= v; i++)
        dis[i][i] = dis[0][i] = 0;
    for (int k = 1; k <= v; k++) {
        dis[0][k] = 0;
        for (int j = 1; j <= v; j++) {
            for (int i = 1; i <= v; i++) {
                if (i != k && j != k)dis[i][j] = min(dis[i][j], dis[i][k] + dis[k][j]);
            }
        }
    }
}
void solve() {
    for (int i = 1; i <= n; i++) {
        for (int j = 0; j <= n; j++)
            dp[i][j][0] = dp[i][j][1] = INF;
    }
    for (ll i = 1; i <= n; i++) {
        for (int j = 0; j <= min(i, m); j++) {
            dp[i][j][0] = min(dp[i - 1][j][0] + dis[c[i - 1]][c[i]], dp[i - 1][j][1] + dis[d[i - 1]][c[i]] * p[i - 1] + dis[c[i - 1]][c[i]] * (1 - p[i - 1]));
            if (j > 0) {
                dp[i][j][1] = dp[i - 1][j - 1][0] + dis[c[i - 1]][d[i]] * p[i] + dis[c[i - 1]][c[i]] * (1 - p[i]);
                if (j > 1) {
                    dp[i][j][1] = min(dp[i][j][1], dp[i - 1][j - 1][1] + dis[c[i - 1]][c[i]] * (1 - p[i - 1]) * (1 - p[i]) +
                                                                         dis[c[i - 1]][d[i]] * (1 - p[i - 1]) * p[i] +
                                                                         dis[d[i - 1]][c[i]] * p[i - 1] * (1 - p[i]) +
                                                                         dis[d[i - 1]][d[i]] * p[i - 1] * p[i]);
                }
            }
        }
    }
}
int main() {
    n = read(), m = read(), v = read(), e = read();
    for (int i = 1; i <= n; i++)
        c[i] = read();
    for (int i = 1; i <= n; i++)
        d[i] = read();
    for (int i = 1; i <= n; i++)
        scanf("%lf", &p[i]);
    memset(dis, 0x3f, sizeof(dis));
    for (int i = 1; i <= e; i++) {
        ll u = read(), p = read(), w = read();
        dis[u][p] = dis[p][u] = min(w, dis[p][u]);
    }
    init();
    solve();
    dl ans = INF;
    for (int i = 0; i <= m; i++)
        ans = min(ans, min(dp[n][i][1], dp[n][i][0]));
    printf("%.2lf\n", ans);
//  for(int i=1;i<=n;i++){
//  for(int j=1;j<=i;j++){
//  cout<<"dhw "<<i<<' '<<j<<endl;
//  cout<<dp[i][j][0]<<' '<<dp[i][j][1]<<endl;
//  }
//  }
    return 0;
}