题解 CF1254D 【Tree Queries】

· · 题解

随便刷 CF 刷到这道题,然后想了个常数巨大的 \mathcal O(n\sqrt n) 做法。(大概和其它题解不太一样)

考虑修改相当于给 u 的所有相邻的点 v 所在子树加 (n - {\rm siz}_v)\cdot \dfrac dn.(这里的 {\rm siz}_v 指的是断掉 (u,v) 这条边后 v 所在连通块的大小)我们不妨令 d\leftarrow \dfrac dn.

我们先考虑假如查询全在修改之后怎么做,我们可以把对于同一个点的修改合并(d 全部加起来),然后暴力给每个点打树上差分标记,再暴力给树上差分做树上前缀和,做一次的时间复杂度是 \mathcal O(n) 的。

于是想到,能否想到另外一个做法,使得我们每隔 B 个操作进行一次上述操作暴力应用修改,然后对于这 B 个操作暴力查它们之间的影响。

注意到,我们要判断一个操作对一个询问的影响,唯一需要知道的事就是询问的点 v 在修改作用于的点 u哪一棵子树内。

随便拿个点做根,记 v 的深度为 d_vv 的父亲为 {\rm fa}_v,则(下文的 {\rm siz}_v 指的是断掉 ({\rm fa}_v, v) 这条边后 v 所在连通块的大小):

假如用倍增跳祖先的方法,时间复杂度是 \mathcal O(\log n) 的,这样总的时间复杂度是 \mathcal O\left(n\log n + qB\log n+n\dfrac qB \right) 的,在 B = \sqrt{\dfrac n {\log n}} 时取到最小值 \mathcal O(n\log n + n\sqrt{n\log n}),没试过,但是感觉不是很跑的过去。虽然常数小就是了

注意到跳 k 级祖先可以使用长链剖分,对于每个点预处理出 2^i 级祖先,对于每条长链,在链顶向上下预处理长链长度级祖先。对于一个询问,跳到令 ik 二进制下的最高位,跳到 2^i 级祖先所在的长链的链顶。显然这条长链长度不小于 2^i,更不小于剩下的 k(或者多跳的 k),直接查处理过的表即可。时间复杂度 \mathcal O(n\log n)-\mathcal O(1).

这样时间复杂度就变成了 \mathcal O\left(n\log n+qB+n\dfrac qB \right) 的,在 B = \sqrt{n} 时取到最小值 \mathcal O(n\log n +q\sqrt n),实测可过。

ll QPow(ll a, ll b) {
    ll ret = 1, bas = a;
    for(; b; b >>= 1, bas = bas * bas % kMod) if(b & 1) ret = ret * bas % kMod;
    return ret;
}

int n, q;
std::vector <int> E[kN];
void Add(int u, int v) { E[u].push_back(v); E[v].push_back(u); }

int fa[kN][20], h_bit[kN], top[kN], d[kN], hvy[kN], dep[kN],
    dfn[kN], siz[kN], dfv = 0;
void Dfs(int u) {
    dfn[u] = ++dfv; siz[u] = 1;
    dep[u] = dep[fa[u][0]] + 1;
    for(auto v : E[u]) if(v != fa[u][0]) {
        fa[v][0] = u; Dfs(v);
        if(d[v] > d[hvy[u]]) hvy[u] = v;
        siz[u] += siz[v];
    }
    d[u] = d[hvy[u]] + 1;
}
void GetTop(int u, int t) {
    top[u] = t;
    if(hvy[u]) GetTop(hvy[u], t);
    for(auto v : E[u]) if(v != fa[u][0] && v != hvy[u]) {
        GetTop(v, v);
    }
}

std::vector <int> lower[kN], upper[kN];
void Init() {
    Dfs(1); GetTop(1, 1);
    for(int i = 1; i <= 18; ++i)
        for(int j = 1; j <= n; ++j)
            fa[j][i] = fa[fa[j][i - 1]][i - 1];
    h_bit[0] = -1;
    for(int i = 1; i <= n; ++i)
        h_bit[i] = h_bit[i >> 1] + 1;
    for(int i = 1; i <= n; ++i) if(i == top[i]) {
        for(int x = i, y = i; top[x] == i; x = hvy[x], y = fa[y][0]) {
            lower[i].push_back(x);
            upper[i].push_back(y);
        }
    }
}
int FindAnc(int x, int k) {
    if(!k) return x;
    int t = top[fa[x][h_bit[k]]];
    k -= dep[x] - dep[t];
    int ret;
    if(k > 0) ret = upper[t][k];
    else ret = lower[t][-k];
    return ret;
}

struct Mod {
    int v; ll d;
} M[kN];
int cnt = 0;
ll t[kN], A[kN], dt[kN];
void Push() {
    cnt = 0;
    for(int u = 1; u <= n; ++u) if(t[u]) {
        ll d = t[u];
        for(auto v : E[u]) if(v != fa[u][0]) {
            dt[dfn[v]] = (dt[dfn[v]] + (n - siz[v]) * d) % kMod;
            dt[dfn[v] + siz[v]] = (dt[dfn[v] + siz[v]] - (n - siz[v]) * d) % kMod;
        }
        dt[1] = (dt[1] + siz[u] * d) % kMod;
        dt[dfn[u]] = (dt[dfn[u]] - siz[u] * d + n * d) % kMod;
        dt[dfn[u] + 1] = (dt[dfn[u] + 1] - n * d) % kMod;
        dt[dfn[u] + siz[u]] = (dt[dfn[u] + siz[u]] + siz[u] * d) % kMod;
    }
    for(int i = 1; i <= n; ++i)
        dt[i] = ((dt[i - 1] + dt[i]) % kMod + kMod) % kMod;
    for(int u = 1; u <= n; ++u)
        A[u] = (A[u] + dt[dfn[u]]) % kMod;
    memset(t, 0, sizeof(t));
    memset(dt, 0, sizeof(dt));
}
void Modify(int v, ll d) {
    t[v] = (t[v] + d) % kMod; 
    M[++cnt] = (Mod) { v, d };
}
ll Query(int v) {
    ll ret = A[v];
    for(int i = 1; i <= cnt; ++i) {
        int u = M[i].v; ll d = M[i].d;
        if(u == v) ret = (ret + n * d) % kMod;
        else if(dep[u] >= dep[v]) ret = (ret + siz[u] * d) % kMod;
        else {
            int x = FindAnc(v, dep[v] - dep[u] - 1);
            if(fa[x][0] != u) ret = (ret + siz[u] * d) % kMod;
            else ret = (ret + (n - siz[x]) * d) % kMod;
        }
    }
    return ret;
}

int main() { 
    ll n_inv; int B;
    rd(n, q); n_inv = QPow(n, kMod - 2);
    B = sqrt(n);
    for(int i = 1; i < n; ++i) {
        int u, v; rd(u, v);
        Add(u, v);
    }
    Init();
    for(int i = 1; i <= q; ++i)  {
        int opt, v; rd(opt, v);
        if(i / B != (i - 1) / B) Push();
        if(opt == 1) {
            ll d; rd(d);
            Modify(v, d * n_inv % kMod);
        } else printf("%lld\n", Query(v));
    }
    return 0;
}