题解 P5439 【【XR-2】永恒】

· · 题解

题目大意

给两颗树 T_1, T_2 以及映射 f: T_1 \rightarrow T_2,其中 T_2 有根,记 [u,v]T_1 上两点 u, v 的简单路径上点集,求

\sum_{x \in T_1} \sum_{y \in T_1} \sum_{u\in T_1} \sum_{v\in T_1} [x<y][u<v][[u,v] \subseteq [x, y]] (\operatorname{dep}(\operatorname{lca}(f(u),f(v))) - 1)

其中求 LCA 和深度均为在 T_2 上。

前置知识

题解

改变求和顺序,考虑将与值主要相关的 u, v 放到前面:

\sum_{u\in T_1} \sum_{v\in T_1} [u<v] (\operatorname{dep}(\operatorname{lca}(f(u),f(v))) - 1)\sum_{x \in T_1} \sum_{y \in T_1} [x<y][[u,v] \subseteq [x, y]]

注意 [u<v]+[v<u] = \frac{[u\neq v]}2,不妨将问题转化为求不同点之间关系:

\frac14 \sum_{u\in T_1} \sum_{v\in T_1} [u\neq v] (\operatorname{dep}(\operatorname{lca}(f(u),f(v))) - 1)\sum_{x \in T_1} \sum_{y \in T_1} [x\neq y][[u,v] \subseteq [x, y]]

考虑进行点分治,对穿过重心 c 的所有 [u,v] 对答案的贡献进行统计。记当前分治时 c 为根,\operatorname{sub}(u)u 节点在整个 T_1 的子树大小。首先考虑 u, v 均不为 c 的情况,若它们在不同子树内,则贡献为

(\operatorname{dep}(\operatorname{lca}(f(u),f(v))) - 1) \operatorname{sub} (u) \operatorname{sub} (v)

考虑按照此定义将整个分治的连通块的贡献进行计算,然后容斥掉每个子树内的贡献。即我们要解决对于 T_1 上一个点集 S,求出

\left(\sum_{u\in S} \sum_{v\in S} (\operatorname{dep}(\operatorname{lca}(f(u),f(v))) - 1) \operatorname{sub} (u) \operatorname{sub} (v) \right) - \sum_{u\in S} (\operatorname{dep}(f(u)) - 1)\operatorname{sub}(u)^2

考虑在 T_2(\operatorname{dep}(\operatorname{lca}(f(u),f(v))) - 1) \operatorname{sub} (u) \operatorname{sub} (v) 的意义,这等价于将 f(u)T_2 的根上的每一个节点加上 \operatorname{sub}(u) 的权值,询问 f(v) 的父亲到根节点的路径上权值和与 \operatorname{sub}(v) 的乘积。如果通过数据结构进行计算会多出一个 \log 的代价,考虑该问题是离线的,通过虚树优化复杂度。

\{f(u) | u\in S\} \cup \{ \operatorname{root}(T_2) \} 的虚树建出,考虑先将每个点映射到的点加上对应权值 w(f(u)) \leftarrow w(f(u)) + \operatorname{sub}(u),这维护了实际修改对应的差分,第二步将虚树从下至上进行累加,这时每个点与原本的点权概念一致,第三步从上至下累加,记 p(x) 为在虚树上的父节点,s(f(u)) = s(p(f(u))) + w(p(f(u)))[\operatorname{dep} (f(u)) - \operatorname{dep}(p(f(u)))]

这时 \sum_{v\in S} s(f(v)) \operatorname{sub}(v) 即为所求。虚树计算的时间复杂度为 \Theta(|S|)。但注意建立虚树时通常有一个对 DFS 序排序的问题,考虑到本题所求全部为离线问题,所以离线进行多数组的基数排序即可做到严格的 \Theta\left(m + \sum |S|\right) = \Theta(m + n\log n)

对于 u 为重心的情况,可以枚举 v 进行计算。

综上所述,点分治内部所消耗的时间复杂度为 \Theta(|S|),故点分治的总体复杂度为 \Theta(n\log n)。为了辅助求出虚树,可能需要 \Theta(1) LCA 计算,预处理的理论复杂度可以做到 \Theta(m),使用朴素的 ST 表可以做到 \Theta(m\log m)

