题解:P8145 [JRKSJ R4] kth

· · 题解

很牛的题啊,感觉自己的思维需要更深一层了。

Hint

::::info[Hint 1] 题意可以转化成在一个排列上走。那么考虑走的过程中需要什么?以及如何实现这个走的过程。

:::success[Answer 1] 走的时候肯定是需要从 ij 个点(包括 i 这个点)的方案数的,设其为 f_{i, j}。走的过程就是每一次判断走小的那一边是否方案数 \ge k,若满足,就走小的那边;否则,走大的那边,且 k 减去 f_{mx, j}。 :::

::::

::::info[Hint 2] 考虑 n \ge 3 时,从任意一个点走至少多少步,都满足方案数 \ge k

:::success[Answer 2] 答案是 j \ge 2\log_2 k,且仅在 n = 3 时取到下界。 :::

::::

::::info[Hint 3] 考虑是不是可以利用 Hint 2 优化一下 dp 的状态数。 ::::

::::info[Hint 4] 对于任意的 j \le \min\{i, n - i + 1\}f_{i, j} 的值是多少?

:::success[Answer 4] 答案是 2^{j - 1}。 :::

::::

::::info[Hint 5] 能否利用 Hint 4 再次优化一下 dp 状态数。 ::::

Solution

题意就转化成在一个排列上任取起点开始走 m 步,求出所有路径中字典序第 k 小的路径的权值和。

首先自然的,想到 f_{i, j} 表示从 ij 个点(包括 i 这个点)的方案数。求最终答案是容易的,不会的可以看下 Hint 1。时间复杂度 O(nm)

考虑优化。我们发现,在 n(n \ge 3) 时,对于任意的正整数 k,都有在 j \ge 2\log_2 k 时(且仅在 n = 3 时取到下界),f_{i, j} \ge k。那就可以 O(n\log_2 k) 求出所有对我们有用的 f 值。设 B = \lceil2\log_2 k\rceil,考虑前 k - B 步怎么走。若 k - B > n,那么最后一定在某一个 ii + 1 之间交替走;若 k - B \le n,那就直接暴力走即可。所以我们直接 O(n) 就可以求出 k - B 步的答案以及最后的停留位置。后面 B 步的过程和一开始暴力做法一样。时间复杂度 O(n\log_2 k)。竟然还没完。

接下来,考虑现在瓶颈是在求 f 上,继续考虑如何优化这个过程。简单但或许并不是那么显然的想到,对于 i \in \left[B, n - B + 1\right],其 f_{i, j}(1 \le j \le B) 就直接是 2^{j - 1}。所以我们只用单独对 i \in \left[1, B\right)\cup\left(n - B + 1, n\right]f 值即可。

别忘了特判一下 n = 1n = 2 的情况。

最终,时间复杂度 O(n + \log_2^2k)

:::success[AC Code]

#include <bits/stdc++.h>
using namespace std;
#define x first
#define y second
#define mp(Tx, Ty) make_pair(Tx, Ty)
#define For(Ti, Ta, Tb) for(auto Ti = (Ta); Ti <= (Tb); Ti++)
#define Dec(Ti, Ta, Tb) for(auto Ti = (Ta); Ti >= (Tb); Ti--)
#define debug(...) fprintf(stderr, __VA_ARGS__)
#define range(Tx) begin(Tx),end(Tx)
const int N = 2e7 + 5;
const __int128 INF = 1e30;
int n, K = 125, p[N], mp[N], W[N];
long long m, k;
__int128 f[300][130], fac2[130];
__int128 get(int x, long long k) {
    if (x <= 0 || x > n) return 0;
    if (k > K) return INF;
    if (K < x && x < n - K + 1) return fac2[k - 1];
    return f[mp[x]][k];
}
int main() {
    cin.tie(nullptr)->sync_with_stdio(false);
    fac2[0] = 1;
    For(i, 1, 125) fac2[i] = fac2[i - 1] * 2;
    cin >> n >> m >> k;
    For(i, 1, n) cin >> p[i];
    For(i, 1, n) W[p[i]] = i;
    int w = 0;
    For(i, 1, n) if (p[i] == 1) w = i;
    if (n == 1) { cout << -1; return 0; }
    if (n == 2) {
        if (k > 2) cout << -1;
        else if (k == 1) cout << (unsigned int)((m + 1) / 2 + m / 2 * 2);
        else cout << (unsigned int)((m + 1) / 2 * 2 + m / 2);
        return 0;
    }
    p[0] = p[n + 1] = 1e9;
    vector<int> G;
    int L = min(n, K);
    For(i, 1, L) G.push_back(i), mp[i] = G.size() - 1;
    For(i, max(L + 1, n - K + 1), n) G.push_back(i), mp[i] = G.size() - 1;
    int sz = G.size();
    For(i, 0, sz - 1) f[i][1] = 1;
    For(i, 2, K) For(j, 0, sz - 1) f[j][i] = get(G[j] - 1, i - 1) + get(G[j] + 1, i - 1);
    unsigned int ans = 0;
    w = 0;
    For(i, 1, n) {
        if (get(W[i], m) < k) k -= get(W[i], m);
        else { w = W[i]; break; }
    }
    if (w == 0) { cout << -1; return 0; }
    int minn = (p[w - 1] < p[w + 1] ? w - 1 : w + 1);
    m--;
    int S = w;
    long long yu = m - K;
    ans = p[S];
    if (yu > 0) {
        while (yu) {
            int x = (p[minn - 1] < p[minn + 1] ? minn - 1 : minn + 1);
            ans += p[minn];
            yu--;
            if (x == w || yu == 0) break;
            w = minn, minn = x;
        }
        ans += (unsigned int)p[w] * (unsigned int)((yu + 1) / 2) 
            +  (unsigned int)p[minn] * (unsigned int)(yu / 2);
        if (yu & 1) S = w;
        else S = minn;
    }
    m = min(m, 1ll * K);
    Dec(i, m, 1) {
        if (S == 1) ans += p[++S];
        else if (S == n) ans += p[--S];
        else {
            if (p[S - 1] < p[S + 1]) {
                if (get(S - 1, i) >= k) ans += p[--S];
                else k -= get(S - 1, i), ans += p[++S];
            } else {
                if (get(S + 1, i) >= k) ans += p[++S];
                else k -= get(S + 1, i), ans += p[--S];
            }
        }
    }
    cout << ans;
    return 0;
}

:::