题解:P5642 人造情感(emotion)

· · 题解

更好的阅读体验

首先,我们考虑 W(S) 怎么求。我们以下称 W(S) 为路径集合的最大权独立集。

进行树形 dp。假设 f_u 表示只考虑 u 子树内的路径集合的最大权独立集,g_u 表示 u 子树内,钦定 u 不被经过的最大权独立集。显然 g_u = \sum \limits_{v \isin son_u} f_v,表示由于 u 不被经过,因此只能由它的子树拼凑答案。

至于 f 的转移,有一种很显然的情况就是 u 不选,由 g_u 转移来。

f_u \leftarrow g_u

如果 u 要选,我们枚举要选一条以 u 为 LCA 的路径 (x_i, y_i, w_i),路径上点的集合为 path_i。那么容易发现我们由最优状态,强制将一个点 u 设为不选,对答案产生的影响是 g_u - f_u。所以我们要先钦定 u 不选,然后将路径上除了 u 之外的所有点都设为不选,再加上这个路径的贡献。

f_u \leftarrow g_u + w_i + \sum_{i \not = u \land i \isin path_u}(g_i - f_i)

这个路径上的 g_i - f_i 之和可以用数据结构维护。具体地,对于单点修改、链查询问题,我们可以树上差分以下,挂到 dfn 序上用数状数组简单维护。

接下来考虑如何求 f(u, v)。由于 W(S \cup (u, v, x+1)) > W(S),所以我们保留 W(S \cup (u, v, x+1)) 中每条路径选或不选的状态不变,然后将 (u, v, x+1) 变成 (u, v, x),则总权值 = W(S)。所以这个事情告诉我们,f(u, v) 其实就是最小的 x 使得加入路径 (u, v, x) 之后这条路径被选。

我们设 val_i 表示必选 i 路径,以路径 i 的 LCA 为根的子树内的最大权独立集权值。

val_i = g_u + w_i + \sum_{i \not = u \land i \isin path_u}(g_i - f_i)

我们接着假设 h_u 表示以 u 为根的子树外的最大权独立集。那么

