题解:AT_abc369_e [ABC369E] Sightseeing Tour

· · 题解

做法一

定义 dp_{s,i,0/1} 表示当前走过的桥的状态是 s,最后一次路过的桥是 i0/1 表示最后一次过桥的方向(u_i \rightarrow v_i 或者 v_i \rightarrow u_i)。转移时先枚举一座 s 中的桥 x,从去掉该桥的状态中再枚举一个最后一次路过的桥 y,由 dp_{s-2^x,y,0/1} 转移即可。这个做法不需要枚举全排列

时间复杂度为:O(n^3+m+q\cdot2^k\cdot k^2)

Submission #57315524 - AtCoder Beginner Contest 369

#include <bits/stdc++.h>

using namespace std;

int n, m;
long long g[405][405];
int u[200005], v[200005], w[200005];
int q;
int b[5];
long long dp[1 << 5][5][2];

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m;
    for (int i = 1; i <= n; i++)
        for (int j = i + 1; j <= n; j++)
            g[i][j] = g[j][i] = LLONG_MAX / 2;
    for (int i = 1; i <= m; i++) {
        cin >> u[i] >> v[i] >> w[i];
        g[u[i]][v[i]] = min<long long>(g[u[i]][v[i]], w[i]);
        g[v[i]][u[i]] = min<long long>(g[v[i]][u[i]], w[i]);
    }
    for (int k = 1; k <= n; k++)
        for (int i = 1; i <= n; i++)
            for (int j = 1; j <= n; j++)
                g[i][j] = min(g[i][j], g[i][k] + g[k][j]);
    for (cin >> q; q--; ) {
        int k;
        cin >> k;
        for (int i = 0; i < k; i++)
            cin >> b[i];
        for (int i = 1; i < (1 << k); i++)
            for (int j = 0; j < k; j++)
                dp[i][j][0] = dp[i][j][1] = LLONG_MAX / 2;
        for (int i = 1; i < (1 << k); i++) {
            for (int x = 0; x < k; x++) {
                if ((i & (1 << x)) != 0) {
                    if ((i & (i - 1)) == 0) {
                        dp[i][x][0] = g[1][v[b[x]]] + w[b[x]];
                        dp[i][x][1] = g[1][u[b[x]]] + w[b[x]];
                        continue;
                    }
                    for (int y = 0; y < k; y++) {
                        if ((i & (1 << y)) != 0 && x != y) {
                            dp[i][x][0] = min(dp[i][x][0], min(dp[i ^ (1 << x)][y][0] + g[u[b[y]]][v[b[x]]], dp[i ^ (1 << x)][y][1] + g[v[b[y]]][v[b[x]]]) + w[b[x]]);
                            dp[i][x][1] = min(dp[i][x][1], min(dp[i ^ (1 << x)][y][0] + g[u[b[y]]][u[b[x]]], dp[i ^ (1 << x)][y][1] + g[v[b[y]]][u[b[x]]]) + w[b[x]]);
                        }
                    }
                }
            }
        }
        long long res = LLONG_MAX / 2;
        for (int i = 0; i < k; i++)
            res = min(res, min(dp[(1 << k) - 1][i][0] + g[u[b[i]]][n], dp[(1 << k) - 1][i][1] + g[v[b[i]]][n]));
        cout << res << '\n';
    }
    return 0;
}

做法二

全排列枚举过桥的顺序,然后模拟即可。看到很多人都又套了一个 O(2^k) 的状压,实际上没有必要。与做法一类似,记录上一次过桥的方向及答案即可做到 O(n^3+m+q\cdot k!\cdot k)

Submission #57335924 - AtCoder Beginner Contest 369

#include <bits/stdc++.h>

using namespace std;

