题解 P5642 【 人造情感(emotion)】
膜拜 EI 鸽鸽!/崇拜
给一个
考虑对于给定的集合
对于路径
对于
接着,我们考虑如何求出
对这两种转移的解释:
-
对于第一种转移,考虑什么时候
c_u <0 ,显然这意味着有某条穿过u 的路径造成了贡献。假设这条路径往下走到了某个儿子v ,那么对于v 来说,由于断掉了(u,v) 这条边,于是只好让u 空出(这需要付出c_u 的代价)。但实际上我们并不知道有哪些儿子需要付出这个代价,我们不妨对所有儿子都进行这个转移,对于错误转移到的那些点一定不会优于第二种转移,所以对答案并没有影响。 -
第二种转移显然满足路径不交,所以一定是合法的。
第一种转移直接做即可;但对于第二种转移,如果我们暴力枚举所有路径进行判断,时间复杂度就达到了惊人的
先考虑一个暴力的维护,我们对每一个点开一个 vector 维护在它身上的 tag,每个 tag 都是一个形如
但可惜的是,即使避免了多余的枚举,标记数量仍然可能很多,时间复杂度并没有实质上的优化。还能再给力点吗?
对给定的树进行轻重链剖分,这时我们会惊喜地发现,大多数 tag 都形如
求出
注意操作顺序,先打 tag 后转移。代码写得有点丑,调试懒得删了就这样将就看看?
const int MN = 4e5 + 5;
const int Mod = 998244353;
// #define dbg
int N, M, ans;
vector <int> G[MN];
struct BIT {
int c[MN];
inline int lowbit(int x) {
return x & (-x);
}
inline void Modify(int x, int k) {
for (int i = x; i <= N; i += lowbit(i))
c[i] += k;
}
inline void Modify(int l, int r, int k) {
Modify(l, k), Modify(r + 1, -k);
}
inline int Query(int x) {
if (!x)
return 0;
int res = 0;
for (int i = x; i; i -= lowbit(i))
res += c[i];
return res;
}
} T;
int siz[MN], son[MN], dep[MN], fa[MN], fat[MN][25];
inline void DFS0(int u, int pr) {
siz[u] = 1, fat[u][0] = pr;
dep[u] = dep[fa[u] = pr] + 1;
for (int i = 1; i <= 20; i++)
fat[u][i] = fat[fat[u][i - 1]][i - 1];
for (int v : G[u])
if (v != pr) {
DFS0(v, u);
siz[u] += siz[v];
if (siz[v] > siz[son[u]])
son[u] = v;
}
}
int top[MN], dfn[MN], dfc, st[MN], ed[MN];
inline void DFS1(int u, int tp) {
top[u] = tp;
dfn[u] = st[u] = ++dfc;
if (son[u]) {
DFS1(son[u], tp);
for (int v : G[u])
if (v != fa[u] && v != son[u])
DFS1(v, v);
}
ed[u] = dfc;
}
inline int LCA(int x, int y) {
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]]) swap(x, y);
x = fa[top[x]];
}
if (dep[x] > dep[y]) swap(x, y);
return x;
}
inline int climb(int x, int y) {
for (int i = 20; ~i; i--)
if (dep[fat[x][i]] > dep[y]) x = fat[x][i];
return x;
}
#define PI3 pair < pii, int >
#define mp3(x, y, z) mp(mp(x, y), z)
int f[MN], g[MN], c[MN];
vector <PI3> vr[MN];
inline int Qry(int x, int y, int r) {
return T.Query(st[x]) + T.Query(st[y]) - T.Query(st[fa[r]]) - T.Query(st[r]);
}
inline void DFS2(int u) {
for (int v : G[u])
if (v != fa[u])
DFS2(v), g[u] += f[v];
int del = 0;
for (auto [p, w] : vr[u]) {
del = max(del, Qry(p.fi, p.se, u) + w);
#ifdef dbg
printf("Qry(%lld %lld) = %lld\n", p.fi, p.se, Qry(p.fi, p.se, u));
#endif
}
f[u] = g[u] + del;
c[u] = g[u] - f[u];
T.Modify(st[u], ed[u], c[u]);
}
const int MS = MN << 2;
#define ls o << 1
#define rs o << 1 | 1
#define mid ((l + r) >> 1)
#define LS ls, l, mid
#define RS rs, mid + 1, r
int tr[MS], tg[MS];
inline void Mdf(int o, int l, int r, int L, int R, int v) {
if (r < L || l > R || L > R) return;
if (L <= l && R >= r) return tg[o] = max(tg[o], v), void();
Mdf(LS, L, R, v), Mdf(RS, L, R, v);
}
inline int Qry(int o, int l, int r, int p) {
if (l == r) return tg[o];
return max(tg[o], p <= mid ? Qry(LS, p) : Qry(RS, p));
}
struct Dat {
int w, a, b;
Dat() {}
Dat(int W, int X, int Y) : w(W), a(X), b(Y) {}
inline bool operator < (const Dat &p) const {
return w > p.w;
}
};
int h[MN];
vector <Dat> tag[MN];
inline void Cov(int x, int y, int w, int r) {
int u = x, v = y;
if (dep[x] < dep[y])
swap(x, y);
tag[x].pb(Dat(w, 0, 0));
if (y != r)
tag[y].pb(Dat(w, 0, 0));
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]])
swap(x, y);
Mdf(1, 1, N, dfn[top[x]], dfn[fa[x]], w);
if (fa[top[x]] != r)
tag[fa[top[x]]].pb(Dat(w, top[x], 0));
x = fa[top[x]];
}
u = climb(u, r);
v = climb(v, r);
tag[r].pb(Dat(w, u, v));
if (x == y)
return;
if (dep[x] < dep[y])
swap(x, y);
int p = climb(x, y);
Mdf(1, 1, N, dfn[p], dfn[fa[x]], w);
}
inline void DFS3(int u) {
for (auto [p, w] : vr[u]) {
int x = p.fi, y = p.se;
Cov(x, y, h[u] + Qry(x, y, u) + w, u);
#ifdef dbg
printf("%lld : pushtag (%lld, %lld, %lld)\n", u, x, y, h[u] + Qry(x, y, u) + w);
#endif
}
sort(tag[u].begin(), tag[u].end());
int w = Qry(1, 1, N, dfn[u]);
for (int v : G[u])
if (v != fa[u]) {
h[v] = max(h[u] + c[u], 0);
if (v != son[u])
h[v] = max(h[v], w);
for (auto tg : tag[u])
if (tg.a != v && tg.b != v) {
h[v] = max(h[v], tg.w);
break;
}
DFS3(v);
}
}
signed main(void) {
N = read(), M = read();
for (int i = 1; i < N; i++) {
int u = read(), v = read();
G[u].pb(v), G[v].pb(u);
}
DFS0(1, 0);
DFS1(1, 1);
for (int i = 1; i <= M; i++) {
int x = read(), y = read(), w = read();
int r = LCA(x, y);
#ifdef dbg
printf("LCA(%lld, %lld) = %lld\n", x, y, r);
#endif
vr[r].pb(mp3(x, y, w));
}
DFS2(1);
h[1] = f[1];
#ifdef dbg
puts("----------------");
#endif
DFS3(1);
#ifdef dbg
puts("----------------");
printf("maxW(U) = %lld\n", f[1]);
for (int i = 1; i <= N; i++)
printf("%lld : f = %lld, g = %lld, c = %lld, h = %lld\n", i, f[i], g[i], c[i], h[i]);
puts("----------------");
#endif
ans = f[1] % Mod * N % Mod * N % Mod;
for (int i = 1; i <= N; i++)
c[i] = Mod + c[i] % Mod, h[i] %= Mod;
for (int u = 1; u <= N; u++) {
int Coef = N * N % Mod;
Dec(Coef, (N - siz[u]) * (N - siz[u]) % Mod);
int w = 0;
for (int v : G[u])
if (v != fa[u]) Add(w, siz[v] * siz[v] % Mod);
Dec(Coef, w);
Dec(ans, Coef * c[u] % Mod);
Coef = siz[u] * siz[u] % Mod;
Dec(Coef, w);
Dec(ans, Coef * h[u] % Mod);
}
printf("%lld\n", ans);
return 0;
}