P11627 [迷宫寻路 Round 3] 游戏

· · 题解

思路:

容易发现在树上,必须经过点 t 的最短路径长度 \operatorname{dist}(u, v) = \operatorname{dis}(u, t) + \operatorname{dis}(v, t)

故有:

\begin{aligned} \sum_{u = 1}^n \sum_{v = 1}^n \operatorname{dist}(u, v) &= \sum_{u = 1}^n \sum_{v = 1}^n \operatorname{dis}(u, t) + \operatorname{dis}(v, t) \\ &= 2n \sum_{u = 1}^n \operatorname{dis}(u, t) \end{aligned}

故我们只需要最大化:

\sum_{u = 1}^n \operatorname{dis}(u, t)

钦定 t 为跟,则原式化为:

\sum_{u = 1}^n dep_t

容易预处理出 cnt_i 表示 i 这条边被经过的次数,设 w_i 表示边权,则要最大化:

\sum_{i = 1}^{n - 1} cnt_i w_i

由于 w 是一个排列,故考虑贪心即可,按照 cnt 的值从大到小分配 (n - 1) \to 1

这样就可以做到 O(N^2)

考虑 cnt_i 本质是什么?设 i 这条边为 (u, v),且以 t 为根时 v 较深一些,故有:

cnt_i = siz_v

故有式子为:

\sum_{i = 1}^{n - 1} siz_{v_i} w_i

由于我们需要枚举 t,故考虑换根 dp。

设从 fa_u 换到 u,发现修改了 siz 的点很少,只有 u, fa_u 这两个点;其中 siz_{fa_u} \gets siz_{fa_u} - siz_u, siz_u \gets n

即相当于在原来 \{siz_i\} 集合中删除 siz_u,加入新的 siz'_{fa_u}

故我们现在只需要做单点修改,全局下面式子的最大值:

\sum_{i = 1}^{n - 1} siz_{v_i} w_i

考虑维护值域线段树 [1, n],对于每个区间 [l, r] 维护有多少个 siz_i 在这里面(设有 cnt 个),以及这些 siz 的和 sum,以及考虑 w1 \sim cnt 的 排列时的最大值。

合并两个区间 cnt_l, sum_l, ans_l, cnt_r, sum_r, ans_r 时,显然有:

cnt = cnt_l + cnt_r sum = sum_l + sum_r

对于右区间的 w,显然要由 1 \sim cnt_r 的排列变为 cnt_l + 1 \sim cnt_l + cnt_r + 1 的排列,即 w 整体添加了 cnt_l,故:

ans = ans_l + ans_r + cnt_l sum_r

时间复杂度为 O(N \log N)

总体认为下位紫,最后我认为并不好想到值域线段树的做法,比较显然的应该是楼下老哥的平衡树做法。

完整代码:

#include<bits/stdc++.h>
#define lowbit(x) x & (-x)
#define pi pair<ll, ll>
#define ls(k) k << 1
#define rs(k) k << 1 | 1
#define fi first
#define se second
using namespace std;
typedef __int128 __;
typedef long double lb;
typedef double db;
typedef unsigned long long ull;
typedef long long ll;
const int N = 1e5 + 10;
inline ll read(){
    ll x = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9'){
        if(c == '-')
          f = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9'){
        x = (x << 1) + (x << 3) + (c ^ 48);
        c = getchar();
    }
    return x * f;
}
inline void write(ll x){
    if(x < 0){
        putchar('-');
        x = -x;
    }
    if(x > 9)
      write(x / 10);
    putchar(x % 10 + '0');
}
struct St{
    int data;
    int id;
    inline bool operator<(const St&rhs)const{
        if(data ^ rhs.data)
          return data > rhs.data;
        return id > rhs.id;
    }
}A[N];
struct Node{
    int l, r;
    int cnt;
    ll sum, ans;
}X[N << 2];
ll ans;
int n, id;
int siz[N], p[N], w[N];
vector<pair<int, int>> E[N];
inline void add(int u, int v, int id){
    E[u].push_back({v, id});
    E[v].push_back({u, id});
}
inline void pushup(int k){
    X[k].cnt = X[k << 1].cnt + X[k << 1 | 1].cnt;
    X[k].sum = X[k << 1].sum + X[k << 1 | 1].sum;
    X[k].ans = X[k << 1].ans + X[k << 1 | 1].ans + 1ll * X[k << 1].cnt * X[k << 1 | 1].sum;
}
inline void build(int k, int l, int r){
    X[k].l = l, X[k].r = r;
    if(l == r)
      return ;
    int mid = (l + r) >> 1;
    build(k << 1, l, mid);
    build(k << 1 | 1, mid + 1, r);
}
inline void update(int k, int i, int v){
    if(X[k].l == i && i == X[k].r){
        X[k].cnt += v;
        X[k].sum += v * i;
        if(v == 1)
          X[k].ans += 1ll * X[k].cnt * i;
        else
          X[k].ans -= 1ll * (X[k].cnt + 1) * i;
        return ;
    }
    int mid = (X[k].l + X[k].r) >> 1;
    if(i <= mid)
      update(k << 1, i, v);
    else
      update(k << 1 | 1, i, v);
    pushup(k);
}
inline void dfs1(int u, int fa){
    siz[u] = 1;
    for(auto t : E[u]){
        int v = t.fi, w = t.se;
        if(v == fa)
          continue;
        p[v] = w;
        dfs1(v, u);
        siz[u] += siz[v];
    }
}
inline void dfs2(int u, int fa){
    for(auto t : E[u]){
        int v = t.fi;
        if(v == fa)
          continue;
        int pre = siz[v];
        update(1, siz[v], -1);
        update(1, n - siz[v], 1);
        siz[u] -= siz[v];
        siz[v] = n;
        if(X[1].ans > ans){
            ans = X[1].ans;
            id = v;
        }
        else if(X[1].ans == ans)
          id = min(id, v);
        dfs2(v, u);
        siz[v] = pre;
        siz[u] = n;
        update(1, n - siz[v], -1);
        update(1, siz[v], 1);
    }
}
int main(){
    n = read();
    for(int u, v, i = 1; i < n; ++i){
        u = read(), v = read();
        add(u, v, i);
    }
    build(1, 1, n);
    dfs1(1, 1);
    for(int i = 2; i <= n; ++i)
      update(1, siz[i], 1);
    ans = X[1].ans;
    dfs2(1, 1);
    write(2ll * n * ans);
    putchar('\n');
    write(id);
    putchar('\n');
    dfs1(id, id);
    int l = 0;
    for(int i = 1; i <= n; ++i){
        if(i == id)
          continue;
        A[++l] = {siz[i], p[i]};
    }
    sort(A + 1, A + n);
    for(int i = 1; i < n; ++i)
      w[A[i].id] = n - i;
    for(int i = 1; i < n; ++i){
        write(w[i]);
        putchar(' ');
    }
    return 0;
}