题解:CF1859F Teleportation in Byteland

· · 题解

Teleportation in Byteland

学校联测搬了这道题,赛时脑抽了,没想到树剖是一款。

1. O(\sum \text{dis}(s,t)\log V)

发现最多只会往下 \log V 次,并且在一个点做完全部的向下必然是不劣的。因此可以考虑枚举向下的次数 k,答案就是由原路径加上另一条分支组成的。设 \text{dis}(s,t,k) 表示点 s 到点 t 在第 k 层的距离,则答案为:

\min_k\min_{p\in S} \text{dis}(s,p,0)+\text{dis}(p,t,k)

继续拆式子:

\min_k\min_{p\in S} \text{dis}(s,p,0)+\text{dis}(p,t,k)\\ =\min_k\min_{q\in \text{Road}(s,t)}\text{dis}(s,q,0)+\text{dis}(q,t,k)+\min_{p\in S} \text{dis}(p,q,0)+\text{dis}(p,q,k)\\

发现 d_{q,k}=\min_{p\in S} \text{dis}(p,q,0)+\text{dis}(p,q,k) 可以通过换根 DP 解决,之后枚举 q 即可。查询距离可以实现预处理 O(n\log n),查询 O(1)。换根 DP 预处理时间复杂度 O(n\log V),询问时间复杂度 O(\sum \text{dis}(s,t)\log V)

2. O(q\log n\log V)

如何优化求式子,考虑树剖。设 rt=\text{lca}(s,t),点 i 所在的树链顶端为 top_i,考虑分类讨论:

u=s 向上跳,若目前 q\in \text{Road}(u,top_u),可得:

\min_k\min_{q\in \text{Road}(u,top_u)}\text{dis}(s,q,0)+\text{dis}(q,t,k)+d_{q,k}\\ =\min_k\min_{q\in \text{Road}(u,top_u)}\text{dis}(s,u,0)+\text{dis}(u,q,0)+\text{dis}(q,top_u,k)+\text{dis}(top_u,t,k)+d_{q,k}\\ =\min_k(\min_{q\in \text{Road}(u,top_u)}\text{dis}(u,q,0)+\text{dis}(q,top_u,k)+d_{q,k})+\text{dis}(s,u,0)+\text{dis}(top_u,t,k)\\

式子的前半部分可以用递推预处理,具体的:

mn1_{k,top_i}=d_{k,top_i}\\ mn1_{k,i}=\min(mn1_{k,fa_i}+\text{wei}(i,fa_i,0),\text{dis}(top_i,i,k)+d_{k,i})

答案即为:

\min_k mn1_{k,u}+\text{dis}(s,u,0)+\text{dis}(top_u,t,k)

预处理 O(n\log V),询问 O(q\log n\log V)

u=t 向上跳,若目前 q\in \text{Road}(u,top_u),可得:

\min_k\min_{q\in \text{Road}(u,top_u)}\text{dis}(s,q,0)+\text{dis}(q,t,k)+d_{q,k}\\ =\min_k\min_{q\in \text{Road}(u,top_u)}\text{dis}(s,top_u,0)+\text{dis}(top_u,q,0)+\text{dis}(q,u,k)+\text{dis}(u,t,k)+d_{q,k}\\ =\min_k(\min_{q\in \text{Road}(u,top_u)}\text{dis}(top_u,q,0)+\text{dis}(q,u,k)+d_{q,k})+\text{dis}(s,top_u,0)+\text{dis}(u,t,k)\\

式子的前半部分可以用递推预处理,具体的:

mn2_{k,top_i}=d_{k,top_i}\\ mn2_{k,i}=\min(mn2_{k,fa_i}+\text{wei}(i,fa_i,k),\text{dis}(top_i,i,0)+d_{k,i})

答案即为:

\min_k mn2_{k,u}+\text{dis}(s,top_u,0)+\text{dis}(u,t,k)

预处理 O(n\log V),询问 O(q\log n\log V)

此时不能使用上述方法计算,因为 \text{Road}(rt,top_{rt}) 会算重。

考虑使用线段树维护 \text{dis}(s,q,0)+\text{dis}(q,t,k)+d_{k,q}

dist_{k,i}=\text{dis}(1,i,k),之后分讨。

\circ$ $q\in \text{Road}(s,rt)

此时 \text{dis}(s,q,0)=\text{dis}(s,rt,0)-(dist_{0,q}-dist_{0,rt})\text{dis}(q,t,k)=\text{dis}(t,rt,k)+dist_{k,q}-dist_{k,rt}

原式变为

\text{dis}(s,rt,0)-(dist_{0,q}-dist_{0,rt})+\text{dis}(t,rt,k)+dist_{k,q}-dist_{k,rt}+d_{k,q}\\ =\text{dis}(s,rt,0)-dist_{0,q}+dist_{0,rt}+\text{dis}(t,rt,k)+dist_{k,q}-dist_{k,rt}+d_{k,q}\\ =(dist_{k,q}-dist_{0,q}+d_{k,q})+dist_{0,rt}-dist_{k,rt}+\text{dis}(s,rt,0)+\text{dis}(t,rt,k)

