题解 P7735【[NOI2021] 轻重边】
提供一种魔改树链剖分的做法。
分析
把题中的重边看成边权
容易想到树链剖分。但传统的树链剖分的编号原则是「重链编号连续、子树编号连续」,我们这里不需要维护子树信息,需要知道儿子的信息,所以比较自然想到,用某种方式使得「重链上所有点的儿子编号连续」。
如何实现呢?我的方法是,先按传统的方法把重链都分出来,然后类似于 BFS,建立一个队列,把以根节点为
此外,对每个点
上述方案唯一的美中不足是,一条重链的
显然,与传统树剖的时间复杂度一样,为
考场上跑自己脚造的数据会超
代码
具体实现的时候可以不把链先赋值为
考场上的核心代码,稍微修得好看了一点:
struct SegT {
int l, r, sum, tg;
#define ls (p << 1)
#define rs (p << 1 | 1)
} t[N<<2];
void build(int p, int l, int r) {
t[p].l = l, t[p].r = r, t[p].sum = 0, t[p].tg = -1;
if (l == r) return;
int mid = (l + r) >> 1;
build(ls, l, mid), build(rs, mid + 1, r);
}
inline void push_up(int p) { t[p].sum = t[ls].sum + t[rs].sum; }
inline void push_down(int p) {
if (t[p].tg != -1) {
t[ls].sum = t[p].tg * (t[ls].r - t[ls].l + 1), t[ls].tg = t[p].tg;
t[rs].sum = t[p].tg * (t[rs].r - t[rs].l + 1), t[rs].tg = t[p].tg;
t[p].tg = -1;
}
}
void upd(int p, int l, int r, int v) {
if (l <= t[p].l && t[p].r <= r) return t[p].sum = v * (t[p].r - t[p].l + 1), t[p].tg = v, void();
int mid = (t[p].l + t[p].r) >> 1;
push_down(p);
if (l <= mid) upd(ls, l, r, v);
if (r > mid) upd(rs, l, r, v);
push_up(p);
}
int qry(int p, int l, int r) {
if (l <= t[p].l && t[p].r <= r) return t[p].sum;
int mid = (t[p].l + t[p].r) >> 1, res = 0;
push_down(p);
if (l <= mid) res += qry(ls, l, r);
if (r > mid) res += qry(rs, l, r);
return res;
}
int fa[N], siz[N], son[N], dep[N], tp[N];
int dfn[N], rk[N], sec;
int upmx[N], upmn[N], dwmx[N], dwmn[N];
vector<int> V[N];
int q[N], hd, tl;
void dfs1(int x, int ff, int dpt) {
fa[x] = ff, dep[x] = dpt;
siz[x] = 1, son[x] = 0;
for (int i = g[x]; i; i = nxt[i]) {
int y = v[i];
if (y == fa[x]) continue;
dfs1(y, x, dep[x] + 1);
if (!son[x] || siz[y] > siz[son[x]]) son[x] = y;
siz[x] += siz[y];
}
}
void dfs2(int x) {
if (!tp[x]) tp[x] = x;
V[tp[x]].pb(x);
for (int i = g[x]; i; i = nxt[i]) {
int y = v[i];
if (y == fa[x]) continue;
if (y == son[x]) tp[y] = tp[x];
dfs2(y);
}
}
inline void upd0(int x, int y) {
while (tp[x] != tp[y]) {
if (dep[tp[x]] < dep[tp[y]]) swap(x, y);
if (son[x]) upd(1, dfn[son[x]], dfn[son[x]], 0);
if (dwmn[tp[x]] <= upmx[x]) upd(1, dwmn[tp[x]], upmx[x], 0);
x = fa[tp[x]];
}
if (dep[x] < dep[y]) swap(x, y);
if (son[x]) upd(1, dfn[son[x]], dfn[son[x]], 0);
if (dwmn[y] <= upmx[x]) upd(1, dwmn[y], upmx[x], 0);
upd(1, dfn[y], dfn[y], 0);
}
inline void upd1(int x, int y) {
while (tp[x] != tp[y]) {
if (dep[tp[x]] < dep[tp[y]]) swap(x, y);
if (tp[x] != x) upd(1, dfn[son[tp[x]]], dfn[x], 1);
upd(1, dfn[tp[x]], dfn[tp[x]], 1);
x = fa[tp[x]];
}
if (dep[x] < dep[y]) swap(x, y);
if (x != y) upd(1, dfn[son[y]], dfn[x], 1);
}
inline int qry(int x, int y) {
int res = 0;
while (tp[x] != tp[y]) {
if (dep[tp[x]] < dep[tp[y]]) swap(x, y);
if (tp[x] != x) res += qry(1, dfn[son[tp[x]]], dfn[x]);
res += qry(1, dfn[tp[x]], dfn[tp[x]]);
x = fa[tp[x]];
}
if (dep[x] < dep[y]) swap(x, y);
if (x != y) res += qry(1, dfn[son[y]], dfn[x]);
return res;
}
inline void init() {
memset(g, 0, sizeof(g)), tot = 0;
memset(tp, 0, sizeof(tp)), sec = 0;
for (int i = 1; i <= n; i++) V[i].clear();
}
inline void solve() {
n = rd(), m = rd();
init();
for (int i = 1; i < n; i++) {
int x = rd(), y = rd();
add(x, y), add(y, x);
}
for (int i = 1; i <= m; i++) {
qs[i].op = rd(), qs[i].x = rd(), qs[i].y = rd();
}
if (n <= 5000 && m <= 5000) {
sub1::main();
return;
}
dfs1(1, 0, 0);
dfs2(1);
hd = tl = 0;
q[tl++] = 1, dfn[1] = ++sec, rk[sec] = 1;
while (hd != tl) {
int x = q[hd++], len = V[x].size();
for (int i = 1; i < len; i++) {
int y = V[x][i];
dfn[y] = ++sec, rk[sec] = y;
}
for (int i = 0; i < len; i++) {
int y = V[x][i];
upmx[y] = -inf, upmn[y] = inf;
for (int j = g[y]; j; j = nxt[j]) {
int z = v[j];
if (z == fa[y]) continue;
if (z == son[y]) continue;
dfn[z] = ++sec, rk[sec] = z;
q[tl++] = z;
chkmax(upmx[y], dfn[z]);
chkmin(upmn[y], dfn[z]);
}
dwmx[y] = upmx[y], dwmn[y] = upmn[y];
}
for (int i = 1; i < len; i++) {
chkmax(upmx[V[x][i]], upmx[V[x][i-1]]);
chkmin(upmn[V[x][i]], upmn[V[x][i-1]]);
}
for (int i = len - 2; i >= 0; i--) {
chkmax(dwmx[V[x][i]], dwmx[V[x][i+1]]);
chkmin(dwmn[V[x][i]], dwmn[V[x][i+1]]);
}
}
build(1, 1, n);
for (int i = 1; i <= m; i++) {
int op = qs[i].op, x = qs[i].x, y = qs[i].y;
if (op == 1) upd0(x, y), upd1(x, y);
else printf("%d\n", qry(x, y));
}
}