因此,本题的理论时间复杂度为 \Theta(m + n\log n),需要 \Theta(m + n\log n) 的空间。本题 std(by PinkRabbit) 的实现时间复杂度为 \Theta(m\log m + n\log n),空间复杂度为 \Theta(m\log m + n\log n)

取决于求权值这一部分的不同实现,可能有例如以下一些复杂度的算法可能卡过时限:

树链剖分:\Theta(m + n\log n\log^2 m)

LCT / 全局平衡二叉树优化:\Theta(m\log m + n\log n \log m)

使用 std::sort() 进行 DFS 序的排序:\Theta(m\log m + n\log^2 n),空间 \Theta(m\log m + n)

代码

#include <cstdio>
#include <algorithm>
// #include <ctime>
#define il inline
typedef long long LL;
const int MN = 300005;
const int MV = 12000005; // max virtual tree nodes < 2 N log N
const int Mod = 998244353, Inv2 = (Mod + 1) / 2;
il void ad(int &x, int y) { x -= (x += y) >= Mod ? Mod : 0; }

int N, M, Ans;
int id[MN], h[MN], g[MN], nxt[MN * 3], to[MN * 3], tot;
il void ins(int *a, int x, int y) { nxt[++tot] = a[x], to[tot] = y, a[x] = tot; }
il void init() {
    int x;
    scanf("%d%d", &N, &M);
    for (int i = 1; i <= N; ++i) {
        scanf("%d", &x);
        if (x) ins(g, x, i), ins(g, i, x);
    }
    scanf("%*d");
    for (int i = 2; i <= M; ++i) scanf("%d", &x), ins(h, x, i);
    scanf("%*s");
    for (int i = 1; i <= N; ++i) scanf("%d", &id[i]);
} // read & edge-linking part

int idf[MN], dfn[MN], dfc;
int dep[MN], lg[MN * 2], st[MN * 2][20], ldf[MN], rdf[MN], ecn;
void tdfs(int u) {
    idf[++dfc] = u, dfn[u] = dfc;
    st[++ecn][0] = u, ldf[u] = ecn;
    for (int i = h[u]; i; i = nxt[i])
        dep[to[i]] = dep[u] + 1, tdfs(to[i]), st[++ecn][0] = u;
    rdf[u] = ecn;
}
il int chkdep(int i, int j) { return dep[i] < dep[j] ? i : j; }
il void _st() {
    lg[0] = -1;
    for (int i = 1; i <= ecn; ++i) lg[i] = lg[i >> 1] + 1;
    for (int j = 0; j < lg[ecn]; ++j)
        for (int i = 2 << j; i <= ecn; ++i)
            st[i][j + 1] = chkdep(st[i - (1 << j)][j], st[i][j]);
}
il int qurdep(int l, int r) { int b = lg[r - l + 1]; return chkdep(st[l + (1 << b) - 1][b], st[r][b]); }
il int lca(int x, int y) {
    if (rdf[x] < ldf[y]) return qurdep(rdf[x], ldf[y]);
    else if (rdf[y] < ldf[x]) return qurdep(rdf[y], ldf[x]);
    else return dep[x] < dep[y] ? x : y;
} // dfn & O(M log M) - O(1) RMQ-LCA part

int Q, ty[MN * 2], qh[MN * 2], buk[MN], nx[MV * 2], b[MV * 2], w[MV * 2], vcn;
il void add(int *h, int x, int y, int z) { nx[++vcn] = h[x], b[vcn] = y, w[vcn] = z, h[x] = vcn; }
// queries part

