CF1904E 题解

· · 题解

首先我们有经典结论:距离点 x 最远的点是树的直径端点之一。那么我们每次只需要维护连通子图的直径。

考虑每次不能到达的点是一段连续的 dfn 区间,容易求出能到达的点范围也是 \mathcal O(k) 个连续的 dfn 区间。用线段树维护区间点的直径,把每一次询问可以到的点合并起来得到直径即可。

怎么用线段树维护区间点的直径?这需要结论:两个连通块合在一起的直径端点一定是原本连通块直径的端点。据此枚举端点,取最长的保留即可。

尽量写个 \mathcal O(1) 的 LCA 吧,否则合并会多带个 \log。用 st 表 LCA,时间复杂度 \mathcal O(n \log n + \sum k \log n)

下面给出的代码使用倍增 LCA,时间复杂度 \mathcal O(n \log^2 n + \sum k \log ^2 n),常数巨大,跑了 3.8 秒。

upd:倍增被 hack TLE 了,换成 st 表了。

// LUOGU_RID: 139275498
/**
 *    author: sunkuangzheng
 *    created: 12.12.2023 07:52:35
**/
#include<bits/stdc++.h>
#ifdef DEBUG_LOCAL
#include <mydebug/debug.h>
debug_helper deg;
#endif
using namespace std;
const int N = 5e5+5;
int n,fa[N][23],dep[N],dfn[N],nfd[N],st[N][22],tot,siz[N],m,u,v,k,x;vector<int> g[N];
int cmp(int u,int v){return dfn[u] < dfn[v] ? u : v;}
inline void dfs(int u,int f){
    fa[u][0] = f,dep[u] = dep[f] + 1,dfn[u] = ++tot,nfd[tot] = u,siz[u] = 1,st[dfn[u]][0] = f;
    for(int i = 1;i <= 22;i ++) fa[u][i] = fa[fa[u][i-1]][i-1];
    for(int v : g[u]) if(v != f) dfs(v,u),siz[u] += siz[v];
}inline int lca(int u,int v){
    if(u == v) return u;
    if((u = dfn[u]) > (v = dfn[v])) swap(u,v);
    int k = __lg(v - u);
    return cmp(st[u+1][k],st[v-(1<<k)+1][k]);
}int kfa(int u,int k){
    for(int i = 22;i >= 0;i --) if((k >> i) & 1) u = fa[u][i];
    return u;
}struct node{int u,v;}t[N*4];
inline int dis(node x){return dep[x.u] + dep[x.v] - 2 * dep[lca(x.u,x.v)];}
inline node mg(node a,node b){
    if(!a.u) return b; if(!b.u) return a;
    vector<int> acc = {a.u,b.u,a.v,b.v};int mx = -1;node ans = {0,0};
    for(int i : acc) for(int j : acc) if(auto x = ((node){i,j});dis(x) > mx) mx = dis(x),ans = x;
    return ans;
}inline void build(int s,int l,int r){
    if(l == r) return t[s] = (node){nfd[l],nfd[l]},void();
    int mid = (l + r) / 2;build(s*2,l,mid),build(s*2+1,mid+1,r);
    t[s] = mg(t[s*2],t[s*2+1]);
}inline node qry(int s,int l,int r,int ql,int qr){
    if(ql <= l && r <= qr) return t[s];
    int mid = (l + r) / 2;
    if(qr <= mid) return qry(s*2,l,mid,ql,qr); if(ql > mid) return qry(s*2+1,mid+1,r,ql,qr);
    return mg(qry(s*2,l,mid,ql,qr),qry(s*2+1,mid+1,r,ql,qr));
}int main(){
    ios::sync_with_stdio(0),cin.tie(0);
    cin >> n >> m;vector<pair<int,int>> acc,cjr;
    for(int i = 1;i < n;i ++) cin >> u >> v,g[u].push_back(v),g[v].push_back(u);
    dfs(1,0);
    for(int j = 1;j <= __lg(n);j ++) for(int i = 1;i + (1 << j) - 1 <= n;i ++)
        st[i][j] = cmp(st[i][j-1],st[i+(1<<(j-1))][j-1]);
    build(1,1,n);
    auto add = [&](int l,int r){if(l <= r) acc.emplace_back(l,r);};
    while(m --){
        cin >> u >> k,acc.clear(),cjr.clear();
        while(k --){
            cin >> x;
            if(lca(u,x) == x){
                int v = kfa(u,dep[u] - dep[x] - 1);
                add(1,dfn[v] - 1),add(dfn[v] + siz[v],n);
            }else add(dfn[x],dfn[x] + siz[x] - 1); 
        }sort(acc.begin(),acc.end());
        int lst = 0;node ans = {0,0};
        for(auto [l,r] : acc){
            if(l > lst + 1) cjr.emplace_back(lst + 1,l - 1);
            lst = max(lst,r);
        }if(lst != n) cjr.emplace_back(lst+1,n);
        for(auto [l,r] : cjr) ans = mg(ans,qry(1,1,n,l,r));
        node p = (node){ans.u,u},q = (node){ans.v,u};cout << max(dis(p),dis(q)) << "\n";
    }
}

题外话:不知道为什么官方题解提供的两种解法都用欧拉序,dfs 序不是更好理解吗 qwq。