P7671
题解 P7671
主席树,树链剖分,标记永久化
给定一棵树,要求实现以下操作:
-
将树上
x 到y 的权值加k ,同时新建一个版本。 -
查询树上
x 到y 的答案。树上
x 到y 的答案的定义为:\sum\limits_{i \in R(x,y)}\!{a_i}\times\dfrac{dis(i,y)(dis(i,y)+1)}{2} -
回到第
x 个版本,初始状态视作0 版本。
可以发现,该题核心操作为查询答案,可以发现我们需要想办法维护
于是就有:
我们发现,如果强行分解这个式子,最后将非常难维护,我们不妨将一条路径拆分成两条链的形式,现在有了两个集合,
现在,我们只要在线段树上维护
但是,这题我们需要完成主席树上的区间修改、区间查询,因为主席树的不同版本节点时存在共用的,如果我们随意地 pushdown 以及 pushup 的话就会得到错误答案,因此我们每次修改都需要新建节点,这会导致空间非常大,其中有大量冗余节点。因此,我们可以使用标记永久化的方法,在修改时,如果节点区间被完全包含了,直接将懒标记打到节点上再返回,有交集就累加上更改的值然后向其遍历;在查询时,一路累加懒标记的贡献,同样在节点区间被完全包含时返回。
#include <iostream>
#include <cstdio>
using namespace std;
inline int read()
{
int x = 0;
char c = getchar();
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9')
{
x = (x << 1) + (x << 3) + (c ^ 48);
c = getchar();
}
return x;
}
const int N = 100100;
const int mod = 20160501;
int head[N], to[N * 2], nxt[N * 2], tot;
inline void add(int x, int y)
{
to[++tot] = y;
nxt[tot] = head[x];
head[x] = tot;
}
int d[N], s[N], f[N], hs[N];
inline void dfs1(int x)
{
s[x] = 1;
for (int i = head[x]; i; i = nxt[i])
{
if (d[to[i]]) continue;
d[to[i]] = d[x] + 1;
f[to[i]] = x;
dfs1(to[i]);
s[x] += s[to[i]];
if (s[hs[x]] < s[to[i]]) hs[x] = to[i];
}
}
int top[N], dfn[N], rk[N], res;
inline void dfs2(int x, int t)
{
top[x] = t;
dfn[x] = ++res;
rk[res] = x;
if (hs[x]) dfs2(hs[x], t);
for (int i = head[x]; i; i = nxt[i])
if (to[i] != f[x] && to[i] != hs[x]) dfs2(to[i], to[i]);
}
struct tree
{
long long k, dk, ddk, f;
int ls, rs, v[2];
tree() { k = dk = ddk = ls = rs = f = v[0] = v[1] = 0; }
} t[N << 6];
int rt[N], cnt, root, n, m;
long long a[N], e[N], ee[N];
inline tree operator + (tree a, tree b)
{
tree c;
c.k = (a.k + b.k) % mod;
c.dk = (a.dk + b.dk) % mod;
c.ddk = (a.ddk + b.ddk) % mod;
return c;
}
inline int D(int l, int r) { return e[r] - e[l - 1] + mod; }
inline int DD(int l, int r) { return ee[r] - ee[l - 1] + mod; }
inline void build(int &tp, int l, int r)
{
tp = ++cnt;
if (l == r)
{
int x = rk[l];
t[tp].k = (a[x]) % mod;
t[tp].dk = (a[x] * d[x]) % mod;
t[tp].ddk = (a[x] * d[x] * d[x]) % mod;
return;
}
int mid = (l + r) >> 1;
build(t[tp].ls, l, mid);
build(t[tp].rs, mid + 1, r);
t[tp].k = (t[t[tp].ls].k + t[t[tp].rs].k) % mod;
t[tp].dk = (t[t[tp].ls].dk + t[t[tp].rs].dk) % mod;
t[tp].ddk = (t[t[tp].ls].ddk + t[t[tp].rs].ddk) % mod;
}
inline void pushup(tree &tp, int l, int r, long long k)
{
tp.k = (tp.k + k * (r - l + 1) % mod) % mod;
tp.dk = (tp.dk + k * D(l, r) % mod) % mod;
tp.ddk = (tp.ddk + k * DD(l, r) % mod) % mod;
}
inline int copy(int tp)
{
t[++cnt] = t[tp];
return cnt;
}
inline void update(int tp, int l, int r, int ql, int qr, long long k)
{
pushup(t[tp], max(l, ql), min(r, qr), k);//只累加需要更改的区间的值
if (ql <= l && r <= qr) { t[tp].f = (t[tp].f + k) % mod; return; }
int mid = (l + r) >> 1;
if (ql <= mid)
{
if (t[tp].v[0])
{
t[tp].ls = copy(t[tp].ls);
t[t[tp].ls].v[0] = 1;
t[t[tp].ls].v[1] = 1;
t[tp].v[0] = 0;
}
update(t[tp].ls, l, mid, ql, qr, k);
}
if (mid < qr)
{
if (t[tp].v[1])
{
t[tp].rs = copy(t[tp].rs);
t[t[tp].rs].v[0] = 1;
t[t[tp].rs].v[1] = 1;
t[tp].v[1] = 0;
}
update(t[tp].rs, mid + 1, r, ql, qr, k);
}
}
inline tree query(int tp, int l, int r, int ql, int qr)
{
if (ql <= l && r <= qr) return t[tp];
int mid = (l + r) >> 1; tree ans;
pushup(ans, max(l, ql), min(r, qr), t[tp].f);
if (ql <= mid) ans = ans + query(t[tp].ls, l, mid, ql, qr);
if (mid < qr) ans = ans + query(t[tp].rs, mid + 1, r, ql, qr);
return ans;
}
inline int LCA(int x, int y)
{
while (top[x] != top[y])
{
if (d[top[x]] < d[top[y]]) swap(x, y);
x = f[top[x]];
}
if (d[x] > d[y]) swap(x, y);
return x;
}
inline void updates(int x, int y, long long k)
{
while (top[x] != top[y])
{
if (d[top[x]] < d[top[y]]) swap(x, y);
update(rt[root], 1, n, dfn[top[x]], dfn[x], k);
x = f[top[x]];
}
if (dfn[x] > dfn[y]) swap(x, y);
update(rt[root], 1, n, dfn[x], dfn[y], k);
}
inline int clac(long long c, int op, int l, int r)
{
tree g = query(rt[root], 1, n, l, r);
return (g.ddk + (op * g.dk * (c * 2 + 1)) % mod + g.k * ((c * c) % mod + c) + mod) % mod;
}
inline long long SUM(int x, int y)
{
int lca = LCA(x, y);
long long m1 = (d[y] - d[lca] * 2 + mod) % mod, m2 = d[y], ans = 0;
while (top[x] != top[y])
{
if (d[top[x]] <= d[top[y]]) ans = (ans + clac(m2, -1, dfn[top[y]], dfn[y])) % mod, y = f[top[y]];
else ans = (ans + clac(m1, 1, dfn[top[x]], dfn[x])) % mod, x = f[top[x]];
}
if (d[x] >= d[y]) ans = (ans + clac(m1, 1, dfn[y], dfn[x])) % mod;
else ans = (ans + clac(m2, -1, dfn[x], dfn[y])) % mod;
return ans * ((mod + 1) / 2) % mod;//这是2的逆元
}
signed main()
{
n = read(); m = read();
for (int i = 1; i < n; ++i)
{
int u = read(), v = read();
add(u, v); add(v, u);
}
for (int i = 1; i <= n; ++i) a[i] = read();
d[1] = 1; dfs1(1); dfs2(1, 1);
for (int i = 1; i <= n; ++i)
{
e[i] = (e[i - 1] + d[rk[i]]) % mod;
ee[i] = (ee[i - 1] + 1ll * d[rk[i]] * d[rk[i]]) % mod;
}//前缀和
build(rt[0], 1, n);
long long last = 0; int ddd = 0;
while (m--)
{
int op = read(), x, y; long long k;
if (op == 1)
{
x = read(); y = read(); k = read();
x ^= last, y ^= last;
rt[++ddd] = ++cnt;
t[rt[ddd]] = t[rt[root]];
root = ddd;
t[rt[root]].v[0] = t[rt[root]].v[1] = 1;
updates(x, y, k);
}//只有在修改时才新建版本
else if (op == 2)
{
x = read(); y = read();
x ^= last, y ^= last;
printf("%lld\n", (last = SUM(x, y)));
}
else
{
x = read();
x ^= last;
root = x;
}
}
return 0;
}