h_{LCA(u, v)} + f_{LCA(u, v)} + F(u, v) + \sum_{i \isin path(u, v)}g_i - f_i = f_1 F(u, v) = f_1 - h_{LCA} - f_{LCA} - \sum_{i \isin {path(u, v)}}g_i - f_i \sum_{i = 1}^n \sum_{j=1}^n f(i, j) = n^2 f_1 - \sum_{i = 1}^nP_i(h_i + f_i) - \sum_{i=1}^nQ_i(g_i - f_i) $$P_i = sz_{i}^2 - \sum_{j \isin son_i}sz_j^2$$ $$Q_i = 2 \cdot sz_i \cdot (n - sz_i) + P_i$$ 接下来考虑怎么求 $h_i$。考虑换根 dp,从 $u$ 转移到儿子 $v$。 第一种情况,不选经过 $u$ 的路径,只加上除了 $v$ 以外 $u$ 子树的贡献。 $$h_v \leftarrow h_u + g_u - f_v$$ 第二种情况,选择一条经过 $u$ 的路径 $i$。 $$h_v \leftarrow h_{LCA_{i}} + val_i - f_v$$ 对于这种情况,如果路径的 LCA 不为 $u$,那么我们直接在数据结构上把 $h_{LCA_i} + val_i$ 挂到路径两个端点上。否则我们把以 $u$ 为 LCA 的路径按照 $val$ 排序,然后从前往后枚举。这样一条路径最多被两次两次,复杂度正确。 我们维护一个支持单点 chkmax,区间查 max 的数据结构即可。 复杂度 $O(n \log n)$。 ```cpp #include<bits/stdc++.h> #define endl '\n' #define N 600006 #define MOD 998244353 using namespace std; template <typename T> inline void chkmax(T &x,T y){x=x<y?y:x;} template <typename T> inline void chkmin(T &x,T y){x=x<y?x:y;} template <typename T> inline void add(T &x,T y){x+=y,x-=x>=MOD?MOD:0;} template <typename T> inline void dec(T &x,T y){x+=MOD-y,x-=x>=MOD?MOD:0;} using i64=long long; struct Node {int u,v,w; i64 val;}; int n,m,dfs_clock,sz[N],son[N],dep[N],fa[N],dfn[N],tp[N]; i64 f[N],g[N],h[N]; vector<int> G[N]; vector<Node> vec[N]; struct BIT { i64 tree[N]; void update(int k,i64 x){for(;k<=n;k+=k&-k)tree[k]+=x;} void update(int l,int r,i64 x){update(l,x),update(r+1,-x);} i64 query(int k){i64 ret=0; for(;k;k-=k&-k)ret+=tree[k]; return ret;} } T; struct Segtree { i64 tree[N<<2]; void update(int p,int l,int r,int k,i64 x) { if(l==r)return chkmax(tree[p],x); int mid=l+r>>1; k<=mid?update(p<<1,l,mid,k,x):update(p<<1|1,mid+1,r,k,x); tree[p]=max(tree[p<<1],tree[p<<1|1]); } i64 query(int p,int l,int r,int L,int R) { if(L<=l&&r<=R)return tree[p]; int mid=l+r>>1; i64 ret=-4e18; if(L<=mid)chkmax(ret,query(p<<1,l,mid,L,R)); if(R>mid)chkmax(ret,query(p<<1|1,mid+1,r,L,R)); return ret; } i64 query(int l,int r,int pl,int pr) { i64 ret=-4e18; if(l<pl)chkmax(ret,query(1,1,n,l,pl-1)); if(r>pr)chkmax(ret,query(1,1,n,pr+1,r)); return ret; } } T1; void dfs1(int u,int ft) { dep[u]=dep[ft]+1,fa[u]=ft,sz[u]=1; for(int v:G[u])if(v!=ft) dfs1(v,u),sz[u]+=sz[v],son[u]=sz[v]>sz[son[u]]?v:son[u]; } void dfs2(int u,int t) { dfn[u]=++dfs_clock,tp[u]=t; if(son[u])dfs2(son[u],t); for(int v:G[u])if(v!=fa[u]&&v!=son[u])dfs2(v,v); } void dfs3(int u) { for(int v:G[u]) if(v!=fa[u])dfs3(v),g[u]+=f[v]; f[u]=g[u]; for(auto &i:vec[u]) i.val=i.w+g[u]+T.query(dfn[i.u])+T.query(dfn[i.v]),chkmax(f[u],i.val); T.update(dfn[u],dfn[u]+sz[u]-1,g[u]-f[u]); } int in(int u,int v){return dfn[u]<=dfn[v]&&dfn[v]<=dfn[u]+sz[u]-1;} void dfs4(int u) { sort(vec[u].begin(),vec[u].end(),[](Node x,Node y) { return x.val>y.val; }); for(int v:G[u])if(v!=fa[u]) { h[v]=h[u]+g[u]-f[v]; for(auto i:vec[u]) { int x=i.u,y=i.v; i64 val=i.val; if(in(v,x)||in(v,y))continue; chkmax(h[v],h[u]+val-f[v]); break; } chkmax(h[v],T1.query(dfn[u],dfn[u]+sz[u]-1,dfn[v],dfn[v]+sz[v]-1)-f[v]); } for(auto i:vec[u]) { int x=i.u,y=i.v; i64 val=i.val; T1.update(1,1,n,dfn[x],h[u]+val),T1.update(1,1,n,dfn[y],h[u]+val); } for(int v:G[u])if(v!=fa[u])dfs4(v); } int getLCA(int u,int v) { for(;tp[u]!=tp[v];u=fa[tp[u]]) if(dep[tp[u]]<dep[tp[v]])swap(u,v); return dep[u]<dep[v]?u:v; } int calc1(int u) { int ret=1ll*sz[u]*sz[u]%MOD; for(int v:G[u])if(v!=fa[u])dec(ret,(int)(1ll*sz[v]*sz[v]%MOD)); return ret; } int calc2(int u) { int ret=2ll*sz[u]*(n-sz[u])%MOD; add(ret,calc1(u)); return ret; } main() { scanf("%d%d",&n,&m); for(int i=1,u,v;i<n;i++) scanf("%d%d",&u,&v),G[u].push_back(v),G[v].push_back(u); dfs1(1,0),dfs2(1,1); for(int i=1,u,v,w;i<=m;i++) scanf("%d%d%d",&u,&v,&w),vec[getLCA(u,v)].push_back({u,v,w,0}); dfs3(1),dfs4(1); int ans=1ll*n*n%MOD*(f[1]%MOD)%MOD; for(int i=1;i<=n;i++)dec(ans,(int)(1ll*(h[i]+f[i])%MOD*calc1(i)%MOD)); for(int i=1;i<=n;i++)dec(ans,(int)(1ll*((g[i]-f[i])%MOD+MOD)%MOD*calc2(i)%MOD)); printf("%d\n",ans%MOD); return 0; } ```