P5642
首先将所有路径挂在 LCA 上。
先考虑求
树状数组辅助转移,单点加链求和转换为子树加单点求值,即可做到单 log。
再设
再刚刚求
-
u$ 被独立集中路径经过,且路径 LCA 为 $h
对于第二种情况,
得出了
视
#include<bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N=300005;
const int mod=998244353;
int n,idx,m,ans,fa[N],son[N],top[N],dep[N],siz[N],dfn[N]; vector<int>e[N]; ll h[N],f[N],sum[N],g[N];
struct Q{ int u,v; ll w; bool operator < (const Q &a) const { return w>a.w; } }; vector<Q>o[N];
namespace BIT{
ll c[N];
#define lowbit(x) (x&(-x))
inline void add(int x,ll y){ while(x<=n) c[x]+=y,x+=lowbit(x); }
inline ll query(int x){ ll r=0; while(x) r+=c[x],x-=lowbit(x); return r; }
inline void update(int l,int r,ll x){ add(l,x); add(r+1,-x); }
}
namespace SGT{
ll mx[N<<2];
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
inline void pushup(int k){ mx[k]=max(mx[ls(k)],mx[rs(k)]); }
inline void cmx(int k,int l,int r,int x,ll y){
if(l==r){ mx[k]=max(mx[k],y); return; }
int mid=l+r>>1;
if(x<=mid) cmx(ls(k),l,mid,x,y);
else cmx(rs(k),mid+1,r,x,y);
pushup(k);
}
inline ll qmx(int k,int l,int r,int ql,int qr){
if(l>=ql&&r<=qr) return mx[k];
int mid=l+r>>1;
if(qr<=mid) return qmx(ls(k),l,mid,ql,qr);
if(ql>mid) return qmx(rs(k),mid+1,r,ql,qr);
return max(qmx(ls(k),l,mid,ql,qr),qmx(rs(k),mid+1,r,ql,qr));
}
inline void update(int x,ll y){ cmx(1,1,n,x,y); }
inline ll query(int L,int R,int l,int r){
ll ret=0;
if(L<l) ret=max(ret,qmx(1,1,n,L,l-1));
if(R>r) ret=max(ret,qmx(1,1,n,r+1,R));
return ret;
}
}
inline int lca(int u,int v){
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u,v);
u=fa[top[u]];
} return dep[u]<dep[v]?u:v;
}
inline void dfs1(int u){
siz[u]=1;
for(int v:e[u]){
if(v==fa[u]) continue;
fa[v]=u; dep[v]=dep[u]+1;
dfs1(v); siz[u]+=siz[v];
if(siz[v]>siz[son[u]]) son[u]=v;
h[u]+=1ll*siz[v]*siz[v];
}
h[u]+=1ll*(n-siz[u])*(n-siz[u]);
}
inline void dfs2(int u,int tp){
top[u]=tp; dfn[u]=++idx;
if(!son[u]) return;
dfs2(son[u],tp);
for(int v:e[u]) if(v!=fa[u]&&v!=son[u]) dfs2(v,v);
}
inline void dfs3(int u){
for(int v:e[u]){
if(v==fa[u]) continue;
dfs3(v); sum[u]+=f[v];
}
f[u]=sum[u];
for(int i=0;i<o[u].size();i++){
o[u][i].w+=sum[u]+BIT::query(dfn[o[u][i].u])+BIT::query(dfn[o[u][i].v]);
f[u]=max(f[u],o[u][i].w);
}
BIT::update(dfn[u],dfn[u]+siz[u]-1,sum[u]-f[u]);
}
inline void dfs4(int u){
sort(o[u].begin(),o[u].end());
for(int v:e[u]){
if(v==fa[u]) continue;
g[v]=g[u]+sum[u]-f[v];
for(Q i:o[u])
if(lca(i.u,v)!=v&&lca(i.v,v)!=v){
g[v]=max(g[v],g[u]+i.w-f[v]);
break;
}
g[v]=max(g[v],SGT::query(dfn[u],dfn[u]+siz[u]-1,dfn[v],dfn[v]+siz[v]-1)-f[v]);
}
for(Q i:o[u])
SGT::update(dfn[i.u],g[u]+i.w),SGT::update(dfn[i.v],g[u]+i.w);
for(int v:e[u]) if(v!=fa[u]) dfs4(v);
}
int main(){
ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
cin>>n>>m;
for(int i=1,u,v;i<n;i++){
cin>>u>>v;
e[u].emplace_back(v);
e[v].emplace_back(u);
}
dfs1(1); dfs2(1,1);
for(int i=1,u,v,w,t;i<=m;i++){
cin>>u>>v>>w; t=lca(u,v);
o[t].emplace_back(Q{u,v,w});
}
dfs3(1); dfs4(1);
ans=1ll*n*n%mod*(f[1]%mod)%mod;
for(int u=1;u<=n;u++){
if(u!=1) ans=(ans-(1ll*n*n-2ll*siz[u]*(n-siz[u])-h[fa[u]])%mod*(f[u]%mod)%mod+mod)%mod;
ans=(ans-(1ll*n*n-2ll*siz[u]*(n-siz[u])-h[u])%mod*(g[u]%mod)%mod+mod)%mod;
}
cout<<ans<<'\n';
return 0;
}