int faz[MN], siz[MN], _siz[MN];
void _dfs(int u, int fz) {
    faz[u] = fz, siz[u] = 1;
    for (int i = g[u]; i; i = nxt[i]) if (to[i] != fz)
        _dfs(to[i], u), siz[u] += siz[to[i]];
    _siz[u] = siz[u];
}
int vis[MN], sz[MN], ts, rsz, rt;
void getrt(int u, int fz) {
    int mxs = 0; sz[u] = 1;
    for (int i = g[u]; i; i = nxt[i]) {
        if (to[i] == fz || vis[to[i]]) continue;
        getrt(to[i], u), sz[u] += sz[to[i]];
        mxs = std::max(mxs, sz[to[i]]);
    }
    mxs = std::max(mxs, ts - sz[u]);
    if (mxs < rsz) rt = u, rsz = mxs;
}
void fors(int u, int fz) {
    add(buk, dfn[id[u]], Q, siz[u]);
    for (int i = g[u]; i; i = nxt[i]) if (to[i] != fz && !vis[to[i]]) fors(to[i], u);
}
void dfs(int u) {
    if (ts == 1) return ; // no-op
    siz[u] = N;
    for (int x = u; !vis[faz[x]]; x = faz[x]) siz[faz[x]] = N - _siz[x];
    ty[++Q] = 0, fors(u, 0);
    for (int i = g[u]; i; i = nxt[i]) if (!vis[to[i]])
        ty[++Q] = 1, add(buk, dfn[id[u]], Q, siz[to[i]]), fors(to[i], u);
    for (int x = u; !vis[x]; x = faz[x]) siz[x] = _siz[x];
    vis[u] = 1;
    int nsz = ts;
    for (int i = g[u]; i; i = nxt[i]) if (!vis[to[i]]) {
        rsz = ts = sz[to[i]] < sz[u] ? sz[to[i]] : nsz - sz[u];
        getrt(to[i], 0), dfs(rt);
    }
} // O(N log N) tree decomposition part

int Sum, sum[MN], rch[MN], tp;
il void C(int x, int y) { if (x) ad(Sum, (LL)y * y % Mod * (dep[x] - dep[rch[tp]]) % Mod); }

int main() {
//  freopen("eternal-8-3.in", "r", stdin);
//  int tim1, tim2, tim3, tim4, tim5, tim6, tim7;
//  tim1 = clock();
    init();
//  tim2 = clock();
    dep[1] = 0, tdfs(1);
//  tim3 = clock();
    _st();
//  tim4 = clock();
    _dfs(1, 0), vis[0] = 1, rsz = ts = N, getrt(1, 0), dfs(rt);
//  tim5 = clock();
    for (int i = M; i >= 1; --i)
        for (int j = buk[i]; j; j = nx[j])
            add(qh, b[j], idf[i], w[j]);
    // O(N log N) radix sort part
//  tim6 = clock();
    for (int q = 1; q <= Q; ++q) {
        int s = 0, x, y; Sum = 0;
        rch[tp = 1] = 1, sum[1] = 0;
        for (int i = qh[q]; i; i = nx[i]) {
            int u = b[i];
            ad(Sum, Mod - (LL)w[i] * w[i] * dep[u] % Mod);
            ad(s, w[i]);
            if (nx[i] && b[nx[i]] == u) continue;
            int lc = lca(rch[tp], u);
            for (x = y = 0; dep[rch[tp]] > dep[lc]; ad(y, sum[tp]), x = rch[tp--]) C(x, y);
            if (dep[rch[tp]] < dep[lc]) rch[++tp] = lc, sum[tp] = 0;
            ad(sum[tp], y);
            C(x, y);
            rch[++tp] = u, sum[tp] = s;
            s = 0;
        }
        for (x = y = 0; tp; ad(y, sum[tp]), x = rch[tp--]) C(x, y);
        ad(Ans, ty[q] ? Mod - Sum : Sum);
    } // virtual tree part
//  tim7 = clock();
    Ans = (LL)Ans * Inv2 % Mod, printf("%d\n", Ans);
/*  printf(" Init : %dms\n", tim2 - tim1);
    printf(" Trie : %dms\n", tim3 - tim2);
    printf(" ST   : %dms\n", tim4 - tim3);
    printf(" DFZ  : %dms\n", tim5 - tim4);
    printf(" sort : %dms\n", tim6 - tim5);
    printf(" vt   : %dms\n", tim7 - tim6);*/
    return 0;
}