前一段丢进线段树中,因为 top_q=top_rt,因此 dfn 是连续的,相当于求前段的最小值,容易维护。

\circ$ $q\in \text{Road}(t,rt)

此时 \text{dis}(s,q,0)=\text{dis}(s,rt,0)+dist_{0,q}-dist_{0,rt}\text{dis}(q,t,k)=\text{dis}(t,rt,k)-(dist_{k,q}-dist_{k,rt})

原式变为

\text{dis}(s,rt,0)+dist_{0,q}-dist_{0,rt}+\text{dis}(t,rt,k)-(dist_{k,q}-dist_{k,rt})+d_{k,q}\\ =\text{dis}(s,rt,0)+dist_{0,q}-dist_{0,rt}+\text{dis}(t,rt,k)-dist_{k,q}+dist_{k,rt}+d_{k,q}\\ =(dist_{0,q}-dist_{k,q}+d_{k,q})-dist_{0,rt}+dist_{k,rt}+\text{dis}(s,rt,0)+\text{dis}(t,rt,k)

同理,相当于求前段的最小值,容易维护。

因为对于每个 k 都要开两个线段树维护,因此预处理 O(n\log n\log V),询问 O(q\log n\log V)

综合起来就是大常数的 O(n\log n\log V),可以通过。

AC Code

