题解:P5642 人造情感(emotion)
dyc2022
·
2026-01-04 18:08:50
·
题解
更好的阅读体验
首先,我们考虑 W(S) 怎么求。我们以下称 W(S) 为路径集合的最大权独立集。
进行树形 dp。假设 f_u 表示只考虑 u 子树内的路径集合的最大权独立集,g_u 表示 u 子树内,钦定 u 不被经过的最大权独立集。显然 g_u = \sum \limits_{v \isin son_u} f_v ,表示由于 u 不被经过,因此只能由它的子树拼凑答案。
至于 f 的转移,有一种很显然的情况就是 u 不选,由 g_u 转移来。
f_u \leftarrow g_u
如果 u 要选,我们枚举要选一条以 u 为 LCA 的路径 (x_i, y_i, w_i) ,路径上点的集合为 path_i 。那么容易发现我们由最优状态,强制将一个点 u 设为不选,对答案产生的影响是 g_u - f_u 。所以我们要先钦定 u 不选,然后将路径上除了 u 之外的所有点都设为不选,再加上这个路径的贡献。
f_u \leftarrow g_u + w_i + \sum_{i \not = u \land i \isin path_u}(g_i - f_i)
这个路径上的 g_i - f_i 之和可以用数据结构维护。具体地,对于单点修改、链查询问题,我们可以树上差分以下,挂到 dfn 序上用数状数组简单维护。
接下来考虑如何求 f(u, v) 。由于 W(S \cup (u, v, x+1)) > W(S) ,所以我们保留 W(S \cup (u, v, x+1)) 中每条路径选或不选的状态不变,然后将 (u, v, x+1) 变成 (u, v, x) ,则总权值 = W(S) 。所以这个事情告诉我们,f(u, v) 其实就是最小的 x 使得加入路径 (u, v, x) 之后这条路径被选。
我们设 val_i 表示必选 i 路径,以路径 i 的 LCA 为根的子树内的最大权独立集权值。
val_i = g_u + w_i + \sum_{i \not = u \land i \isin path_u}(g_i - f_i)
我们接着假设 h_u 表示以 u 为根的子树外的最大权独立集。那么
h_{LCA(u, v)} + f_{LCA(u, v)} + F(u, v) + \sum_{i \isin path(u, v)}g_i - f_i = f_1
F(u, v) = f_1 - h_{LCA} - f_{LCA} - \sum_{i \isin {path(u, v)}}g_i - f_i
\sum_{i = 1}^n \sum_{j=1}^n f(i, j) = n^2 f_1 - \sum_{i = 1}^nP_i(h_i + f_i) - \sum_{i=1}^nQ_i(g_i - f_i)
$$P_i = sz_{i}^2 - \sum_{j \isin son_i}sz_j^2$$
$$Q_i = 2 \cdot sz_i \cdot (n - sz_i) + P_i$$
接下来考虑怎么求 $h_i$。考虑换根 dp,从 $u$ 转移到儿子 $v$。
第一种情况,不选经过 $u$ 的路径,只加上除了 $v$ 以外 $u$ 子树的贡献。
$$h_v \leftarrow h_u + g_u - f_v$$
第二种情况,选择一条经过 $u$ 的路径 $i$。
$$h_v \leftarrow h_{LCA_{i}} + val_i - f_v$$
对于这种情况,如果路径的 LCA 不为 $u$,那么我们直接在数据结构上把 $h_{LCA_i} + val_i$ 挂到路径两个端点上。否则我们把以 $u$ 为 LCA 的路径按照 $val$ 排序,然后从前往后枚举。这样一条路径最多被两次两次,复杂度正确。
我们维护一个支持单点 chkmax,区间查 max 的数据结构即可。
复杂度 $O(n \log n)$。
```cpp
#include<bits/stdc++.h>
#define endl '\n'
#define N 600006
#define MOD 998244353
using namespace std;
template <typename T> inline void chkmax(T &x,T y){x=x<y?y:x;}
template <typename T> inline void chkmin(T &x,T y){x=x<y?x:y;}
template <typename T> inline void add(T &x,T y){x+=y,x-=x>=MOD?MOD:0;}
template <typename T> inline void dec(T &x,T y){x+=MOD-y,x-=x>=MOD?MOD:0;}
using i64=long long;
struct Node {int u,v,w; i64 val;};
int n,m,dfs_clock,sz[N],son[N],dep[N],fa[N],dfn[N],tp[N];
i64 f[N],g[N],h[N];
vector<int> G[N];
vector<Node> vec[N];
struct BIT {
i64 tree[N];
void update(int k,i64 x){for(;k<=n;k+=k&-k)tree[k]+=x;}
void update(int l,int r,i64 x){update(l,x),update(r+1,-x);}
i64 query(int k){i64 ret=0; for(;k;k-=k&-k)ret+=tree[k]; return ret;}
} T;
struct Segtree {
i64 tree[N<<2];
void update(int p,int l,int r,int k,i64 x)
{
if(l==r)return chkmax(tree[p],x);
int mid=l+r>>1; k<=mid?update(p<<1,l,mid,k,x):update(p<<1|1,mid+1,r,k,x);
tree[p]=max(tree[p<<1],tree[p<<1|1]);
}
i64 query(int p,int l,int r,int L,int R)
{
if(L<=l&&r<=R)return tree[p];
int mid=l+r>>1; i64 ret=-4e18;
if(L<=mid)chkmax(ret,query(p<<1,l,mid,L,R));
if(R>mid)chkmax(ret,query(p<<1|1,mid+1,r,L,R));
return ret;
}
i64 query(int l,int r,int pl,int pr)
{
i64 ret=-4e18;
if(l<pl)chkmax(ret,query(1,1,n,l,pl-1));
if(r>pr)chkmax(ret,query(1,1,n,pr+1,r));
return ret;
}
} T1;
void dfs1(int u,int ft)
{
dep[u]=dep[ft]+1,fa[u]=ft,sz[u]=1;
for(int v:G[u])if(v!=ft)
dfs1(v,u),sz[u]+=sz[v],son[u]=sz[v]>sz[son[u]]?v:son[u];
}
void dfs2(int u,int t)
{
dfn[u]=++dfs_clock,tp[u]=t;
if(son[u])dfs2(son[u],t);
for(int v:G[u])if(v!=fa[u]&&v!=son[u])dfs2(v,v);
}
void dfs3(int u)
{
for(int v:G[u])
if(v!=fa[u])dfs3(v),g[u]+=f[v];
f[u]=g[u];
for(auto &i:vec[u])
i.val=i.w+g[u]+T.query(dfn[i.u])+T.query(dfn[i.v]),chkmax(f[u],i.val);
T.update(dfn[u],dfn[u]+sz[u]-1,g[u]-f[u]);
}
int in(int u,int v){return dfn[u]<=dfn[v]&&dfn[v]<=dfn[u]+sz[u]-1;}
void dfs4(int u)
{
sort(vec[u].begin(),vec[u].end(),[](Node x,Node y) {
return x.val>y.val;
}); for(int v:G[u])if(v!=fa[u])
{
h[v]=h[u]+g[u]-f[v];
for(auto i:vec[u])
{
int x=i.u,y=i.v; i64 val=i.val; if(in(v,x)||in(v,y))continue;
chkmax(h[v],h[u]+val-f[v]); break;
}
chkmax(h[v],T1.query(dfn[u],dfn[u]+sz[u]-1,dfn[v],dfn[v]+sz[v]-1)-f[v]);
}
for(auto i:vec[u])
{
int x=i.u,y=i.v; i64 val=i.val;
T1.update(1,1,n,dfn[x],h[u]+val),T1.update(1,1,n,dfn[y],h[u]+val);
}
for(int v:G[u])if(v!=fa[u])dfs4(v);
}
int getLCA(int u,int v)
{
for(;tp[u]!=tp[v];u=fa[tp[u]])
if(dep[tp[u]]<dep[tp[v]])swap(u,v);
return dep[u]<dep[v]?u:v;
}
int calc1(int u)
{
int ret=1ll*sz[u]*sz[u]%MOD;
for(int v:G[u])if(v!=fa[u])dec(ret,(int)(1ll*sz[v]*sz[v]%MOD));
return ret;
}
int calc2(int u)
{
int ret=2ll*sz[u]*(n-sz[u])%MOD; add(ret,calc1(u));
return ret;
}
main()
{
scanf("%d%d",&n,&m);
for(int i=1,u,v;i<n;i++)
scanf("%d%d",&u,&v),G[u].push_back(v),G[v].push_back(u);
dfs1(1,0),dfs2(1,1);
for(int i=1,u,v,w;i<=m;i++)
scanf("%d%d%d",&u,&v,&w),vec[getLCA(u,v)].push_back({u,v,w,0});
dfs3(1),dfs4(1);
int ans=1ll*n*n%MOD*(f[1]%MOD)%MOD;
for(int i=1;i<=n;i++)dec(ans,(int)(1ll*(h[i]+f[i])%MOD*calc1(i)%MOD));
for(int i=1;i<=n;i++)dec(ans,(int)(1ll*((g[i]-f[i])%MOD+MOD)%MOD*calc2(i)%MOD));
printf("%d\n",ans%MOD);
return 0;
}
```