【题解】CF1749F Distance to the Path

· · 题解

部分参考了 这篇 xianggl巨佬的博客。

由于 d \le 20,考虑将不等关系的信息存放在数组中,转换为最大距离 d 时的信息。考虑单独处理每一个点时,直接暴力修改与某一点 x 距离为 [0,d] 的点,此时不难分析出时间复杂度为 O(qnd)

由于是区间更新,需要同时更新 x \to y 这条链上的点,也就是提示要用数据结构去维护。将单点数据的处理分为子树内与子树外两种情况进行分类讨论。

对于在子树内的点,倘若是一段连续的区间,则可以利用数据结构一次更新完毕。不难想到利用树链剖分维护 \texttt{dfs} 序,这样子树内的 \texttt{dfs} 序便是连续的。利用常数较小的树状数组,设 c_{i,j} 存的是最大距离为 i 时每一个点产生的贡献。则在一次处理时直接更新 x,y,\rm{LCA(x,y)} 三点即可,也就是 modify (c[d],x,k);modify (c[d],y,k);modify (c[d],LCA(x,y),-2 * k),其中 modify 函数即为树状数组的操作。用 in[x]out[x] 分别表示第一次和最后一次经过 x 点时的 \texttt{dfs} 序,因此在查询答案时子树内的贡献即用差分的思想进行树状数组维护操作。

对于在子树外的点,显然都可以归结到 z = \rm LCA(x,y) 这一点产生的贡献。由于是单点,直接进行暴力修改。设 dp_{i,j} 表示以 i 为根距离为 j 时的贡献的权值。从 z 开始向上走,z 对上面距离不超过 j 的节点产生影响,然后发现,z 的父亲会对距离不超过 j - 2 的节点产生的影响被重复计算,直接进行容斥即可(需要减去的条件是距离超过 j - 2z 上方还存在节点)。

综上分析,子树内修改为 O(q\log n),查询为 O(qd \log n);子树外修改为 O(qd),查询为 (qd),由于 n,q 同阶,d,\log n 大小相近,所以最后的时间复杂度为 O(n \log^2 n)

代码如下:

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <vector>
#define init(x) memset (x,0,sizeof (x))
#define ll long long
#define ull unsigned long long
#define INF 0x3f3f3f3f
#define lowbit(x) x & (-x)
using namespace std;
const int MAX = 2e5 + 5;
const int MOD = 1e9 + 7;
inline int read ();
int n,m,cnt,in[MAX],out[MAX],top[MAX],sz[MAX],hson[MAX],f[MAX],dep[MAX],dp[MAX][25],c[25][MAX]; 
vector <int> ve[MAX];
void dfs1 (int u,int fa);
void dfs2 (int u,int fa);
int lca (int u,int v);
void modify (int tr[],int x,int v);
int query (int tr[],int x);
int main ()
{
    //freopen (".in","r",stdin);
    //freopen (".out","w",stdout);
    n = read ();
    for (int i = 1;i < n;++i)
    {
        int x = read (),y = read ();
        ve[x].push_back (y);ve[y].push_back (x);
    }
    dfs1 (1,0);
    in[1] = ++cnt,top[1] = 1;
    dfs2 (1,0);
    //for (int i = 1;i <= n;++i) cout<<in[i]<<" "<<out[i]<<endl;
    m = read ();
    for (int i = 1;i <= m;++i)
    {
        int ty = read ();
        if (ty == 1)
        {
            int x = read (),ans = 0,p = 0;
            for (int i = 0;i <= 20;++i)
            {
                ans += query (c[i],out[x]) - query (c[i],in[x] - 1);//差分思想
                for (int j = p;j <= 20;++j) ans += dp[x][j];//每向上走一次,距离至少为 p
                ++p;x = f[x];
                if (!x) break;//向上面已经没有节点就停止
            }
            printf ("%d\n",ans);
        }
        else
        {
            int u = read (),v = read (),k = read (),d = read ();
            int z = lca (u,v);
            modify (c[d],in[u],k);modify (c[d],in[v],k);modify (c[d],in[z],-2 * k);//重复部分的去除
            for (int i = d;~i;--i)
            {   
                dp[z][i] += k;
                if (i - 2 >= 0 && f[z]) dp[z][i - 2] -= k;//容斥 距离 >= i - 2 且 z 上还有节点 
                z = f[z];
                if (!z) break;
            }
        }
    } 
    return 0;
}
inline int read ()
{
    int s = 0;int f = 1;
    char ch = getchar ();
    while ((ch < '0' || ch > '9') && ch != EOF)
    {
        if (ch == '-') f = -1;
        ch = getchar ();
    }
    while (ch >= '0' && ch <= '9')
    {
        s = s * 10 + ch - '0';
        ch = getchar ();
    }
    return s * f;
}
void dfs1 (int u,int fa)//树链剖分
{
    dep[u] = dep[fa] + 1;
    f[u] = fa;sz[u] = 1;
    for (auto v : ve[u])
    {
        if (v == fa) continue;
        dfs1 (v,u);
        sz[u] += sz[v];
        if (sz[v] > sz[hson[u]]) hson[u] = v;
    }
}
void dfs2 (int u,int fa)
{
    if (hson[u])
    {
        in[hson[u]] = ++cnt;//dfs 序
        top[hson[u]] = top[u];
        dfs2 (hson[u],u);
    }
    for (auto v : ve[u])
    {
        if (top[v]) continue;
        in[v] = ++cnt;
        top[v] = v;
        dfs2 (v,v);
    }
    out[u] = cnt;//dfs 序
}
int lca (int u,int v)//树链剖分求 LCA
{
    int fx = top[u],fy = top[v];
    while (fx != fy)
    {
        if (dep[fx] < dep[fy]) swap (fx,fy),swap (u,v);
        u = f[fx],fx = top[u]; 
    }
    if (dep[u] > dep[v]) swap (u,v);
    return u;
}
void modify (int tr[],int x,int v) {for (int i = x;i <= n;i += lowbit (i)) tr[i] += v;}
int query (int tr[],int x) {int sum = 0;for (int i = x;i;i -= lowbit (i)) sum += tr[i];return sum;}//树状数组基本操作