题解 P5439 【【XR-2】永恒】
题目大意
给两颗树
T_1, T_2 以及映射f: T_1 \rightarrow T_2 ,其中T_2 有根,记[u,v] 为T_1 上两点u, v 的简单路径上点集,求\sum_{x \in T_1} \sum_{y \in T_1} \sum_{u\in T_1} \sum_{v\in T_1} [x<y][u<v][[u,v] \subseteq [x, y]] (\operatorname{dep}(\operatorname{lca}(f(u),f(v))) - 1) 其中求 LCA 和深度均为在
T_2 上。
前置知识
- 点分治
- 虚树
题解
改变求和顺序,考虑将与值主要相关的
注意
考虑进行点分治,对穿过重心
考虑按照此定义将整个分治的连通块的贡献进行计算,然后容斥掉每个子树内的贡献。即我们要解决对于
考虑在
将
这时
对于
综上所述,点分治内部所消耗的时间复杂度为
因此,本题的理论时间复杂度为
取决于求权值这一部分的不同实现,可能有例如以下一些复杂度的算法可能卡过时限:
树链剖分:
LCT / 全局平衡二叉树优化:
使用 std::sort() 进行 DFS 序的排序:
代码
#include <cstdio>
#include <algorithm>
// #include <ctime>
#define il inline
typedef long long LL;
const int MN = 300005;
const int MV = 12000005; // max virtual tree nodes < 2 N log N
const int Mod = 998244353, Inv2 = (Mod + 1) / 2;
il void ad(int &x, int y) { x -= (x += y) >= Mod ? Mod : 0; }
int N, M, Ans;
int id[MN], h[MN], g[MN], nxt[MN * 3], to[MN * 3], tot;
il void ins(int *a, int x, int y) { nxt[++tot] = a[x], to[tot] = y, a[x] = tot; }
il void init() {
int x;
scanf("%d%d", &N, &M);
for (int i = 1; i <= N; ++i) {
scanf("%d", &x);
if (x) ins(g, x, i), ins(g, i, x);
}
scanf("%*d");
for (int i = 2; i <= M; ++i) scanf("%d", &x), ins(h, x, i);
scanf("%*s");
for (int i = 1; i <= N; ++i) scanf("%d", &id[i]);
} // read & edge-linking part
int idf[MN], dfn[MN], dfc;
int dep[MN], lg[MN * 2], st[MN * 2][20], ldf[MN], rdf[MN], ecn;
void tdfs(int u) {
idf[++dfc] = u, dfn[u] = dfc;
st[++ecn][0] = u, ldf[u] = ecn;
for (int i = h[u]; i; i = nxt[i])
dep[to[i]] = dep[u] + 1, tdfs(to[i]), st[++ecn][0] = u;
rdf[u] = ecn;
}
il int chkdep(int i, int j) { return dep[i] < dep[j] ? i : j; }
il void _st() {
lg[0] = -1;
for (int i = 1; i <= ecn; ++i) lg[i] = lg[i >> 1] + 1;
for (int j = 0; j < lg[ecn]; ++j)
for (int i = 2 << j; i <= ecn; ++i)
st[i][j + 1] = chkdep(st[i - (1 << j)][j], st[i][j]);
}
il int qurdep(int l, int r) { int b = lg[r - l + 1]; return chkdep(st[l + (1 << b) - 1][b], st[r][b]); }
il int lca(int x, int y) {
if (rdf[x] < ldf[y]) return qurdep(rdf[x], ldf[y]);
else if (rdf[y] < ldf[x]) return qurdep(rdf[y], ldf[x]);
else return dep[x] < dep[y] ? x : y;
} // dfn & O(M log M) - O(1) RMQ-LCA part
int Q, ty[MN * 2], qh[MN * 2], buk[MN], nx[MV * 2], b[MV * 2], w[MV * 2], vcn;
il void add(int *h, int x, int y, int z) { nx[++vcn] = h[x], b[vcn] = y, w[vcn] = z, h[x] = vcn; }
// queries part
int faz[MN], siz[MN], _siz[MN];
void _dfs(int u, int fz) {
faz[u] = fz, siz[u] = 1;
for (int i = g[u]; i; i = nxt[i]) if (to[i] != fz)
_dfs(to[i], u), siz[u] += siz[to[i]];
_siz[u] = siz[u];
}
int vis[MN], sz[MN], ts, rsz, rt;
void getrt(int u, int fz) {
int mxs = 0; sz[u] = 1;
for (int i = g[u]; i; i = nxt[i]) {
if (to[i] == fz || vis[to[i]]) continue;
getrt(to[i], u), sz[u] += sz[to[i]];
mxs = std::max(mxs, sz[to[i]]);
}
mxs = std::max(mxs, ts - sz[u]);
if (mxs < rsz) rt = u, rsz = mxs;
}
void fors(int u, int fz) {
add(buk, dfn[id[u]], Q, siz[u]);
for (int i = g[u]; i; i = nxt[i]) if (to[i] != fz && !vis[to[i]]) fors(to[i], u);
}
void dfs(int u) {
if (ts == 1) return ; // no-op
siz[u] = N;
for (int x = u; !vis[faz[x]]; x = faz[x]) siz[faz[x]] = N - _siz[x];
ty[++Q] = 0, fors(u, 0);
for (int i = g[u]; i; i = nxt[i]) if (!vis[to[i]])
ty[++Q] = 1, add(buk, dfn[id[u]], Q, siz[to[i]]), fors(to[i], u);
for (int x = u; !vis[x]; x = faz[x]) siz[x] = _siz[x];
vis[u] = 1;
int nsz = ts;
for (int i = g[u]; i; i = nxt[i]) if (!vis[to[i]]) {
rsz = ts = sz[to[i]] < sz[u] ? sz[to[i]] : nsz - sz[u];
getrt(to[i], 0), dfs(rt);
}
} // O(N log N) tree decomposition part
int Sum, sum[MN], rch[MN], tp;
il void C(int x, int y) { if (x) ad(Sum, (LL)y * y % Mod * (dep[x] - dep[rch[tp]]) % Mod); }
int main() {
// freopen("eternal-8-3.in", "r", stdin);
// int tim1, tim2, tim3, tim4, tim5, tim6, tim7;
// tim1 = clock();
init();
// tim2 = clock();
dep[1] = 0, tdfs(1);
// tim3 = clock();
_st();
// tim4 = clock();
_dfs(1, 0), vis[0] = 1, rsz = ts = N, getrt(1, 0), dfs(rt);
// tim5 = clock();
for (int i = M; i >= 1; --i)
for (int j = buk[i]; j; j = nx[j])
add(qh, b[j], idf[i], w[j]);
// O(N log N) radix sort part
// tim6 = clock();
for (int q = 1; q <= Q; ++q) {
int s = 0, x, y; Sum = 0;
rch[tp = 1] = 1, sum[1] = 0;
for (int i = qh[q]; i; i = nx[i]) {
int u = b[i];
ad(Sum, Mod - (LL)w[i] * w[i] * dep[u] % Mod);
ad(s, w[i]);
if (nx[i] && b[nx[i]] == u) continue;
int lc = lca(rch[tp], u);
for (x = y = 0; dep[rch[tp]] > dep[lc]; ad(y, sum[tp]), x = rch[tp--]) C(x, y);
if (dep[rch[tp]] < dep[lc]) rch[++tp] = lc, sum[tp] = 0;
ad(sum[tp], y);
C(x, y);
rch[++tp] = u, sum[tp] = s;
s = 0;
}
for (x = y = 0; tp; ad(y, sum[tp]), x = rch[tp--]) C(x, y);
ad(Ans, ty[q] ? Mod - Sum : Sum);
} // virtual tree part
// tim7 = clock();
Ans = (LL)Ans * Inv2 % Mod, printf("%d\n", Ans);
/* printf(" Init : %dms\n", tim2 - tim1);
printf(" Trie : %dms\n", tim3 - tim2);
printf(" ST : %dms\n", tim4 - tim3);
printf(" DFZ : %dms\n", tim5 - tim4);
printf(" sort : %dms\n", tim6 - tim5);
printf(" vt : %dms\n", tim7 - tim6);*/
return 0;
}