P11885 [RMI 2024] 跑酷 / Jump Civilization

· · 题解

宝宝题。

思路:

考虑倒序维护最短路,即每次 i + 1 \gets i,动态维护 dis_j 表示 i \to j 的最短路,答案即 dis_j \le k 的数量。

由于题目的性质,容易得到 \forall u \in (i, v_i), v_u \le v_i,这说明,从 i 出发到 (i, v_i) 内的点一定走 i \to i + 1 这条边,到 [v_i, n] 内的点一定走 i \to v_i 这条边。

我们设 dis'_j 表示 i + 1 \to j 的最短路,则 dis 相对 dis' 的变化是:

所以我们需要维护区间加减,查询全局有多少个数 \le k(初始每个 dis_i = +\inf)。

考虑分块,维护块内排序后的下标,重构时归并即可。

时间复杂度为 O(n \sqrt{n \log n})

完整代码:

 #include<bits/stdc++.h>
#define ls(k) k << 1
#define rs(k) k << 1 | 1
#define fi first
#define se second
#define popcnt(x) __builtin_popcount(x)
#define open(s1, s2) freopen(s1, "r", stdin), freopen(s2, "w", stdout);
using namespace std;
typedef __int128 __;
typedef long double lb;
typedef double db;
typedef unsigned long long ull;
typedef long long ll;
bool Begin;
const int N = 3e5 + 10;
inline ll read(){
    ll x = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9'){
        if(c == '-')
          f = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9'){
        x = (x << 1) + (x << 3) + (c ^ 48);
        c = getchar();
    }
    return x * f;
}
inline void write(ll x){
    if(x < 0){
        putchar('-');
        x = -x;
    }
    if(x > 9)
      write(x / 10);
    putchar(x % 10 + '0');
}
int n, k, now, t;
int v[N], h[N], a[N], b[N], dis[N], ans[N], id[N], L[N], R[N], tag[N], f[N];
bool vis[N];
inline void update(int u, int w){
    bool flag = 1;
    int x = id[u], mid = 0, cnt = 0;
    for(int i = L[x]; i <= R[x]; ++i){
        if(f[i] == u){
            mid = i;
            break;
        }
    }
    dis[u] = w;
    for(int i = L[x]; i < mid; ++i){
        if(flag && dis[u] <= dis[f[i]])
          h[++cnt] = u, flag = 0;
        h[++cnt] = f[i];
    }
    for(int i = mid + 1; i <= R[x]; ++i){
        if(flag && dis[u] <= dis[f[i]])
          h[++cnt] = u, flag = 0;
        h[++cnt] = f[i];          
    }
    for(int i = L[x]; i <= R[x]; ++i)
      f[i] = h[i - L[x] + 1];
}
inline void add(int l, int w){
    int x = id[l];
    for(int i = l; i <= R[x]; ++i){
        vis[i] = 1;
        dis[i] += w;
    }
    int c0 = 0, c1 = 0;
    for(int i = L[x]; i <= R[x]; ++i){
        if(vis[f[i]])
          a[++c0] = f[i];
        else
          b[++c1] = f[i];
    }
    int cnt = 0, i = 1, j = 1;
    while(i <= c0 && j <= c1){
        if(dis[a[i]] <= dis[b[j]])
          h[++cnt] = a[i++];
        else
          h[++cnt] = b[j++];
    }
    while(i <= c0)
      h[++cnt] = a[i++];
    while(j <= c1)
      h[++cnt] = b[j++];
    for(int i = L[x]; i <= R[x]; ++i)
      f[i] = h[i - L[x] + 1];
    for(int i = l; i <= R[x]; ++i)
      vis[i] = 0;
    for(int i = x + 1; i <= id[n]; ++i)
      tag[i] += w;
}
inline int get(int x, int w){
    w -= tag[x];
    if(dis[f[L[x]]] > w)
      return 0;
    if(dis[f[R[x]]] <= w)
      return R[x] - L[x] + 1;
    int l = L[x], r = R[x], ans = 0;
    while(l <= r){
        int mid = (l + r) >> 1;
        if(dis[f[mid]] <= w){
            ans = mid;
            l = mid + 1;
        }
        else
          r = mid - 1;
    }
    return ans - L[x] + 1;
}
inline int ask(int l, int w){
    int x = id[l], ans = 0;
    for(int i = l; i <= R[x]; ++i)
      if(dis[i] + tag[x] <= w)
        ++ans;
    for(int i = x + 1; i <= id[n]; ++i)
      ans += get(i, w);
    return ans;
}
bool End;
int main(){
    n = read(), k = read();
    for(int i = 1; i < n; ++i){
        v[i] = read();
        f[i] = i;
    }
    f[n] = n;
    t = __builtin_sqrt(n) / 2 + 1;
    for(int i = 1; i <= n; ++i)
      id[i] = (i - 1) / t + 1;
    for(int i = 1; i <= (n + t - 1) / t; ++i){
        L[i] = (i - 1) * t + 1;
        R[i] = min(i * t, n);
    }
    ans[n] = 1;
    for(int i = n - 1; i >= 1; --i){
        int t = dis[v[i]] + tag[id[v[i]]] + now;
        ++now;
        update(i, - now - tag[id[i]]);
        add(v[i], -t);
        ans[i] = ask(i, k - now);
    }
    for(int i = 1; i <= n; ++i){
        write(ans[i]);
        putchar(' ');
    }
    //cerr << '\n' << abs(&Begin - &End) / 1048576 << "MB";
    return 0;
}