P7815 自伤无色 题解

· · 题解

首先设答案为 \frac{p}{q} 推柿子(柿子中的uvw含义与题面相同):

先处理比较简单的分母部分:

q=\sum_{w=1}^{n}\sum_{u\in son_u}\sum_{v\in son_v \cap dep_v\le dep_u}(2 \times dep_v - 2 \times dep_w - 1)

如果这样直接做是 N^3 的,完全过不了。

拆开,考虑每个u对于w的贡献其实就是w子树内,u子树外dep比u小的个数 \times (2 \times dep_w - 1) + dep比u小的dep和

考虑这两个东西怎么算

对于w的每个轻儿子u,每次用树状数组算出前面子树与u对答案的贡献,再把u的子树加入树状数组中。注意,要把dep>u的v和dep\le uv都算一遍

重儿子的话已经保留了,放在第一个,前面都没有数,所以不用统计,就顺利做到log级的复杂度了

对于分子,柿子更加复杂一些:

p =\sum_{w=1}^{n}\sum_{u\in son_u}\sum_{v\in son_v \cap dep_v\le dep_u}(2 \times dep_v - 2 \times dep_w - 1)\times (2 \times dep_u + dep_v - 3\times dep_w) =\sum_{w=1}^{n}\sum_{u\in son_u}\sum_{v\in son_v \cap dep_v\le dep_u}(2\times dep_v^2 + 6\times dep_w^2 + 4\times dep_u \times dep_v +3\times dep_w -4\times dep_w\times dep_u -8\times dep_w\times dep_v - 2\times dep_u - dep_v)

所以树状数组里面还要维护一个前缀二次方和

时间复杂度 O(N log_2 N)

注意一下dep有点大,所以要开ll,离散化,注意取模

Code by zbs2006

#include<bits/stdc++.h>
using namespace std;
const int N=1000005;
int n;
typedef long long ll;
const ll mod=1e9+7;
int cnt,h[N],to[N<<1],nxt[N<<1],sz[N],son[N];
ll val[N<<1],dep[N],D[N];
int d[N];
void add(int x,int y,ll z){
    cnt++;
    nxt[cnt]=h[x];
    h[x]=cnt;
    to[cnt]=y,val[cnt]=z;
}
int tim,In[N],Out[N],dfn[N];
void M(ll &x){
    x=(x%mod+mod)%mod;
}
void dfs(int u,int fa){
    int i;
    sz[u]=1;int mx=0;
    In[u]=++tim,dfn[tim]=u;
    for(i=h[u];i;i=nxt[i]){
        int v=to[i];ll w=val[i];
        if(v!=fa){
            dep[v]=dep[u]+w;
            dfs(v,u);
            sz[u]+=sz[v];
            if(sz[v]>mx)
                mx=sz[v],son[u]=v;
        }
    }
    Out[u]=tim;
}
#define lowbit(x) (x&(-x))
ll c[N],e[N],g[N];
void add(int x,ll y,int fl){
    while(x<=n){
        c[x]+=fl;M(c[x]);
        e[x]+=fl*y;M(e[x]);
        g[x]+=fl*y*y;M(g[x]);
        x+=lowbit(x);
    } 
}
struct Ans{
    ll fi,se,th;
};
Ans sum(int x){
    ll C=0,E=0,G=0;
    while(x){
        C=(C+c[x])%mod;
        E=(E+e[x])%mod;
        G=(G+g[x])%mod;
        x-=lowbit(x);
    }
    return (Ans){C,E,G};
}
ll qp(ll x,ll y){
    ll res=1;
    while(y){
        if(y&1)res=res*x%mod;
        x=x*x%mod;
        y>>=1;
    }
    return res;
}
ll MU,ZI; 
void deal(int w,int fa,int keep){
    int i;
    for(i=h[w];i;i=nxt[i]){
        int To=to[i];
        if(To!=fa&&To!=son[w])
            deal(To,w,0);
    }
    if(son[w])
        deal(son[w],w,1);
    for(i=h[w];i;i=nxt[i]){
        int To=to[i];
        if(To!=fa&&To!=son[w]){
            for(int j=In[To];j<=Out[To];j++){
                int v=dfn[j];Ans r=sum(d[v]);
                ll da=sum(n).fi-r.fi;
                ll dg=sum(n).se-r.se;
                //calc contributions of V
                //as U 
                MU+=(2*r.se+mod-r.fi*(2*D[w]+1)%mod)%mod,M(MU);
                //as V 
                MU+=da*(-2*D[w]%mod+mod+2*D[v]%mod-1)%mod,M(MU);
                //as U  
                ZI+=r.fi*(6*D[w]%mod*D[w]%mod+3*D[w]%mod)%mod,M(ZI);
                ZI+=r.fi*D[v]%mod*(mod-2-4*D[w])%mod,M(ZI);
                ZI+=r.th*2%mod,M(ZI);
                ZI+=r.se*(4*D[v]-8*D[w]-1)%mod,M(ZI);
                //as V
                ZI+=da*(6*D[w]*D[w]%mod+3*D[w])%mod,M(ZI);
                ZI+=dg*(mod-2-4*D[w])%mod,M(ZI);
                ZI+=da*2*D[v]%mod*D[v]%mod,M(ZI);
                ZI+=da*(mod-1-8*D[w]%mod)%mod*D[v],M(ZI); 
                ZI+=4*dg*D[v]%mod,M(ZI);
            }
            for(int j=In[To];j<=Out[To];j++){
                int v=dfn[j];
                add(d[v],D[v],1);
            }
        }
    }
    add(d[w],D[w],1);
    if(!keep){
        for(int j=In[w];j<=Out[w];j++){
            int v=dfn[j];
            add(d[v],D[v],-1);
        }
    }
}
int main(){
    int i;
    cin>>n;
    for(i=1;i<n;i++){
        int x,y;ll z;
        scanf("%d%d%lld",&x,&y,&z);
        add(x,y,z),add(y,x,z);
    }
    dfs(1,0);
    for(i=1;i<=n;i++)D[i]=dep[i];
    sort(dep+1,dep+n+1);
    int tot=unique(dep+1,dep+n+1)-dep-1;
    for(i=1;i<=n;i++)d[i]=lower_bound(dep+1,dep+tot+1,D[i])-dep,D[i]%=mod;
    deal(1,0,1);
    cout<<ZI*qp(MU,mod-2)%mod<<endl;
    return 0;
}