int n, m;
long long g[405][405];
int u[200005], v[200005], w[200005];
int q;
int b[5];

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m;
    for (int i = 1; i <= n; i++)
        for (int j = i + 1; j <= n; j++)
            g[i][j] = g[j][i] = LLONG_MAX / 2;
    for (int i = 1; i <= m; i++) {
        cin >> u[i] >> v[i] >> w[i];
        g[u[i]][v[i]] = min<long long>(g[u[i]][v[i]], w[i]);
        g[v[i]][u[i]] = min<long long>(g[v[i]][u[i]], w[i]);
    }
    for (int k = 1; k <= n; k++)
        for (int i = 1; i <= n; i++)
            for (int j = 1; j <= n; j++)
                g[i][j] = min(g[i][j], g[i][k] + g[k][j]);
    for (cin >> q; q--; ) {
        int k;
        cin >> k;
        for (int i = 0; i < k; i++)
            cin >> b[i];
        sort(b, b + k);
        long long res = LLONG_MAX;
        do {
            int pu = 1, pv = 1;
            long long ansu = 0, ansv = 0;
            for (int i = 0; i < k; i++) {
                const long long preu = ansu, prev = ansv;
                ansu = min(preu + g[pu][v[b[i]]], prev + g[pv][v[b[i]]]) + w[b[i]];
                ansv = min(preu + g[pu][u[b[i]]], prev + g[pv][u[b[i]]]) + w[b[i]];
                pu = u[b[i]];
                pv = v[b[i]];
            }
            res = min(res, min(ansu + g[pu][n], ansv + g[pv][n]));
        } while (next_permutation(b, b + k));
        cout << res << '\n';
    }
    return 0;
}

总结

实际上这个记录最后位置的优化从 TSP 问题出发可以很自然地想到,写过 Hamilton 路径/环路的话应该是比较显然的。

Update: 套了个快读以后拿到了最优解榜一,4 秒时限只跑了 60 毫秒。

Submission #57336507 - AtCoder Beginner Contest 369

#include <bits/stdc++.h>

using namespace std;

inline char nextchar() {
    static char buf[1 << 20], *p = buf, *q = buf;
    return p == q && (q = (p = buf) + fread(buf, 1, 1 << 20, stdin), p == q) ? EOF : *p++;
}

template <typename T, typename = enable_if_t<is_integral_v<T>>>
inline void read(T& x) {
    x = 0;
    char c = nextchar();
    for (; !isdigit(c); c = nextchar()) ;
    for (;  isdigit(c); c = nextchar()) x = x * 10 + c - '0';
}

template <typename T, typename... other>
inline void read(T& x, other&... y) {
    read(x);
    read(y...);
}

int n, m;
long long g[405][405];
int u[200005], v[200005], w[200005];
int q;
int b[5];

int main() {
    read(n, m);
    for (int i = 1; i <= n; i++)
        for (int j = i + 1; j <= n; j++)
            g[i][j] = g[j][i] = LLONG_MAX / 2;
    for (int i = 1; i <= m; i++) {
        read(u[i], v[i], w[i]);
        g[u[i]][v[i]] = min<long long>(g[u[i]][v[i]], w[i]);
        g[v[i]][u[i]] = min<long long>(g[v[i]][u[i]], w[i]);
    }
    for (int k = 1; k <= n; k++)
        for (int i = 1; i <= n; i++)
            for (int j = 1; j <= n; j++)
                g[i][j] = min(g[i][j], g[i][k] + g[k][j]);
    for (read(q); q--; ) {
        int k;
        read(k);
        for (int i = 0; i < k; i++)
            read(b[i]);
        sort(b, b + k);
        long long res = LLONG_MAX;
        do {
            int pu = 1, pv = 1;
            long long ansu = 0, ansv = 0;
            for (int i = 0; i < k; i++) {
                const long long preu = ansu, prev = ansv;
                ansu = min(preu + g[pu][v[b[i]]], prev + g[pv][v[b[i]]]) + w[b[i]];
                ansv = min(preu + g[pu][u[b[i]]], prev + g[pv][u[b[i]]]) + w[b[i]];
                pu = u[b[i]];
                pv = v[b[i]];
            }
            res = min(res, min(ansu + g[pu][n], ansv + g[pv][n]));
        } while (next_permutation(b, b + k));
        printf("%lld\n", res);
    }
    return 0;
}