#include <bits/stdc++.h>
#define ll long long
#define ls (p << 1)
#define rs (p << 1 | 1)
using namespace std;
const ll N = 1e5 + 5 , inf = 1e18 + 7;
int n , tim , q;
struct Seg
{
    ll tr[N << 2];
    void push_up (int p)
    {tr[p] = min (tr[ls] , tr[rs]);}
    void set (int ps , ll w , int p = 1 , int s = 1 , int t = n)
    {
        if (s == t)
        {
            tr[p] = w;
            return ;
        }
        int mid = s + t >> 1;
        if (mid >= ps) set (ps , w , ls , s , mid);
        else set (ps , w , rs , mid + 1 , t);
        push_up (p);
    }
    ll qry (int l , int r , int p = 1 , int s = 1 , int t = n)
    {
        if (l <= s && t <= r) return tr[p];
        int mid = s + t >> 1;
        ll ans = inf;
        if (mid >= l) ans = min (ans , qry (l , r , ls , s , mid));
        if (mid < r) ans = min (ans , qry (l , r , rs , mid + 1 , t));
        return ans;
    }
} Tr1[21] , Tr2[21];
int sz[N] , dfn[N] , cc , lg[N] , Fa[N] , f[N][21] , dep[N];
ll dis[21][N] , d[21][N] , DP1[N] , DP2[N];
vector <ll> mn1[21][N] , mn2[21][N];
bool vis[N];
int top[N] , flr[N] , ds[N] , wei[N];
vector <pair <int , int> > g[N];
ll e = 1 , cur = 0;
ll Ceil (ll x)
{return (x + e - 1) / e;}
int cmp (int x , int y) {return (dep[x] < dep[y] ? x : y);}
void Dfs (int u)
{
    f[dfn[u]][0] = u;
    for (auto it : g[u]) if (it.first != Fa[u]) Dfs (it.first);
}
int get (int l , int r)
{
    int Log = lg[r - l + 1];
    return cmp (f[l][Log] , f[r - (1 << Log) + 1][Log]);
}
int lca (int u , int v)
{
    if (u == v) return u;
    if ((u = dfn[u]) > (v = dfn[v])) swap (u , v);
    return Fa[get (u + 1 , v)];
}
void init (int u)
{
    sz[u] = 1;
    ds[u] = 0;
    dep[u] = dep[Fa[u]] + 1;
    for (auto it : g[u]) if (it.first != Fa[u])
    {
        d[0][it.first] = d[0][u] + it.second;
        Fa[it.first] = u;
        init (it.first);
        sz[u] += sz[it.first];
        if (sz[ds[u]] < sz[it.first]) ds[u] = it.first , wei[u] = it.second;
    }
}
void init2 (int u , int t)
{
    dfn[u] = ++ cc , top[u] = t;
    if (ds[u]) init2 (ds[u] , t) , flr[u] = flr[ds[u]];
    else flr[u] = u;
    for (auto it : g[u]) if (it.first != Fa[u] && it.first != ds[u])
        init2 (it.first , it.first);
}
void dfs1 (int u)
{
    if (vis[u]) DP2[u] = 0;
    else DP2[u] = inf;
    for (auto it : g[u]) if (it.first != Fa[u])
    {
        int v = it.first , w = it.second;
        d[cur][v] = d[cur][u] + Ceil (w);
        w += Ceil (w);
        dfs1 (v);
        DP2[u] = min (DP2[u] , DP2[v] + w);
    }
}
void dfs2 (int u)
{
    dis[cur][u] = min (DP2[u] , DP1[u]);
    ll MIN = inf , LMIN = inf;
    for (auto it : g[u]) if (it.first != Fa[u])
    {
        int v = it.first , w = it.second;
        w += Ceil (w);
        if (DP2[v] + w < MIN) LMIN = MIN , MIN = DP2[v] + w;
        else if (DP2[v] + w < LMIN) LMIN = DP2[v] + w;
    }
    for (auto it : g[u]) if (it.first != Fa[u])
    {
        int v = it.first , w = it.second;
        w += Ceil (w);
        if (DP2[v] + w == MIN) DP1[v] = min (LMIN , DP1[u]) + w;
        else DP1[v] = min (MIN , DP1[u]) + w;
        if (vis[u]) DP1[v] = min (DP1[v] , 1ll * w);
        dfs2 (v);
    }
}
ll Dist (int u , int v , int P)
{return d[P][u] + d[P][v] - (d[P][lca (u , v)] << 1);}
signed main ()
{
    DP1[1] = inf;
    lg[0] = -1;
    for (int i = 1;i < N;i ++) lg[i] = lg[i >> 1] + 1;
    ios::sync_with_stdio (0);
    cin.tie (0) , cout.tie (0);
    int test;
    cin >> test;
    while (test --)
    {
        cin >> n >> tim;
        for (int i = 1;i <= n;i ++) g[i].clear ();
        for (int i = 1;i < n;i ++)
        {
            ll u , v , w;
            cin >> u >> v >> w;
            g[u].push_back ({v , w});
            g[v].push_back ({u , w});
        }
        char x;
        for (int i = 1;i <= n;i ++)
        {
            cin >> x;
            if (x == 49) vis[i] = 1;
            else vis[i] = 0;
        }
        cc = 0;
        init (1);
        init2 (1 , 1);
        Dfs (1);
        for (int i = 1;i <= 20;i ++)
            for (int u = 1;u + (1 << i) - 1 <= n;u ++)
                f[u][i] = cmp (f[u][i - 1] , f[u + (1 << i - 1)][i - 1]);
        e = 1 , cur = 0;
        while ((++ cur) <= 20)
        {
            e = e << 1;
            dfs1 (1);
            dfs2 (1);
        }
        cin >> q;
        int s , t;
        for (int i = 1;i <= n;++ i) if (top[i] == i)
        {
            for (unsigned short j = 1;j <= 20;++ j)
            {
                e = 1 << j;
                mn1[j][i].clear () , mn2[j][i].clear ();
                mn1[j][i].push_back (dis[j][i]);
                mn2[j][i].push_back (dis[j][i]);
                ll pos = i , cc = 0;
                while (ds[pos])
                {
                    mn1[j][i].push_back (min (mn1[j][i][cc] + wei[pos] , Dist (ds[pos] , i , j) + dis[j][ds[pos]]));
                    mn2[j][i].push_back (min (mn2[j][i][cc] + Ceil (wei[pos]) , Dist (ds[pos] , i , 0) + dis[j][ds[pos]]));
                    pos = ds[pos];
                    cc ++;
                }
            }
        }
        for (int i = 1;i <= n;++ i)
        {
            for (unsigned short j = 1;j <= 20;++ j)
            {
                Tr1[j].set (dfn[i] , dis[j][i] + Dist (i , top[i] , j) - d[0][i]);
                Tr2[j].set (dfn[i] , dis[j][i] + Dist (i , top[i] , 0) - d[j][i]);
            }
        }
        while (q --)
        {
            cin >> s >> t;
            ll ans = Dist (s , t , 0);
            int rt = lca (s , t);
            for (unsigned short i = 1;i <= 20;++ i)
            {
                ll w = 1ll * tim * i;
                if (w + Dist (s , t , i) >= ans) break;
                int u;
                ans = min (ans , Dist (s , rt , 0) + Dist (rt , t , i) + dis[i][rt] + w);
                u = s;
                while (top[u] ^ top[rt])
                {
                    ans = min (ans , mn1[i][top[u]][dfn[u] - dfn[top[u]]] + Dist (s , u , 0) + Dist (top[u] , t , i) + w);
                    u = Fa[top[u]];
                }
                if (u ^ rt)
                {
                    ll mn = Tr1[i].qry (dfn[rt] , dfn[u]) + d[0][u] - Dist (rt , top[u] , i);
                    ans = min (ans , mn + Dist (u , s , 0) + Dist (rt , t , i) + w);
                }
                u = t;
                while (top[u] ^ top[rt])
                {
                    ans = min (ans , mn2[i][top[u]][dfn[u] - dfn[top[u]]] + Dist (s , top[u] , 0) + Dist (u , t , i) + w);
                    u = Fa[top[u]];
                }
                if (u ^ rt)
                {
                    ll mn = Tr2[i].qry (dfn[rt] , dfn[u]) + d[i][u] - Dist (rt , top[u] , 0);
                    ans = min (ans , mn + Dist (u , t , i) + Dist (rt , s , 0) + w);
                }
            }
            cout << ans << '\n';
        }
    }
    return 0;
}