P5642

· · 题解

首先将所有路径挂在 LCA 上。

先考虑求 W(U),设 f_u 表示 u 子树内的最大独立集,sum_u=\sum\limits_{v \in son(u)} {f_v},自下而上 dp:

f_u=sum_u f_u=\max\limits_{(x_i,y_i,w_i)}\{{sum_u+w_i-\sum\limits_{z \in path(x_i,y_i),z \neq u} (f_z-sum_z)}\}

树状数组辅助转移,单点加链求和转换为子树加单点求值,即可做到单 log。

再设 g_u 表示 u 子树外的最大独立集。

再刚刚求 f 时,我们在遇到边 (x_i,y_i,w_i) 时记下 t_i={sum_u+w_i-\sum\limits_{z \in path(x_i,y_i),z \neq u} (f_z-sum_z)} 的值,即选择这条边情况下子树内的最优答案,以便我们接下来的转移。自上而下 dp:

g_v=g_u+sum_u-f_v g_v=\max\limits_{(x_i,y_i,w_i)}\{{g_h+t_i-f_v}\}

对于第二种情况,h=u 可以直接按照 t_i 排序后暴力扫,否则一个端点在祖先的其他子树中,另一端点在 u 的非 v 子树中,可以在挂在线段树上查询,可以做到单 log。

得出了 f_ug_u 就可以统计答案了:对于任意 (i,j),所求的 F(i,j) 都是由若干 f_u 和至多一个 g_u 贡献而来,一个 f_u 有贡献当且仅当 u 不在路径上且 fa_u 在路径上,一个 g_u 有贡献当且仅当 fa_u 不在路径上且 u 在路径上,统计答案是容易的。

n,m 同阶,总时间复杂度 O(n \log n)

#include<bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N=300005;
const int mod=998244353;
int n,idx,m,ans,fa[N],son[N],top[N],dep[N],siz[N],dfn[N]; vector<int>e[N]; ll h[N],f[N],sum[N],g[N];
struct Q{ int u,v; ll w; bool operator < (const Q &a) const { return w>a.w; } }; vector<Q>o[N];
namespace BIT{
    ll c[N];
    #define lowbit(x) (x&(-x))
    inline void add(int x,ll y){ while(x<=n) c[x]+=y,x+=lowbit(x); }
    inline ll query(int x){ ll r=0; while(x) r+=c[x],x-=lowbit(x); return r; }
    inline void update(int l,int r,ll x){ add(l,x); add(r+1,-x); }
}
namespace SGT{
    ll mx[N<<2];
    #define ls(x) (x<<1)
    #define rs(x) (x<<1|1)
    inline void pushup(int k){ mx[k]=max(mx[ls(k)],mx[rs(k)]); }
    inline void cmx(int k,int l,int r,int x,ll y){
        if(l==r){ mx[k]=max(mx[k],y); return; }
        int mid=l+r>>1;
        if(x<=mid) cmx(ls(k),l,mid,x,y);
        else cmx(rs(k),mid+1,r,x,y);
        pushup(k);
    }
    inline ll qmx(int k,int l,int r,int ql,int qr){
        if(l>=ql&&r<=qr) return mx[k];
        int mid=l+r>>1;
        if(qr<=mid) return qmx(ls(k),l,mid,ql,qr);
        if(ql>mid) return qmx(rs(k),mid+1,r,ql,qr);
        return max(qmx(ls(k),l,mid,ql,qr),qmx(rs(k),mid+1,r,ql,qr));
    }
    inline void update(int x,ll y){ cmx(1,1,n,x,y); }
    inline ll query(int L,int R,int l,int r){
        ll ret=0;
        if(L<l) ret=max(ret,qmx(1,1,n,L,l-1));
        if(R>r) ret=max(ret,qmx(1,1,n,r+1,R));
        return ret;
    }
}
inline int lca(int u,int v){
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]]) swap(u,v);
        u=fa[top[u]];
    }   return dep[u]<dep[v]?u:v;
}
inline void dfs1(int u){
    siz[u]=1;
    for(int v:e[u]){
        if(v==fa[u]) continue;
        fa[v]=u; dep[v]=dep[u]+1;
        dfs1(v); siz[u]+=siz[v];
        if(siz[v]>siz[son[u]]) son[u]=v;
        h[u]+=1ll*siz[v]*siz[v];
    }
    h[u]+=1ll*(n-siz[u])*(n-siz[u]);
}
inline void dfs2(int u,int tp){
    top[u]=tp; dfn[u]=++idx;
    if(!son[u]) return;
    dfs2(son[u],tp);
    for(int v:e[u]) if(v!=fa[u]&&v!=son[u]) dfs2(v,v);
}
inline void dfs3(int u){
    for(int v:e[u]){
        if(v==fa[u]) continue;
        dfs3(v); sum[u]+=f[v];
    }
    f[u]=sum[u];
    for(int i=0;i<o[u].size();i++){
        o[u][i].w+=sum[u]+BIT::query(dfn[o[u][i].u])+BIT::query(dfn[o[u][i].v]);
        f[u]=max(f[u],o[u][i].w);
    }
    BIT::update(dfn[u],dfn[u]+siz[u]-1,sum[u]-f[u]);
}
inline void dfs4(int u){
    sort(o[u].begin(),o[u].end());
    for(int v:e[u]){
        if(v==fa[u]) continue;
        g[v]=g[u]+sum[u]-f[v];
        for(Q i:o[u])
            if(lca(i.u,v)!=v&&lca(i.v,v)!=v){
                g[v]=max(g[v],g[u]+i.w-f[v]);
                break;
            }
        g[v]=max(g[v],SGT::query(dfn[u],dfn[u]+siz[u]-1,dfn[v],dfn[v]+siz[v]-1)-f[v]);
    }
    for(Q i:o[u])
        SGT::update(dfn[i.u],g[u]+i.w),SGT::update(dfn[i.v],g[u]+i.w);
    for(int v:e[u]) if(v!=fa[u]) dfs4(v);
}
int main(){
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    cin>>n>>m;
    for(int i=1,u,v;i<n;i++){
        cin>>u>>v;
        e[u].emplace_back(v);
        e[v].emplace_back(u);
    }
    dfs1(1); dfs2(1,1);
    for(int i=1,u,v,w,t;i<=m;i++){
        cin>>u>>v>>w; t=lca(u,v);
        o[t].emplace_back(Q{u,v,w});
    }
    dfs3(1); dfs4(1);
    ans=1ll*n*n%mod*(f[1]%mod)%mod;
    for(int u=1;u<=n;u++){
        if(u!=1) ans=(ans-(1ll*n*n-2ll*siz[u]*(n-siz[u])-h[fa[u]])%mod*(f[u]%mod)%mod+mod)%mod;
        ans=(ans-(1ll*n*n-2ll*siz[u]*(n-siz[u])-h[u])%mod*(g[u]%mod)%mod+mod)%mod;
    }
    cout<<ans<<'\n';
    return 0;
}