P1850 [NOIP 2016 提高组] 换教室题解
思路
看到期望就想到 dp。
题目中两个教室间的距离可以用 floyd 快速求出。
设
转移方程
设
不交换
先考虑当前课不换的情况。
此时只有两种情况:上一次课换或不换。
对于上一次课不换的情况,显然有:
对于上一次课换的情况,由于上一次换课不一定成功,所以需要分类讨论。
当上一次换课成功时,两者距离为
于是有:
由于上次课换与不换的情况不能同时出现,且题目要求求期望最小值,所以两者取
交换
当前课交换的情况与不交换的情况的转移方程推导类似(大力分讨),希望读者自己推导,故不再赘述,直接放出结果:
(蒟蒻太菜式子丑见谅)
注意取第一个式子的条件是
然后就可以愉快地 AC 啦!
代码
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define dl double)
inline int read() {
int x = 0, f = 1;
char ch;
while ((ch = getchar()) < 48 || ch > 57)if (ch == '-')f = -1;
while (ch >= 48 && ch <= 57)x = x * 10 + ch - 48, ch = getchar();
return x * f;
}
char __sta[1009], __len;
inline void write(int x, bool bo) {
if (x < 0)putchar('-'), x = -x;
do __sta[++__len] = x % 10 + 48, x /= 10;
while (x);
while (__len)putchar(__sta[__len--]);
putchar(bo ? '\n' : ' ');
}
const ll V = 309, N = 2009, INF = 1e9+3, E = 90009;
ll n, m, v, e;
ll c[N], d[N];
dl p[N];
dl dp[N][N][2];//1换0不换
ll dis[V][V];
inline void init() {
for (int i = 0; i <= v; i++)
dis[i][i] = dis[0][i] = 0;
for (int k = 1; k <= v; k++) {
dis[0][k] = 0;
for (int j = 1; j <= v; j++) {
for (int i = 1; i <= v; i++) {
if (i != k && j != k)dis[i][j] = min(dis[i][j], dis[i][k] + dis[k][j]);
}
}
}
}
void solve() {
for (int i = 1; i <= n; i++) {
for (int j = 0; j <= n; j++)
dp[i][j][0] = dp[i][j][1] = INF;
}
for (ll i = 1; i <= n; i++) {
for (int j = 0; j <= min(i, m); j++) {
dp[i][j][0] = min(dp[i - 1][j][0] + dis[c[i - 1]][c[i]], dp[i - 1][j][1] + dis[d[i - 1]][c[i]] * p[i - 1] + dis[c[i - 1]][c[i]] * (1 - p[i - 1]));
if (j > 0) {
dp[i][j][1] = dp[i - 1][j - 1][0] + dis[c[i - 1]][d[i]] * p[i] + dis[c[i - 1]][c[i]] * (1 - p[i]);
if (j > 1) {
dp[i][j][1] = min(dp[i][j][1], dp[i - 1][j - 1][1] + dis[c[i - 1]][c[i]] * (1 - p[i - 1]) * (1 - p[i]) +
dis[c[i - 1]][d[i]] * (1 - p[i - 1]) * p[i] +
dis[d[i - 1]][c[i]] * p[i - 1] * (1 - p[i]) +
dis[d[i - 1]][d[i]] * p[i - 1] * p[i]);
}
}
}
}
}
int main() {
n = read(), m = read(), v = read(), e = read();
for (int i = 1; i <= n; i++)
c[i] = read();
for (int i = 1; i <= n; i++)
d[i] = read();
for (int i = 1; i <= n; i++)
scanf("%lf", &p[i]);
memset(dis, 0x3f, sizeof(dis));
for (int i = 1; i <= e; i++) {
ll u = read(), p = read(), w = read();
dis[u][p] = dis[p][u] = min(w, dis[p][u]);
}
init();
solve();
dl ans = INF;
for (int i = 0; i <= m; i++)
ans = min(ans, min(dp[n][i][1], dp[n][i][0]));
printf("%.2lf\n", ans);
// for(int i=1;i<=n;i++){
// for(int j=1;j<=i;j++){
// cout<<"dhw "<<i<<' '<<j<<endl;
// cout<<dp[i][j][0]<<' '<<dp[i][j][1]<<endl;
// }
// }
return 0;
}